igtr_decoder.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torch.nn.init import ones_, trunc_normal_, zeros_
  6. from openrec.modeling.common import DropPath, Identity, Mlp
  7. from openrec.modeling.decoders.nrtr_decoder import Embeddings
  8. class CrossAttention(nn.Module):
  9. def __init__(
  10. self,
  11. dim,
  12. num_heads=8,
  13. qkv_bias=False,
  14. qk_scale=None,
  15. attn_drop=0.0,
  16. proj_drop=0.0,
  17. ):
  18. super().__init__()
  19. self.num_heads = num_heads
  20. head_dim = dim // num_heads
  21. self.scale = qk_scale or head_dim**-0.5
  22. self.q = nn.Linear(dim, dim, bias=qkv_bias)
  23. self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
  24. self.attn_drop = nn.Dropout(attn_drop)
  25. self.proj = nn.Linear(dim, dim)
  26. self.proj_drop = nn.Dropout(proj_drop)
  27. def forward(self, q, kv, key_mask=None):
  28. N, C = kv.shape[1:]
  29. QN = q.shape[1]
  30. q = self.q(q).reshape([-1, QN, self.num_heads,
  31. C // self.num_heads]).transpose(1, 2)
  32. q = q * self.scale
  33. k, v = self.kv(kv).reshape(
  34. [-1, N, 2, self.num_heads,
  35. C // self.num_heads]).permute(2, 0, 3, 1, 4)
  36. attn = q.matmul(k.transpose(2, 3))
  37. if key_mask is not None:
  38. attn = attn + key_mask.unsqueeze(1)
  39. attn = F.softmax(attn, -1)
  40. if not self.training:
  41. self.attn_map = attn
  42. attn = self.attn_drop(attn)
  43. x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
  44. x = self.proj(x)
  45. x = self.proj_drop(x)
  46. return x
  47. class EdgeDecoderLayer(nn.Module):
  48. def __init__(
  49. self,
  50. dim,
  51. num_heads,
  52. mlp_ratio=4.0,
  53. qkv_bias=False,
  54. qk_scale=None,
  55. drop=0.0,
  56. attn_drop=0.0,
  57. drop_path=[0.0, 0.0],
  58. act_layer=nn.GELU,
  59. norm_layer='nn.LayerNorm',
  60. epsilon=1e-6,
  61. ):
  62. super().__init__()
  63. self.head_dim = dim // num_heads
  64. self.scale = qk_scale or self.head_dim**-0.5
  65. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  66. self.drop_path1 = DropPath(
  67. drop_path[0]) if drop_path[0] > 0.0 else Identity()
  68. self.norm1 = eval(norm_layer)(dim, eps=epsilon)
  69. self.norm2 = eval(norm_layer)(dim, eps=epsilon)
  70. # self.c = nn.Linear(dim, dim*2)
  71. self.p = nn.Linear(dim, dim)
  72. self.cv = nn.Linear(dim, dim)
  73. self.pv = nn.Linear(dim, dim)
  74. self.dim = dim
  75. self.num_heads = num_heads
  76. self.p_proj = nn.Linear(dim, dim)
  77. mlp_hidden_dim = int(dim * mlp_ratio)
  78. self.mlp_ratio = mlp_ratio
  79. self.mlp = Mlp(
  80. in_features=dim,
  81. hidden_features=mlp_hidden_dim,
  82. act_layer=act_layer,
  83. drop=drop,
  84. )
  85. def forward(self, p, cv, pv):
  86. pN = p.shape[1]
  87. vN = cv.shape[1]
  88. p_shortcut = p
  89. p1 = self.p(p).reshape(
  90. [-1, pN, self.num_heads,
  91. self.dim // self.num_heads]).transpose(1, 2)
  92. cv1 = self.cv(cv).reshape(
  93. [-1, vN, self.num_heads,
  94. self.dim // self.num_heads]).transpose(1, 2)
  95. pv1 = self.pv(pv).reshape(
  96. [-1, vN, self.num_heads,
  97. self.dim // self.num_heads]).transpose(1, 2)
  98. edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
  99. p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
  100. x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
  101. x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
  102. return x
  103. class DecoderLayer(nn.Module):
  104. def __init__(
  105. self,
  106. dim,
  107. num_heads,
  108. mlp_ratio=4.0,
  109. qkv_bias=False,
  110. qk_scale=None,
  111. drop=0.0,
  112. attn_drop=0.0,
  113. drop_path=0.0,
  114. act_layer=nn.GELU,
  115. norm_layer='nn.LayerNorm',
  116. epsilon=1e-6,
  117. ):
  118. super().__init__()
  119. self.norm1 = eval(norm_layer)(dim, eps=epsilon)
  120. self.normkv = eval(norm_layer)(dim, eps=epsilon)
  121. self.mixer = CrossAttention(
  122. dim,
  123. num_heads=num_heads,
  124. qkv_bias=qkv_bias,
  125. qk_scale=qk_scale,
  126. attn_drop=attn_drop,
  127. proj_drop=drop,
  128. )
  129. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  130. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  131. self.norm2 = eval(norm_layer)(dim, eps=epsilon)
  132. mlp_hidden_dim = int(dim * mlp_ratio)
  133. self.mlp_ratio = mlp_ratio
  134. self.mlp = Mlp(
  135. in_features=dim,
  136. hidden_features=mlp_hidden_dim,
  137. act_layer=act_layer,
  138. drop=drop,
  139. )
  140. def forward(self, q, kv, key_mask=None):
  141. x1 = q + self.drop_path(
  142. self.mixer(self.norm1(q), self.normkv(kv), key_mask))
  143. x = x1 + self.drop_path(self.mlp(self.norm2(x1)))
  144. return x
  145. class CMFFLayer(nn.Module):
  146. def __init__(
  147. self,
  148. dim,
  149. num_heads,
  150. mlp_ratio=4.0,
  151. qkv_bias=False,
  152. qk_scale=None,
  153. drop=0.0,
  154. attn_drop=0.0,
  155. drop_path=0.0,
  156. act_layer=nn.GELU,
  157. epsilon=1e-6,
  158. ):
  159. super().__init__()
  160. self.normq1 = nn.LayerNorm(dim, eps=epsilon)
  161. self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
  162. self.images_to_question_cross_attn = CrossAttention(
  163. dim,
  164. num_heads=num_heads,
  165. qkv_bias=qkv_bias,
  166. qk_scale=qk_scale,
  167. attn_drop=attn_drop,
  168. proj_drop=drop,
  169. )
  170. self.normq2 = nn.LayerNorm(dim, eps=epsilon)
  171. self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
  172. self.question_to_images_cross_attn = CrossAttention(
  173. dim,
  174. num_heads=num_heads,
  175. qkv_bias=qkv_bias,
  176. qk_scale=qk_scale,
  177. attn_drop=attn_drop,
  178. proj_drop=drop,
  179. )
  180. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  181. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  182. self.normmlp = nn.LayerNorm(dim, eps=epsilon)
  183. mlp_hidden_dim = int(dim * mlp_ratio)
  184. self.mlp = Mlp(
  185. in_features=dim,
  186. hidden_features=mlp_hidden_dim,
  187. act_layer=act_layer,
  188. drop=drop,
  189. )
  190. def forward(self, question_f, prompt_f, visual_f, mask=None):
  191. query_add = torch.concat([question_f, prompt_f, visual_f], 1)
  192. query_add = query_add + self.drop_path(
  193. self.images_to_question_cross_attn(self.normq1(query_add),
  194. self.normkv1(prompt_f), mask))
  195. query_add = query_add + self.drop_path(
  196. self.question_to_images_cross_attn(
  197. self.normq2(query_add),
  198. self.normkv2(query_add[:, -visual_f.shape[1]:, :])))
  199. query_updated = query_add + self.drop_path(
  200. self.mlp(self.normmlp(query_add)))
  201. question_f_updated = query_updated[:, :question_f.shape[1], :]
  202. prompt_f_updated = query_updated[:, question_f.
  203. shape[1]:-visual_f.shape[1], :]
  204. visual_f_updated = query_updated[:, -visual_f.shape[1]:, :]
  205. return question_f_updated, prompt_f_updated, visual_f_updated
  206. class IGTRDecoder(nn.Module):
  207. """
  208. IGTRDecoder is a neural network module designed for decoding tasks in OCR (Optical Character Recognition) systems.
  209. It utilizes a combination of embedding layers, multi-head attention layers, and linear layers to process input sequences
  210. and generate output sequences.
  211. Args:
  212. in_channels (int): Number of input channels.
  213. dim (int): Dimension of the model.
  214. out_channels (int): Number of output channels.
  215. num_layer (int, optional): Number of layers in the decoder. Default is 2.
  216. drop_path_rate (float, optional): Drop path rate for stochastic depth. Default is 0.1.
  217. max_len (int, optional): Maximum length of the sequence. Default is 25.
  218. vis_seq (int, optional): Length of the visual sequence. Default is 50.
  219. ch (bool, optional): Flag for character embedding. Default is False.
  220. ar (bool, optional): Flag for autoregressive decoding. Default is False.
  221. refine_iter (int, optional): Number of refinement iterations. Default is 0.
  222. quesall (bool, optional): Flag to use all questions. Default is True.
  223. next_pred (bool, optional): Flag for next prediction. Default is False.
  224. ds (bool, optional): Flag for downsampling. Default is False.
  225. pos2d (bool, optional): Flag for 2D positional embedding. Default is False.
  226. check_search (bool, optional): Flag for checking search. Default is False.
  227. max_size (list, optional): Maximum size for 2D positional embedding. Default is [8, 32].
  228. **kwargs: Additional keyword arguments.
  229. Methods:
  230. _init_weights(m): Initializes the weights of the module.
  231. no_weight_decay(): Returns the parameters that should not have weight decay.
  232. question_encoder(targets, train_i): Encodes the questions based on the targets and training index.
  233. forward(x, data=None): Forward pass of the decoder. Calls either forward_train or forward_test based on the mode.
  234. forward_test(x): Forward pass during testing.
  235. forward_train(x, targets=None): Forward pass during training.
  236. Returns:
  237. Depending on the mode (training or testing), the forward method returns either the loss and logits (during training)
  238. or the predicted indices and probabilities (during testing).
  239. """
  240. def __init__(self,
  241. in_channels,
  242. dim,
  243. out_channels,
  244. num_layer=2,
  245. drop_path_rate=0.1,
  246. max_len=25,
  247. vis_seq=50,
  248. ch=False,
  249. ar=False,
  250. refine_iter=0,
  251. quesall=True,
  252. next_pred=False,
  253. ds=False,
  254. pos2d=False,
  255. check_search=False,
  256. max_size=[8, 32],
  257. **kwargs):
  258. super(IGTRDecoder, self).__init__()
  259. self.out_channels = out_channels
  260. self.dim = dim
  261. self.max_len = max_len + 3 # max_len + eos + bos
  262. self.ch = ch
  263. self.char_embed = Embeddings(d_model=dim,
  264. vocab=self.out_channels,
  265. scale_embedding=True)
  266. self.ignore_index = out_channels - 1
  267. self.ar = ar
  268. self.refine_iter = refine_iter
  269. self.bos = self.out_channels - 2
  270. self.eos = 0
  271. self.next_pred = next_pred
  272. self.quesall = quesall
  273. self.check_search = check_search
  274. dpr = np.linspace(0, drop_path_rate, num_layer + 2)
  275. self.cmff_decoder = nn.ModuleList([
  276. CMFFLayer(dim=dim,
  277. num_heads=dim // 32,
  278. mlp_ratio=4.0,
  279. qkv_bias=True,
  280. drop_path=dpr[i]) for i in range(num_layer)
  281. ])
  282. self.answer_to_question_layer = DecoderLayer(dim=dim,
  283. num_heads=dim // 32,
  284. mlp_ratio=4.0,
  285. qkv_bias=True,
  286. drop_path=dpr[-2])
  287. self.answer_to_image_layer = DecoderLayer(dim=dim,
  288. num_heads=dim // 32,
  289. mlp_ratio=4.0,
  290. qkv_bias=True,
  291. drop_path=dpr[-1])
  292. self.char_pos_embed = nn.Parameter(torch.zeros([self.max_len, dim],
  293. dtype=torch.float32),
  294. requires_grad=True)
  295. self.appear_num_embed = nn.Parameter(torch.zeros([self.max_len, dim],
  296. dtype=torch.float32),
  297. requires_grad=True)
  298. self.ds = ds
  299. self.pos2d = pos2d
  300. if not ds:
  301. self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
  302. dtype=torch.float32),
  303. requires_grad=True)
  304. trunc_normal_(self.vis_pos_embed, std=0.02)
  305. elif pos2d:
  306. pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
  307. dtype=torch.float32)
  308. trunc_normal_(pos_embed, mean=0, std=0.02)
  309. self.vis_pos_embed = nn.Parameter(
  310. pos_embed.transpose(1, 2).reshape(1, dim, max_size[0],
  311. max_size[1]),
  312. requires_grad=True,
  313. )
  314. self.prompt_pos_embed = nn.Parameter(torch.zeros([1, 6, dim],
  315. dtype=torch.float32),
  316. requires_grad=True)
  317. self.answer_query = nn.Parameter(torch.zeros([1, 1, dim],
  318. dtype=torch.float32),
  319. requires_grad=True)
  320. self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
  321. self.ques1_head = nn.Linear(dim, self.out_channels - 2)
  322. self.ques2_head = nn.Linear(dim, self.max_len, bias=False)
  323. self.ques3_head = nn.Linear(dim, self.max_len - 1)
  324. self.ques4_head = nn.Linear(dim, self.max_len - 1)
  325. trunc_normal_(self.char_pos_embed, std=0.02)
  326. trunc_normal_(self.appear_num_embed, std=0.02)
  327. trunc_normal_(self.answer_query, std=0.02)
  328. trunc_normal_(self.prompt_pos_embed, std=0.02)
  329. self.apply(self._init_weights)
  330. def _init_weights(self, m):
  331. if isinstance(m, nn.Linear):
  332. trunc_normal_(m.weight, std=0.02)
  333. if isinstance(m, nn.Linear) and m.bias is not None:
  334. zeros_(m.bias)
  335. elif isinstance(m, nn.LayerNorm):
  336. zeros_(m.bias)
  337. ones_(m.weight)
  338. @torch.jit.ignore
  339. def no_weight_decay(self):
  340. return {
  341. 'char_pos_embed', 'vis_pos_embed', 'appear_num_embed',
  342. 'answer_query', 'char_embed'
  343. }
  344. def question_encoder(self, targets, train_i):
  345. (
  346. prompt_pos_idx,
  347. prompt_char_idx,
  348. ques_pos_idx,
  349. ques1_answer,
  350. ques2_char_idx,
  351. ques2_answer,
  352. ques4_char_num,
  353. ques_len,
  354. ques2_len,
  355. prompt_len,
  356. ) = targets
  357. max_ques_len = torch.max(ques_len)
  358. max_ques2_len = torch.max(ques2_len)
  359. max_prompt_len = torch.max(prompt_len)
  360. if self.next_pred and (train_i == 2 or train_i == 3):
  361. prompt_pos = self.prompt_pos_embed
  362. prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
  363. else:
  364. prompt_pos = F.embedding(
  365. prompt_pos_idx[:, :max_prompt_len], self.char_pos_embed
  366. ) # bs lp [ 0, 4, 3, 12, 12, 12, 12, 12, 12, 12, 12]
  367. prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
  368. prompt_char = self.char_embed(prompt_char_idx) # bs lp
  369. prompt = prompt_pos + prompt_char
  370. mask_1234 = torch.where(prompt_char_idx == self.ignore_index,
  371. float('-inf'), 0)
  372. ques1 = F.embedding(ques_pos_idx[:, :max_ques_len],
  373. self.char_pos_embed) # bs lq1 dim
  374. ques1_answer = ques1_answer[:, :max_ques_len]
  375. if self.quesall or train_i == 0:
  376. ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, 1])
  377. ques2 = ques2_char + F.embedding(ques2_char_idx[:, :max_ques2_len,
  378. 0],
  379. self.char_pos_embed) # bs lq2 dim
  380. ques2_answer = ques2_answer[:, :max_ques2_len]
  381. ques2_head = F.embedding(ques2_char_idx[:, :max_ques2_len, 0],
  382. self.ques2_head.weight)
  383. ques4_char = self.char_embed(ques1_answer)
  384. ques4_ap_num = F.embedding(ques4_char_num[:, :max_ques_len],
  385. self.appear_num_embed)
  386. ques4 = ques4_char + ques4_ap_num
  387. ques4_answer = ques_pos_idx[:, :max_ques_len]
  388. return (
  389. prompt,
  390. ques1,
  391. ques2,
  392. ques2_head,
  393. ques4,
  394. ques1_answer,
  395. ques2_answer,
  396. ques4_answer,
  397. mask_1234.unsqueeze(1),
  398. )
  399. else:
  400. return prompt, ques1, ques1_answer, mask_1234.unsqueeze(1)
  401. def forward(self, x, data=None):
  402. if self.training:
  403. return self.forward_train(x, data)
  404. else:
  405. return self.forward_test(x)
  406. def forward_test(self, x):
  407. """
  408. Perform the forward pass for the test phase.
  409. Args:
  410. x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
  411. Returns:
  412. torch.Tensor or List[torch.Tensor]: The output logits or a list containing predicted indices and probabilities.
  413. The function handles different modes of operation based on the attributes:
  414. - `self.ds`: Determines if positional embedding is added to the input tensor.
  415. - `self.pos2d`: Determines if the positional embedding is 2D.
  416. - `self.ar`: Determines if autoregressive decoding is used.
  417. - `self.check_search`: Determines if beam search is used.
  418. - `self.next_pred`: Determines if next token prediction is used.
  419. - `self.refine_iter`: Number of refinement iterations for the predictions.
  420. The function performs the following steps:
  421. 1. Adds positional embeddings to the input tensor if required.
  422. 2. Initializes the BOS (beginning of sequence) prompt.
  423. 3. Depending on the mode, performs decoding using different strategies:
  424. - Beam search decoding.
  425. - Autoregressive decoding.
  426. - Next token prediction.
  427. 4. If refinement iterations are specified, refines the predictions.
  428. 5. Returns the final logits or the predicted indices and probabilities.
  429. """
  430. if not self.ds:
  431. visual_f = x + self.vis_pos_embed
  432. elif self.pos2d:
  433. x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  434. visual_f = x.flatten(2).transpose(1, 2)
  435. else:
  436. visual_f = x
  437. bs = x.shape[0]
  438. prompt_bos = self.char_embed(
  439. torch.full(
  440. [bs, 1], self.bos, dtype=torch.long,
  441. device=x.get_device())) + self.char_pos_embed[:1, :].unsqueeze(
  442. 0) # BOS prompt
  443. ques_all = torch.tile(self.char_pos_embed.unsqueeze(0), (bs, 1, 1))
  444. if not self.ar:
  445. if self.check_search:
  446. tgt_in = torch.full((bs, self.max_len),
  447. self.ignore_index,
  448. dtype=torch.long,
  449. device=x.get_device())
  450. tgt_in[:, 0] = self.bos
  451. logits = []
  452. for j in range(1, self.max_len):
  453. visual_f_check = visual_f
  454. ques_check_i = ques_all[:, j:j + 1, :] + self.char_embed(
  455. torch.arange(self.out_channels - 2,
  456. device=x.get_device())).unsqueeze(0)
  457. prompt_check = ques_all[:, :j] + self.char_embed(
  458. tgt_in[:, :j])
  459. # prompt_check = prompt_bos
  460. mask = torch.where(
  461. (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
  462. float('-inf'), 0)
  463. for layer in self.cmff_decoder:
  464. ques_check_i, prompt_check, visual_f_check = layer(
  465. ques_check_i, prompt_check, visual_f_check,
  466. mask.unsqueeze(1))
  467. answer_query_i = self.answer_to_question_layer(
  468. ques_check_i, prompt_check, mask.unsqueeze(1))
  469. answer_pred_i = self.norm_pred(
  470. self.answer_to_image_layer(
  471. answer_query_i, visual_f_check)) # B, 26, 37
  472. # the next token probability is in the output's ith token position
  473. fc_2 = self.ques2_head.weight[j:j + 1].unsqueeze(0)
  474. fc_2 = fc_2.tile([bs, 1, 1])
  475. p_i = fc_2 @ answer_pred_i.transpose(1, 2)
  476. # p_i = p_i[:, 0, :]
  477. logits.append(p_i)
  478. if j < self.max_len - 1:
  479. # greedy decode. add the next token index to the target input
  480. tgt_in[:, j] = p_i.squeeze().argmax(-1)
  481. # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
  482. if (tgt_in == self.eos).any(dim=-1).all():
  483. break
  484. logits = torch.cat(logits, dim=1)
  485. else:
  486. ques_pd = ques_all[:, 1:, :]
  487. prompt_pd = prompt_bos
  488. visual_f_pd = visual_f
  489. for layer in self.cmff_decoder:
  490. ques_pd, prompt_pd, visual_f_pd = layer(
  491. ques_pd, prompt_pd, visual_f_pd)
  492. answer_query_pd = self.answer_to_question_layer(
  493. ques_pd, prompt_pd)
  494. answer_feats_pd = self.norm_pred(
  495. self.answer_to_image_layer(answer_query_pd,
  496. visual_f_pd)) # B, 26, 37
  497. logits = self.ques1_head(answer_feats_pd)
  498. elif self.next_pred:
  499. ques_pd_1 = ques_all[:, 1:2, :]
  500. prompt_pd = prompt_bos
  501. visual_f_pd = visual_f
  502. for layer in self.cmff_decoder:
  503. ques_pd_1, prompt_pd, visual_f_pd = layer(
  504. ques_pd_1, prompt_pd, visual_f_pd)
  505. answer_query_pd = self.answer_to_question_layer(
  506. ques_pd_1, prompt_pd)
  507. answer_feats_pd = self.norm_pred(
  508. self.answer_to_image_layer(answer_query_pd,
  509. visual_f_pd)) # B, 26, 37
  510. logits_pd_1 = self.ques1_head(answer_feats_pd)
  511. ques_next = self.char_pos_embed[-2:-1, :].unsqueeze(0).tile(
  512. [bs, 1, 1])
  513. prompt_next_bos = (self.char_embed(
  514. torch.full(
  515. [bs, 1], self.bos, dtype=torch.long,
  516. device=x.get_device())) + self.prompt_pos_embed[:, :1, :])
  517. pred_prob, pred_id = F.softmax(logits_pd_1, -1).max(-1)
  518. pred_prob_list = [pred_prob]
  519. pred_id_list = [pred_id]
  520. for j in range(1, 70):
  521. prompt_next_1 = self.char_embed(
  522. pred_id) + self.prompt_pos_embed[:,
  523. -1 * pred_id.shape[1]:, :]
  524. prompt_next = torch.concat([prompt_next_bos, prompt_next_1], 1)
  525. ques_next_i = ques_next
  526. visual_f_i = visual_f
  527. for layer in self.cmff_decoder:
  528. ques_next_i, prompt_next, visual_f_pd = layer(
  529. ques_next_i, prompt_next, visual_f_i)
  530. answer_query_next_i = self.answer_to_question_layer(
  531. ques_next_i, prompt_next)
  532. answer_feats_next_i = self.norm_pred(
  533. self.answer_to_image_layer(answer_query_next_i,
  534. visual_f_i)) # B, 26, 37
  535. logits_next_i = self.ques1_head(answer_feats_next_i)
  536. # pred_id = logits_next_i.argmax(-1)
  537. pred_prob_i, pred_id_i = F.softmax(logits_next_i, -1).max(-1)
  538. pred_prob_list.append(pred_prob_i)
  539. pred_id_list.append(pred_id_i)
  540. if (torch.concat(pred_id_list,
  541. 1) == self.eos).any(dim=-1).all():
  542. break
  543. if pred_id.shape[1] >= 5:
  544. pred_id = torch.concat([pred_id[:, 1:], pred_id_i], 1)
  545. else:
  546. pred_id = torch.concat([pred_id, pred_id_i], 1)
  547. return [
  548. torch.concat(pred_id_list, 1),
  549. torch.concat(pred_prob_list, 1)
  550. ]
  551. else:
  552. tgt_in = torch.full((bs, self.max_len),
  553. self.ignore_index,
  554. dtype=torch.long,
  555. device=x.get_device())
  556. tgt_in[:, 0] = self.bos
  557. logits = []
  558. for j in range(1, self.max_len):
  559. visual_f_ar = visual_f
  560. ques_i = ques_all[:, j:j + 1, :]
  561. prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j])
  562. mask = torch.where(
  563. (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
  564. float('-inf'), 0)
  565. for layer in self.cmff_decoder:
  566. ques_i, prompt_ar, visual_f_ar = layer(
  567. ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(1))
  568. answer_query_i = self.answer_to_question_layer(
  569. ques_i, prompt_ar, mask.unsqueeze(1))
  570. answer_pred_i = self.norm_pred(
  571. self.answer_to_image_layer(answer_query_i,
  572. visual_f_ar)) # B, 26, 37
  573. # the next token probability is in the output's ith token position
  574. p_i = self.ques1_head(answer_pred_i)
  575. logits.append(p_i)
  576. if j < self.max_len - 1:
  577. # greedy decode. add the next token index to the target input
  578. tgt_in[:, j] = p_i.squeeze().argmax(-1)
  579. # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
  580. if (tgt_in == self.eos).any(dim=-1).all():
  581. break
  582. logits = torch.cat(logits, dim=1)
  583. if self.refine_iter > 0:
  584. pred_probs, pred_idxs = F.softmax(logits, -1).max(-1)
  585. for i in range(self.refine_iter):
  586. mask_check = (pred_idxs == self.eos).int().cumsum(-1) <= 1
  587. ques_check_all = self.char_embed(
  588. pred_idxs) + ques_all[:, 1:pred_idxs.shape[1] + 1, :]
  589. prompt_check = prompt_bos
  590. visual_f_check = visual_f
  591. ques_check = ques_check_all
  592. for layer in self.cmff_decoder:
  593. ques_check, prompt_check, visual_f_check = layer(
  594. ques_check, prompt_check, visual_f_check)
  595. answer_query_check = self.answer_to_question_layer(
  596. ques_check, prompt_check)
  597. answer_pred_check = self.norm_pred(
  598. self.answer_to_image_layer(answer_query_check,
  599. visual_f_check)) # B, 26, 37
  600. ques2_head = self.ques2_head.weight[1:pred_idxs.shape[1] +
  601. 1, :]
  602. ques2_head = torch.tile(ques2_head.unsqueeze(0), [bs, 1, 1])
  603. answer2_pred = answer_pred_check.matmul(
  604. ques2_head.transpose(1, 2))
  605. diag_mask = torch.eye(answer2_pred.shape[1],
  606. device=x.get_device()).unsqueeze(0).tile(
  607. [bs, 1, 1])
  608. answer2_pred = F.sigmoid(
  609. (answer2_pred * diag_mask).sum(-1)) * mask_check
  610. check_result = answer2_pred < 0.9 # pred_probs < 0.99
  611. prompt_refine = torch.concat([prompt_bos, ques_check_all], 1)
  612. mask_refine = torch.where(
  613. check_result, float('-inf'), 0) + torch.where(
  614. (pred_idxs == self.eos).int().cumsum(-1) < 1, 0,
  615. float('-inf'))
  616. mask_refine = torch.concat(
  617. [torch.zeros([bs, 1], device=x.get_device()), mask_refine],
  618. 1).unsqueeze(1)
  619. ques_refine = ques_all[:, 1:pred_idxs.shape[1] + 1, :]
  620. visual_f_refine = visual_f
  621. for layer in self.cmff_decoder:
  622. ques_refine, prompt_refine, visual_f_refine = layer(
  623. ques_refine, prompt_refine, visual_f_refine,
  624. mask_refine)
  625. answer_query_refine = self.answer_to_question_layer(
  626. ques_refine, prompt_refine, mask_refine)
  627. answer_pred_refine = self.norm_pred(
  628. self.answer_to_image_layer(answer_query_refine,
  629. visual_f_refine)) # B, 26, 37
  630. answer_refine = self.ques1_head(answer_pred_refine)
  631. refine_probs, refine_idxs = F.softmax(answer_refine,
  632. -1).max(-1)
  633. pred_idxs_refine = torch.where(check_result, refine_idxs,
  634. pred_idxs)
  635. pred_idxs = torch.where(mask_check, pred_idxs_refine,
  636. pred_idxs)
  637. pred_probs_refine = torch.where(check_result, refine_probs,
  638. pred_probs)
  639. pred_probs = torch.where(mask_check, pred_probs_refine,
  640. pred_probs)
  641. return [pred_idxs, pred_probs]
  642. return F.softmax(logits, -1)
  643. def forward_train(self, x, targets=None):
  644. """
  645. Forward pass for training the model.
  646. Args:
  647. x (torch.Tensor): Input tensor of shape (batch_size, ...).
  648. targets (list, optional): List of target tensors. The list should contain:
  649. - targets[1]: Tensor of shape (batch_size, ...), prompt position indices.
  650. - targets[2]: Tensor of shape (batch_size, ...), prompt character indices.
  651. - targets[3]: Tensor of shape (batch_size, ...), question position indices.
  652. - targets[4]: Tensor of shape (batch_size, ...), question 1 answers.
  653. - targets[5]: Tensor of shape (batch_size, ...), question 2 character indices.
  654. - targets[6]: Tensor of shape (batch_size, ...), question 2 answers.
  655. - targets[7]: Tensor of shape (batch_size, ..., 2), question 3 character indices and answers.
  656. - targets[8]: Tensor of shape (batch_size, ...), question 4 character numbers.
  657. - targets[9]: Tensor of shape (batch_size, ...), question lengths.
  658. - targets[10]: Tensor of shape (batch_size, ...), prompt lengths.
  659. - targets[11]: Tensor of shape (batch_size, ...), question 4 answers.
  660. Returns:
  661. list: A list containing:
  662. - loss (dict): Dictionary containing the total loss and individual losses for each question.
  663. - 'loss': Total loss.
  664. - 'loss1': Loss for question 1.
  665. - 'loss2': Loss for question 2.
  666. - 'loss3': Loss for question 3.
  667. - 'loss4': Loss for question 4.
  668. - logits (torch.Tensor): Logits for question 1 predictions.
  669. """
  670. bs = x.shape[0]
  671. answer_token = torch.tile(self.answer_query, (bs, 1, 1))
  672. if self.ch:
  673. ques3 = self.char_embed(targets[7][:, :,
  674. 0]) + answer_token # bs nc dim
  675. ques3_answer = targets[7][:, :, 1]
  676. else:
  677. ques3 = self.char_embed(
  678. torch.arange(self.out_channels - 2, device=x.get_device())
  679. ).unsqueeze(0) + answer_token # bs nc dim
  680. ques3_answer = targets[7]
  681. loss1_list = []
  682. loss2_list = []
  683. loss3_list = []
  684. loss4_list = []
  685. sampler1_num = 0
  686. sampler2_num = 0
  687. sampler3_num = 0
  688. sampler4_num = 0
  689. if not self.ds:
  690. visual_f = x + self.vis_pos_embed
  691. elif self.pos2d:
  692. x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  693. visual_f = x.flatten(2).transpose(1, 2)
  694. else:
  695. visual_f = x
  696. train_i = 0
  697. for target_ in zip(
  698. targets[1].transpose(0, 1),
  699. targets[2].transpose(0, 1),
  700. targets[3].transpose(0, 1),
  701. targets[4].transpose(0, 1),
  702. targets[5].transpose(0, 1),
  703. targets[6].transpose(0, 1),
  704. targets[8].transpose(0, 1),
  705. targets[9].transpose(0, 1),
  706. targets[10].transpose(0, 1),
  707. targets[11].transpose(0, 1),
  708. ):
  709. # target_ = [prompt_pos_idx, prompt_char_idx, ques_pos_idx, ques1_answer, \
  710. # ques2_char_idx, ques2_answer, ques4_char_num, ques_len, prompt_len]
  711. visual_f_1234 = visual_f
  712. if self.quesall or train_i == 0:
  713. (
  714. prompt,
  715. ques1,
  716. ques2,
  717. ques2_head,
  718. ques4,
  719. ques1_answer,
  720. ques2_answer,
  721. ques4_answer,
  722. mask_1234,
  723. ) = self.question_encoder(target_, train_i)
  724. prompt_1234 = prompt
  725. ques_1234 = torch.concat([ques1, ques2, ques3, ques4], 1)
  726. for layer in self.cmff_decoder:
  727. ques_1234, prompt_1234, visual_f_1234 = layer(
  728. ques_1234, prompt_1234, visual_f_1234, mask_1234)
  729. answer_query_1234 = self.answer_to_question_layer(
  730. ques_1234, prompt_1234, mask_1234)
  731. answer_feats_1234 = self.norm_pred(
  732. self.answer_to_image_layer(answer_query_1234,
  733. visual_f_1234)) # B, 26, 37
  734. answer_feats_1 = answer_feats_1234[:, :ques1.shape[1], :]
  735. answer_feats_2 = answer_feats_1234[:, ques1.shape[1]:(
  736. ques1.shape[1] + ques2.shape[1]), :]
  737. answer_feats_3 = answer_feats_1234[:, (
  738. ques1.shape[1] + ques2.shape[1]):-ques4.shape[1], :]
  739. answer_feats_4 = answer_feats_1234[:, -ques4.shape[1]:, :]
  740. answer1_pred = self.ques1_head(answer_feats_1)
  741. if train_i == 0:
  742. logits = answer1_pred
  743. n = (ques1_answer != self.ignore_index).sum().item()
  744. loss1 = n * F.cross_entropy(
  745. answer1_pred.flatten(0, 1),
  746. ques1_answer.flatten(0, 1),
  747. ignore_index=self.ignore_index,
  748. reduction='mean',
  749. )
  750. sampler1_num += n
  751. loss1_list.append(loss1)
  752. answer2_pred = answer_feats_2.matmul(ques2_head.transpose(
  753. 1, 2))
  754. diag_mask = torch.eye(answer2_pred.shape[1],
  755. device=x.get_device()).unsqueeze(0).tile(
  756. [bs, 1, 1])
  757. answer2_pred = (answer2_pred * diag_mask).sum(-1)
  758. ques2_answer = ques2_answer.flatten(0, 1)
  759. non_pad_mask = torch.not_equal(ques2_answer, self.ignore_index)
  760. n = non_pad_mask.sum().item()
  761. ques2_answer = torch.where(ques2_answer == self.ignore_index,
  762. 0, ques2_answer)
  763. loss2_none = F.binary_cross_entropy_with_logits(
  764. answer2_pred.flatten(0, 1), ques2_answer, reduction='none')
  765. loss2 = n * loss2_none.masked_select(non_pad_mask).mean()
  766. sampler2_num += n
  767. loss2_list.append(loss2)
  768. answer3_pred = self.ques3_head(answer_feats_3)
  769. n = (ques3_answer != self.ignore_index).sum().item()
  770. loss3 = n * F.cross_entropy(answer3_pred.flatten(0, 1),
  771. ques3_answer.flatten(0, 1),
  772. reduction='mean')
  773. sampler3_num += n
  774. loss3_list.append(loss3)
  775. answer4_pred = self.ques4_head(answer_feats_4)
  776. n = (ques4_answer != self.max_len - 1).sum().item()
  777. loss4 = n * F.cross_entropy(
  778. answer4_pred.flatten(0, 1),
  779. ques4_answer.flatten(0, 1),
  780. ignore_index=self.max_len - 1,
  781. reduction='mean',
  782. )
  783. sampler4_num += n
  784. loss4_list.append(loss4)
  785. else:
  786. prompt, ques1, ques1_answer, mask_1234 = self.question_encoder(
  787. target_, train_i)
  788. prompt_1234 = prompt
  789. for layer in self.cmff_decoder:
  790. ques1, prompt_1234, visual_f_1234 = layer(
  791. ques1, prompt_1234, visual_f_1234, mask_1234)
  792. answer_query_1 = self.answer_to_question_layer(
  793. ques1, prompt_1234, mask_1234)
  794. answer_feats_1 = self.norm_pred(
  795. self.answer_to_image_layer(answer_query_1,
  796. visual_f_1234)) # B, 26, 37
  797. answer1_pred = self.ques1_head(answer_feats_1)
  798. n = (ques1_answer != self.ignore_index).sum().item()
  799. loss1 = n * F.cross_entropy(
  800. answer1_pred.flatten(0, 1),
  801. ques1_answer.flatten(0, 1),
  802. ignore_index=self.ignore_index,
  803. reduction='mean',
  804. )
  805. sampler1_num += n
  806. loss1_list.append(loss1)
  807. train_i += 1
  808. loss_list = [
  809. sum(loss1_list) / sampler1_num,
  810. sum(loss2_list) / sampler2_num,
  811. sum(loss3_list) / sampler3_num,
  812. sum(loss4_list) / sampler4_num,
  813. ]
  814. loss = {
  815. 'loss': sum(loss_list),
  816. 'loss1': loss_list[0],
  817. 'loss2': loss_list[1],
  818. 'loss3': loss_list[2],
  819. 'loss4': loss_list[3],
  820. }
  821. return [loss, logits]