smtr_decoder.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torch.nn.init import ones_, trunc_normal_, zeros_
  6. from openrec.modeling.common import DropPath, Identity
  7. from openrec.modeling.decoders.nrtr_decoder import Embeddings
  8. class CrossAttention(nn.Module):
  9. def __init__(
  10. self,
  11. dim,
  12. num_heads=8,
  13. qkv_bias=False,
  14. qk_scale=None,
  15. attn_drop=0.0,
  16. proj_drop=0.0,
  17. ):
  18. super().__init__()
  19. self.num_heads = num_heads
  20. head_dim = dim // num_heads
  21. self.scale = qk_scale or head_dim**-0.5
  22. self.q = nn.Linear(dim, dim, bias=qkv_bias)
  23. self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
  24. self.attn_drop = nn.Dropout(attn_drop)
  25. self.proj = nn.Linear(dim, dim)
  26. self.proj_drop = nn.Dropout(proj_drop)
  27. def forward(self, q, kv, key_mask=None):
  28. N, C = kv.shape[1:]
  29. QN = q.shape[1]
  30. q = self.q(q).reshape([-1, QN, self.num_heads,
  31. C // self.num_heads]).transpose(1, 2)
  32. q = q * self.scale
  33. k, v = self.kv(kv).reshape(
  34. [-1, N, 2, self.num_heads,
  35. C // self.num_heads]).permute(2, 0, 3, 1, 4)
  36. attn = q.matmul(k.transpose(2, 3))
  37. if key_mask is not None:
  38. attn = attn + key_mask.unsqueeze(1)
  39. attn = F.softmax(attn, -1)
  40. if not self.training:
  41. self.attn_map = attn
  42. attn = self.attn_drop(attn)
  43. x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
  44. x = self.proj(x)
  45. x = self.proj_drop(x)
  46. return x
  47. class SSMatchLayer(nn.Module):
  48. def __init__(
  49. self,
  50. dim,
  51. nextq2subs_head2=None,
  52. dynq2img_heads=2,
  53. mlp_ratio=4.0,
  54. qkv_bias=False,
  55. qk_scale=None,
  56. drop=0.0,
  57. attn_drop=0.0,
  58. drop_path=0.0,
  59. act_layer=nn.GELU,
  60. epsilon=1e-6,
  61. is_last_layer=False,
  62. ):
  63. super().__init__()
  64. self.dim = dim
  65. if nextq2subs_head2 is None:
  66. nextq2subs_head2 = dim // 32
  67. self.normq1 = nn.LayerNorm(dim, eps=epsilon)
  68. self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
  69. self.images_to_question_cross_attn = CrossAttention(
  70. dim,
  71. num_heads=nextq2subs_head2,
  72. qkv_bias=qkv_bias,
  73. qk_scale=qk_scale,
  74. attn_drop=attn_drop,
  75. proj_drop=drop)
  76. self.normq2 = nn.LayerNorm(dim, eps=epsilon)
  77. self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
  78. self.question_to_images_cross_attn = CrossAttention(
  79. dim,
  80. num_heads=dynq2img_heads,
  81. qkv_bias=qkv_bias,
  82. qk_scale=qk_scale,
  83. attn_drop=attn_drop,
  84. proj_drop=drop)
  85. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  86. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  87. self.is_last_layer = is_last_layer
  88. def forward(self, question_f, prompt_f, visual_f, mask=None):
  89. question_f = question_f + self.drop_path(
  90. self.images_to_question_cross_attn(self.normq1(question_f),
  91. self.normkv1(prompt_f), mask))
  92. question_f = question_f.reshape(visual_f.shape[0], -1, self.dim)
  93. question_f = self.question_to_images_cross_attn(
  94. self.normq2(question_f), self.normkv2(visual_f))
  95. if self.is_last_layer:
  96. return question_f
  97. return question_f.flatten(0, 1).unsqueeze(1)
  98. class SMTRDecoder(nn.Module):
  99. def __init__(self,
  100. in_channels,
  101. out_channels,
  102. num_layer=2,
  103. nextq2subs_head2=None,
  104. dynq2img_heads=2,
  105. drop_path_rate=0.1,
  106. max_len=25,
  107. vis_seq=50,
  108. ds=False,
  109. pos2d=False,
  110. max_size=[8, 32],
  111. sub_str_len=5,
  112. next_mode=True,
  113. infer_aug=False,
  114. bi_attn=False,
  115. **kwargs):
  116. super(SMTRDecoder, self).__init__()
  117. self.out_channels = out_channels
  118. dim = in_channels
  119. self.dim = dim
  120. self.max_len = max_len + 3 # max_len + eos + bos
  121. self.char_embed = Embeddings(d_model=dim,
  122. vocab=self.out_channels,
  123. scale_embedding=True)
  124. self.ignore_index = out_channels - 1
  125. self.sub_str_len = sub_str_len
  126. self.bos_next = out_channels - 3
  127. self.bos_pre = out_channels - 2
  128. self.eos = 0
  129. dpr = np.linspace(0, drop_path_rate, num_layer + 2)
  130. self.next_mode = next_mode
  131. self.infer_aug = infer_aug
  132. self.bi_attn = bi_attn
  133. self.cmff_decoder = nn.ModuleList([
  134. SSMatchLayer(dim=dim,
  135. nextq2subs_head2=nextq2subs_head2,
  136. dynq2img_heads=dynq2img_heads,
  137. mlp_ratio=4.0,
  138. qkv_bias=True,
  139. drop_path=dpr[i],
  140. is_last_layer=i==num_layer-1) for i in range(num_layer)
  141. ])
  142. self.ds = ds
  143. self.pos2d = pos2d
  144. if not ds:
  145. self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
  146. dtype=torch.float32),
  147. requires_grad=True)
  148. trunc_normal_(self.vis_pos_embed, std=0.02)
  149. elif pos2d:
  150. pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
  151. dtype=torch.float32)
  152. trunc_normal_(pos_embed, mean=0, std=0.02)
  153. self.vis_pos_embed = nn.Parameter(pos_embed.transpose(
  154. 1, 2).reshape(1, dim, max_size[0], max_size[1]),
  155. requires_grad=True)
  156. self.next_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
  157. dtype=torch.float32),
  158. requires_grad=True)
  159. self.pre_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
  160. dtype=torch.float32),
  161. requires_grad=True)
  162. self.prompt_next_embed = nn.Parameter(torch.zeros(
  163. [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
  164. requires_grad=True)
  165. self.prompt_pre_embed = nn.Parameter(torch.zeros(
  166. [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
  167. requires_grad=True)
  168. self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
  169. self.ques1_head = nn.Linear(dim, self.out_channels - 3)
  170. trunc_normal_(self.next_token, std=0.02)
  171. trunc_normal_(self.pre_token, std=0.02)
  172. trunc_normal_(self.prompt_pre_embed, std=0.02)
  173. trunc_normal_(self.prompt_next_embed, std=0.02)
  174. self.apply(self._init_weights)
  175. def _init_weights(self, m):
  176. if isinstance(m, nn.Linear):
  177. trunc_normal_(m.weight, std=0.02)
  178. if isinstance(m, nn.Linear) and m.bias is not None:
  179. zeros_(m.bias)
  180. elif isinstance(m, nn.LayerNorm):
  181. zeros_(m.bias)
  182. ones_(m.weight)
  183. @torch.jit.ignore
  184. def no_weight_decay(self):
  185. return {'vis_pos_embed', 'pre_token', 'next_token', 'char_embed'}
  186. def forward(self, x, data=None):
  187. if self.training:
  188. return self.forward_train(x, data)
  189. else:
  190. if self.infer_aug:
  191. if self.bi_attn:
  192. return self.forward_test_bi_attn(x)
  193. return self.forward_test_bi(x)
  194. return self.forward_test(x)
  195. def forward_test_bi(self, x):
  196. # self.attn_maps = []
  197. if not self.ds:
  198. visual_f = x + self.vis_pos_embed
  199. elif self.pos2d:
  200. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  201. visual_f = x.flatten(2).transpose(1, 2)
  202. else:
  203. visual_f = x
  204. bs = 2
  205. if 1:
  206. next = self.next_token
  207. pre = self.pre_token
  208. next_pre = torch.concat([next, pre], 0)
  209. next_pre = next_pre.squeeze(1) #2, 1, dim
  210. prompt_next_embed = self.prompt_next_embed.squeeze(1)
  211. prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
  212. next_id = torch.full([1, self.sub_str_len],
  213. self.bos_next,
  214. dtype=torch.long,
  215. device=x.device)
  216. pre_id = torch.full([1, self.sub_str_len],
  217. self.bos_pre,
  218. dtype=torch.long,
  219. device=x.device)
  220. # prompt_next_bos = self.char_embed(prompt_id)
  221. # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device)
  222. next_pred_id_list = torch.full([1, self.max_len],
  223. self.ignore_index,
  224. dtype=torch.long,
  225. device=x.device)
  226. pre_pred_id_list = torch.full([1, self.max_len],
  227. self.ignore_index,
  228. dtype=torch.long,
  229. device=x.device)
  230. next_logits_all = []
  231. pre_logits_all = []
  232. mask_pad = torch.zeros([bs, 1],
  233. dtype=torch.float32,
  234. device=x.device)
  235. for j in range(0, min(70, self.max_len - 1)):
  236. prompt_char_next = torch.concat([
  237. prompt_next_embed[:, :1, :],
  238. prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
  239. ], 1) # b, sub_l, dim
  240. prompt_char_pre = torch.concat([
  241. prompt_pre_embed[:, :1, :],
  242. prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)
  243. ], 1) # b, sub_l, dim
  244. prompt_char = torch.concat([prompt_char_next, prompt_char_pre],
  245. 0) #2, 6, dim
  246. # prompt_char = prompt_char.flatten(0, 1)
  247. mask_next = torch.where(next_id == self.bos_next,
  248. float('-inf'), 0) # b, subs_l
  249. mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'),
  250. 0) # b, subs_l
  251. mask = torch.concat([mask_next, mask_pre], 0) #2, 5
  252. mask = torch.concat([mask_pad, mask], 1) # 2, 6
  253. pred_token = next_pre
  254. visual_f_i = visual_f[:2] # 2 l dim
  255. for layer in self.cmff_decoder:
  256. pred_token = layer(pred_token, prompt_char, visual_f_i,
  257. mask.unsqueeze(1))
  258. logits_next_i = self.ques1_head(self.norm_pred(pred_token))
  259. logits = F.softmax(logits_next_i, -1)
  260. pred_id_i = logits.argmax(-1) #2, 1
  261. # print(pred_id_i.shape)
  262. next_pred_id_list[:, j:j + 1] = pred_id_i[:1]
  263. pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2]
  264. if not (next_pred_id_list == self.eos).any(dim=-1).all():
  265. next_logits_all.append(logits[:1])
  266. next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
  267. if not (pre_pred_id_list == self.eos).any(dim=-1).all():
  268. pre_logits_all.append(logits[1:2])
  269. pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
  270. if (next_pred_id_list == self.eos).any(dim=-1).all() and (
  271. pre_pred_id_list == self.eos).any(dim=-1).all():
  272. break
  273. # print(next_id, pre_id)
  274. # exit(0)
  275. if len(next_logits_all) > self.sub_str_len and len(
  276. pre_logits_all) > self.sub_str_len:
  277. next_logits_all_ = torch.concat(next_logits_all[:-1],
  278. 1) # 1, l
  279. pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1],
  280. 1) #1, l
  281. next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
  282. pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
  283. next_logits_all = []
  284. ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
  285. mask_pad = torch.zeros([1, 1],
  286. dtype=torch.float32,
  287. device=x.device)
  288. for j in range(0, min(70, self.max_len - 1)):
  289. prompt_next = torch.concat([
  290. prompt_next_embed[:, :1, :],
  291. prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
  292. ], 1) # b, sub_l, dim
  293. mask_next = torch.where(next_id == self.bos_next,
  294. float('-inf'), 0) # b, subs_l
  295. mask = torch.concat([mask_pad, mask_next], 1)
  296. # prompt_next = self.char_embed(prompt_id)
  297. ques_next_i = ques_next
  298. visual_f_i = visual_f[2:3]
  299. for layer in self.cmff_decoder:
  300. ques_next_i = layer(ques_next_i, prompt_next,
  301. visual_f_i, mask.unsqueeze(1))
  302. logits_next_i = self.ques1_head(
  303. self.norm_pred(ques_next_i))
  304. logits = F.softmax(logits_next_i, -1)
  305. pred_id_i = logits.argmax(-1)
  306. next_logits_all.append(logits)
  307. next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1)
  308. if next_id.equal(pre_id):
  309. break
  310. next_logits_all = torch.concat(next_logits_all, 1)
  311. next_logits_all_ = torch.concat(
  312. [next_logits_all_, next_logits_all], 1)
  313. return torch.concat(
  314. [next_logits_all_, pre_logits_all_[:, self.sub_str_len:]],
  315. 1)
  316. else:
  317. return torch.concat(next_logits_all + pre_logits_all[::-1], 1)
  318. def forward_test_bi_attn(self, x):
  319. self.attn_maps = []
  320. if not self.ds:
  321. visual_f = x + self.vis_pos_embed
  322. elif self.pos2d:
  323. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  324. visual_f = x.flatten(2).transpose(1, 2)
  325. else:
  326. visual_f = x
  327. bs = 2
  328. if 1:
  329. next = self.next_token
  330. pre = self.pre_token
  331. next_pre = torch.concat([next, pre], 0)
  332. next_pre = next_pre.squeeze(1) #2, 1, dim
  333. prompt_next_embed = self.prompt_next_embed.squeeze(1)
  334. prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
  335. next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.device)
  336. pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.device)
  337. # prompt_next_bos = self.char_embed(prompt_id)
  338. # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device)
  339. next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device)
  340. pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device)
  341. next_logits_all = []
  342. pre_logits_all = []
  343. attn_map_next = []
  344. attn_map_pre = []
  345. mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.device)
  346. for j in range(0, min(70, self.max_len-1)):
  347. prompt_char_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim
  348. prompt_char_pre = torch.concat([prompt_pre_embed[:, :1, :], prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)], 1) # b, sub_l, dim
  349. prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 0) #2, 6, dim
  350. # prompt_char = prompt_char.flatten(0, 1)
  351. mask_next = torch.where(next_id == self.bos_next, float('-inf'), 0) # b, subs_l
  352. mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'), 0) # b, subs_l
  353. mask = torch.concat([mask_next, mask_pre], 0) #2, 5
  354. mask = torch.concat([mask_pad, mask], 1) # 2, 6
  355. pred_token = next_pre
  356. visual_f_i = visual_f[:2] # 2 l dim
  357. for layer in self.cmff_decoder:
  358. pred_token = layer(pred_token, prompt_char, visual_f_i, mask.unsqueeze(1))
  359. logits_next_i = self.ques1_head(self.norm_pred(pred_token))
  360. logits = F.softmax(logits_next_i, -1)
  361. pred_id_i = logits.argmax(-1) #2, 1
  362. # print(pred_id_i.shape)
  363. next_pred_id_list[:, j:j+1] = pred_id_i[:1]
  364. pre_pred_id_list[:, j:j+1] = pred_id_i[1:2]
  365. if not (next_pred_id_list == self.eos).any(dim=-1).all():
  366. next_logits_all.append(logits[:1])
  367. attn_map_next.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[0])
  368. next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
  369. if not (pre_pred_id_list == self.eos).any(dim=-1).all():
  370. pre_logits_all.append(logits[1:2])
  371. attn_map_pre.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[1])
  372. pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
  373. if (next_pred_id_list == self.eos).any(dim=-1).all() and (pre_pred_id_list == self.eos).any(dim=-1).all():
  374. break
  375. # print(next_id, pre_id)
  376. # exit(0)
  377. if len(next_logits_all) > self.sub_str_len and len(pre_logits_all) > self.sub_str_len:
  378. next_logits_all_ = torch.concat(next_logits_all[:-1], 1) # 1, l
  379. pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1], 1) #1, l
  380. next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
  381. pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
  382. next_logits_all_mid = []
  383. attn_map_next_mid = []
  384. ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
  385. mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.device)
  386. for j in range(0, min(70, self.max_len-1)):
  387. prompt_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim
  388. mask_next = torch.where(next_id == self.bos_next, float('-inf'), 0) # b, subs_l
  389. mask = torch.concat([mask_pad, mask_next], 1)
  390. # prompt_next = self.char_embed(prompt_id)
  391. ques_next_i = ques_next
  392. visual_f_i = visual_f[2:3]
  393. for layer in self.cmff_decoder:
  394. ques_next_i = layer(ques_next_i, prompt_next, visual_f_i, mask.unsqueeze(1))
  395. logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
  396. attn_map_next_mid.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[0])
  397. logits = F.softmax(logits_next_i, -1)
  398. pred_id_i = logits.argmax(-1)
  399. next_logits_all_mid.append(logits)
  400. next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1)
  401. if next_id.equal(pre_id):
  402. break
  403. next_logits_all_mid = torch.concat(next_logits_all_mid, 1)
  404. # next_logits_all_ = torch.concat([next_logits_all_, next_logits_all], 1)
  405. self.attn_maps = [attn_map_next, attn_map_next_mid, attn_map_pre[::-1]]
  406. return [torch.concat(next_logits_all, 1), next_logits_all_mid, torch.concat(pre_logits_all[::-1], 1)]
  407. else:
  408. self.attn_maps = [attn_map_next, attn_map_pre[::-1]]
  409. return [torch.concat(next_logits_all, 1), torch.concat(pre_logits_all[::-1], 1)]
  410. def forward_test(self, x):
  411. self.attn_maps = []
  412. if not self.ds:
  413. visual_f = x + self.vis_pos_embed
  414. elif self.pos2d:
  415. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  416. visual_f = x.flatten(2).transpose(1, 2)
  417. else:
  418. visual_f = x
  419. bs = x.shape[0]
  420. if self.next_mode:
  421. ques_next = self.next_token.tile([bs, 1, 1, 1]).squeeze(1)
  422. prompt_next_embed = self.prompt_next_embed.tile([bs, 1, 1,
  423. 1]).squeeze(1)
  424. prompt_id = torch.full([bs, self.sub_str_len],
  425. self.bos_next,
  426. dtype=torch.long,
  427. device=x.device)
  428. pred_id_list = torch.full([bs, self.max_len],
  429. self.ignore_index,
  430. dtype=torch.long,
  431. device=x.device)
  432. logits_all = []
  433. mask_pad = torch.zeros([bs, 1],
  434. dtype=torch.float32,
  435. device=x.device)
  436. for j in range(0, self.max_len - 1):
  437. prompt_next = torch.concat([
  438. prompt_next_embed[:, :1, :],
  439. prompt_next_embed[:, 1:, :] + self.char_embed(prompt_id)
  440. ], 1) # b, sub_l, dim
  441. mask_next = torch.where(prompt_id == self.bos_next,
  442. float('-inf'), 0) # b, subs_l
  443. mask = torch.concat([mask_pad, mask_next], 1)
  444. ques_next_i = ques_next
  445. visual_f_i = visual_f
  446. for layer in self.cmff_decoder:
  447. ques_next_i = layer(ques_next_i, prompt_next, visual_f_i,
  448. mask.unsqueeze(1))
  449. self.attn_maps.append(
  450. self.cmff_decoder[-1].question_to_images_cross_attn.
  451. attn_map[0])
  452. logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
  453. logits = F.softmax(logits_next_i, -1)
  454. pred_id_i = logits.argmax(-1)
  455. logits_all.append(logits)
  456. pred_id_list[:, j:j + 1] = pred_id_i
  457. if (pred_id_list == self.eos).any(dim=-1).all():
  458. break
  459. prompt_id = torch.concat(
  460. [
  461. prompt_id[:, 1:, ],
  462. pred_id_i,
  463. ],
  464. 1,
  465. )
  466. return torch.concat(logits_all, 1)
  467. else:
  468. ques_next = self.pre_token.tile([bs, 1, 1, 1]).squeeze(1)
  469. prompt_pre_embed = self.prompt_pre_embed.tile([bs, 1, 1,
  470. 1]).squeeze(1)
  471. prompt_id = torch.full([bs, self.sub_str_len],
  472. self.bos_pre,
  473. dtype=torch.long,
  474. device=x.device)
  475. pred_id_list = torch.full([bs, self.max_len],
  476. self.ignore_index,
  477. dtype=torch.long,
  478. device=x.device)
  479. logits_all = []
  480. mask_pad = torch.zeros([bs, 1],
  481. dtype=torch.float32,
  482. device=x.device)
  483. for j in range(0, self.max_len - 1):
  484. prompt_next = torch.concat([
  485. prompt_pre_embed[:, :1, :],
  486. prompt_pre_embed[:, 1:, :] + self.char_embed(prompt_id)
  487. ], 1) # b, sub_l, dim
  488. mask_next = torch.where(prompt_id == self.bos_pre,
  489. float('-inf'), 0) # b, subs_l
  490. mask = torch.concat([mask_pad, mask_next], 1)
  491. ques_next_i = ques_next
  492. visual_f_i = visual_f
  493. for layer in self.cmff_decoder:
  494. ques_next_i = layer(ques_next_i, prompt_next, visual_f_i,
  495. mask.unsqueeze(1))
  496. logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
  497. logits = F.softmax(logits_next_i, -1)
  498. pred_id_i = logits.argmax(-1)
  499. logits_all.append(logits)
  500. pred_id_list[:, j:j + 1] = pred_id_i
  501. if (pred_id_list == self.eos).any(dim=-1).all():
  502. break
  503. prompt_id = torch.concat(
  504. [
  505. pred_id_i,
  506. prompt_id[:, :-1, ],
  507. ],
  508. 1,
  509. )
  510. return torch.concat(logits_all, 1)
  511. def forward_train(self, x, targets=None):
  512. bs = x.shape[0]
  513. if not self.ds:
  514. visual_f = x + self.vis_pos_embed
  515. elif self.pos2d:
  516. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  517. else:
  518. visual_f = x
  519. max_len_curr = targets[3].max()
  520. subs = targets[1][:, :max_len_curr, :] # b, n, subs_l
  521. mask_next = torch.where(subs == self.bos_next, float('-inf'),
  522. 0) # b, n, subs_l
  523. prompt_next_embed = self.prompt_next_embed.tile(
  524. [bs, max_len_curr, 1, 1])
  525. prompt_char_next = torch.concat([
  526. prompt_next_embed[:, :, :1, :],
  527. prompt_next_embed[:, :, 1:, :] + self.char_embed(subs)
  528. ], 2) # b, n, subs_l, dim
  529. next = self.next_token.tile([bs, max_len_curr, 1, 1])
  530. max_len_curr_pre = targets[6].max()
  531. subs = targets[4][:, :max_len_curr_pre, :] # b, n, subs_l
  532. mask_pre = torch.where(subs == self.bos_pre, float('-inf'),
  533. 0) # b, n, subs_l
  534. prompt_pre_embed = self.prompt_pre_embed.tile(
  535. [bs, max_len_curr_pre, 1, 1])
  536. prompt_char_pre = torch.concat([
  537. prompt_pre_embed[:, :, :1, :],
  538. prompt_pre_embed[:, :, 1:, :] + self.char_embed(subs)
  539. ], 2) # b, n, sub_l, dim
  540. pre = self.pre_token.tile([bs, max_len_curr_pre, 1, 1]) # b, n, 1, dim
  541. prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 1)
  542. next_pre = torch.concat([next, pre], 1)
  543. mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1],
  544. dtype=torch.float32,
  545. device=x.device)
  546. mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1)
  547. mask = torch.concat([mask_pad, mask], 1)
  548. next_pre = next_pre.flatten(0, 1)
  549. prompt_char = prompt_char.flatten(0, 1)
  550. for layer in self.cmff_decoder:
  551. next_pre = layer(next_pre, prompt_char, visual_f,
  552. mask.unsqueeze(1))
  553. answer1_pred = self.ques1_head(self.norm_pred(next_pre))
  554. logits = answer1_pred[:, :max_len_curr]
  555. label = torch.concat(
  556. [targets[2][:, :max_len_curr], targets[5][:, :max_len_curr_pre]],
  557. 1)
  558. loss1 = F.cross_entropy(answer1_pred.flatten(0, 1),
  559. label.flatten(0, 1),
  560. ignore_index=self.ignore_index,
  561. reduction='mean')
  562. loss = {'loss': loss1}
  563. return [loss, logits]