smtr_decoder_nattn.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torch.nn.init import ones_, trunc_normal_, zeros_
  6. from openrec.modeling.common import DropPath, Identity
  7. from openrec.modeling.decoders.cppd_decoder import DecoderLayer
  8. from openrec.modeling.decoders.nrtr_decoder import Embeddings
  9. class CrossAttention(nn.Module):
  10. def __init__(
  11. self,
  12. dim,
  13. num_heads=8,
  14. qkv_bias=False,
  15. qk_scale=None,
  16. attn_drop=0.0,
  17. proj_drop=0.0,
  18. ):
  19. super().__init__()
  20. self.num_heads = num_heads
  21. head_dim = dim // num_heads
  22. self.scale = qk_scale or head_dim**-0.5
  23. self.q = nn.Linear(dim, dim, bias=qkv_bias)
  24. self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
  25. self.attn_drop = nn.Dropout(attn_drop)
  26. self.proj = nn.Linear(dim, dim)
  27. self.proj_drop = nn.Dropout(proj_drop)
  28. def forward(self, q, kv, key_mask=None):
  29. N, C = kv.shape[1:]
  30. QN = q.shape[1]
  31. q = self.q(q).reshape([-1, QN, self.num_heads,
  32. C // self.num_heads]).transpose(1, 2)
  33. q = q * self.scale
  34. k, v = self.kv(kv).reshape(
  35. [-1, N, 2, self.num_heads,
  36. C // self.num_heads]).permute(2, 0, 3, 1, 4)
  37. attn = q.matmul(k.transpose(2, 3))
  38. if key_mask is not None:
  39. attn = attn + key_mask.unsqueeze(1)
  40. attn = F.softmax(attn, -1)
  41. if not self.training:
  42. self.attn_map = attn
  43. attn = self.attn_drop(attn)
  44. x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
  45. x = self.proj(x)
  46. x = self.proj_drop(x)
  47. return x
  48. class SSMatchLayer(nn.Module):
  49. def __init__(
  50. self,
  51. dim,
  52. nextq2subs_head2=None,
  53. dynq2img_heads=2,
  54. mlp_ratio=4.0,
  55. qkv_bias=False,
  56. qk_scale=None,
  57. drop=0.0,
  58. attn_drop=0.0,
  59. drop_path=0.0,
  60. act_layer=nn.GELU,
  61. num_layer=2,
  62. epsilon=1e-6,
  63. ):
  64. super().__init__()
  65. self.dim = dim
  66. if nextq2subs_head2 is None:
  67. nextq2subs_head2 = dim // 32
  68. self.normq1 = nn.LayerNorm(dim, eps=epsilon)
  69. self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
  70. self.images_to_question_cross_attn = CrossAttention(
  71. dim,
  72. num_heads=nextq2subs_head2,
  73. qkv_bias=qkv_bias,
  74. qk_scale=qk_scale,
  75. attn_drop=attn_drop,
  76. proj_drop=drop)
  77. self.normq2 = nn.LayerNorm(dim, eps=epsilon)
  78. # self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
  79. dpr = np.linspace(0, drop_path, num_layer)
  80. self.question_to_images_cross_attn = nn.ModuleList([
  81. DecoderLayer(
  82. dim=dim,
  83. num_heads=dynq2img_heads,
  84. mlp_ratio=4.0,
  85. qkv_bias=True,
  86. drop_path=dpr[i],
  87. act_layer=act_layer,
  88. ) for i in range(num_layer)
  89. ])
  90. # CrossAttention(
  91. # dim,
  92. # num_heads=dynq2img_heads,
  93. # qkv_bias=qkv_bias,
  94. # qk_scale=qk_scale,
  95. # attn_drop=attn_drop,
  96. # proj_drop=drop)
  97. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  98. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  99. def forward(self, question_f, prompt_f, visual_f, mask=None):
  100. question_f = question_f + self.drop_path(
  101. self.images_to_question_cross_attn(self.normq1(question_f),
  102. self.normkv1(prompt_f), mask))
  103. question_f = question_f.reshape(visual_f.shape[0], -1, self.dim)
  104. question_f = self.normq2(question_f)
  105. # kv = self.normkv2(visual_f)
  106. for layer in self.question_to_images_cross_attn:
  107. question_f = layer(question_f, visual_f)
  108. return question_f
  109. class SMTRDecoderNumAttn(nn.Module):
  110. def __init__(self,
  111. in_channels,
  112. out_channels,
  113. num_layer=2,
  114. nextq2subs_head2=None,
  115. dynq2img_heads=2,
  116. drop_path_rate=0.1,
  117. max_len=25,
  118. vis_seq=50,
  119. ds=False,
  120. pos2d=False,
  121. max_size=[8, 32],
  122. sub_str_len=5,
  123. next_mode=True,
  124. infer_aug=False,
  125. **kwargs):
  126. super(SMTRDecoderNumAttn, self).__init__()
  127. self.out_channels = out_channels
  128. dim = in_channels
  129. self.dim = dim
  130. self.max_len = max_len + 3 # max_len + eos + bos
  131. self.char_embed = Embeddings(d_model=dim,
  132. vocab=self.out_channels,
  133. scale_embedding=True)
  134. self.ignore_index = out_channels - 1
  135. self.sub_str_len = sub_str_len
  136. self.bos_next = out_channels - 3
  137. self.bos_pre = out_channels - 2
  138. self.eos = 0
  139. self.next_mode = next_mode
  140. self.infer_aug = infer_aug
  141. self.cmff_decoder = SSMatchLayer(dim=dim,
  142. nextq2subs_head2=nextq2subs_head2,
  143. dynq2img_heads=dynq2img_heads,
  144. mlp_ratio=4.0,
  145. qkv_bias=True,
  146. drop_path=drop_path_rate,
  147. num_layer=num_layer)
  148. self.ds = ds
  149. self.pos2d = pos2d
  150. if not ds:
  151. self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
  152. dtype=torch.float32),
  153. requires_grad=True)
  154. trunc_normal_(self.vis_pos_embed, std=0.02)
  155. elif pos2d:
  156. pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
  157. dtype=torch.float32)
  158. trunc_normal_(pos_embed, mean=0, std=0.02)
  159. self.vis_pos_embed = nn.Parameter(pos_embed.transpose(
  160. 1, 2).reshape(1, dim, max_size[0], max_size[1]),
  161. requires_grad=True)
  162. self.next_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
  163. dtype=torch.float32),
  164. requires_grad=True)
  165. self.pre_token = nn.Parameter(torch.zeros([1, 1, 1, dim],
  166. dtype=torch.float32),
  167. requires_grad=True)
  168. self.prompt_next_embed = nn.Parameter(torch.zeros(
  169. [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
  170. requires_grad=True)
  171. self.prompt_pre_embed = nn.Parameter(torch.zeros(
  172. [1, 1, self.sub_str_len + 1, dim], dtype=torch.float32),
  173. requires_grad=True)
  174. self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
  175. self.ques1_head = nn.Linear(dim, self.out_channels - 3)
  176. trunc_normal_(self.next_token, std=0.02)
  177. trunc_normal_(self.pre_token, std=0.02)
  178. trunc_normal_(self.prompt_pre_embed, std=0.02)
  179. trunc_normal_(self.prompt_next_embed, std=0.02)
  180. self.apply(self._init_weights)
  181. def _init_weights(self, m):
  182. if isinstance(m, nn.Linear):
  183. trunc_normal_(m.weight, std=0.02)
  184. if isinstance(m, nn.Linear) and m.bias is not None:
  185. zeros_(m.bias)
  186. elif isinstance(m, nn.LayerNorm):
  187. zeros_(m.bias)
  188. ones_(m.weight)
  189. @torch.jit.ignore
  190. def no_weight_decay(self):
  191. return {'vis_pos_embed', 'pre_token', 'next_token', 'char_embed'}
  192. def forward(self, x, data=None):
  193. if self.training:
  194. return self.forward_train(x, data)
  195. else:
  196. if self.infer_aug:
  197. return self.forward_test_bi(x)
  198. return self.forward_test(x)
  199. def forward_test_bi(self, x):
  200. # self.attn_maps = []
  201. if not self.ds:
  202. visual_f = x + self.vis_pos_embed
  203. elif self.pos2d:
  204. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  205. visual_f = x.flatten(2).transpose(1, 2)
  206. else:
  207. visual_f = x
  208. bs = 2
  209. if 1:
  210. next = self.next_token
  211. pre = self.pre_token
  212. next_pre = torch.concat([next, pre], 0)
  213. next_pre = next_pre.squeeze(1) #2, 1, dim
  214. prompt_next_embed = self.prompt_next_embed.squeeze(1)
  215. prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
  216. next_id = torch.full([1, self.sub_str_len],
  217. self.bos_next,
  218. dtype=torch.long,
  219. device=x.get_device())
  220. pre_id = torch.full([1, self.sub_str_len],
  221. self.bos_pre,
  222. dtype=torch.long,
  223. device=x.get_device())
  224. # prompt_next_bos = self.char_embed(prompt_id)
  225. # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device())
  226. next_pred_id_list = torch.full([1, self.max_len],
  227. self.ignore_index,
  228. dtype=torch.long,
  229. device=x.get_device())
  230. pre_pred_id_list = torch.full([1, self.max_len],
  231. self.ignore_index,
  232. dtype=torch.long,
  233. device=x.get_device())
  234. next_logits_all = []
  235. pre_logits_all = []
  236. mask_pad = torch.zeros([bs, 1],
  237. dtype=torch.float32,
  238. device=x.get_device())
  239. for j in range(0, min(70, self.max_len - 1)):
  240. prompt_char_next = torch.concat([
  241. prompt_next_embed[:, :1, :],
  242. prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
  243. ], 1) # b, sub_l, dim
  244. prompt_char_pre = torch.concat([
  245. prompt_pre_embed[:, :1, :],
  246. prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)
  247. ], 1) # b, sub_l, dim
  248. prompt_char = torch.concat([prompt_char_next, prompt_char_pre],
  249. 0) #2, 6, dim
  250. # prompt_char = prompt_char.flatten(0, 1)
  251. mask_next = torch.where(next_id == self.bos_next,
  252. float('-inf'), 0) # b, subs_l
  253. mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'),
  254. 0) # b, subs_l
  255. mask = torch.concat([mask_next, mask_pre], 0) #2, 5
  256. mask = torch.concat([mask_pad, mask], 1) # 2, 6
  257. pred_token = next_pre
  258. visual_f_i = visual_f[:2] # 2 l dim
  259. pred_token = self.cmff_decoder(pred_token, prompt_char,
  260. visual_f_i, mask.unsqueeze(1))
  261. logits_next_i = self.ques1_head(self.norm_pred(pred_token))
  262. logits = F.softmax(logits_next_i, -1)
  263. pred_id_i = logits.argmax(-1) #2, 1
  264. # print(pred_id_i.shape)
  265. next_pred_id_list[:, j:j + 1] = pred_id_i[:1]
  266. pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2]
  267. if not (next_pred_id_list == self.eos).any(dim=-1).all():
  268. next_logits_all.append(logits[:1])
  269. next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
  270. if not (pre_pred_id_list == self.eos).any(dim=-1).all():
  271. pre_logits_all.append(logits[1:2])
  272. pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
  273. if (next_pred_id_list == self.eos).any(dim=-1).all() and (
  274. pre_pred_id_list == self.eos).any(dim=-1).all():
  275. break
  276. # print(next_id, pre_id)
  277. # exit(0)
  278. if len(next_logits_all) > self.sub_str_len and len(
  279. pre_logits_all) > self.sub_str_len:
  280. next_logits_all_ = torch.concat(next_logits_all[:-1],
  281. 1) # 1, l
  282. pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1],
  283. 1) #1, l
  284. next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
  285. pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
  286. next_logits_all = []
  287. ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
  288. mask_pad = torch.zeros([1, 1],
  289. dtype=torch.float32,
  290. device=x.get_device())
  291. for j in range(0, min(70, self.max_len - 1)):
  292. prompt_next = torch.concat([
  293. prompt_next_embed[:, :1, :],
  294. prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
  295. ], 1) # b, sub_l, dim
  296. mask_next = torch.where(next_id == self.bos_next,
  297. float('-inf'), 0) # b, subs_l
  298. mask = torch.concat([mask_pad, mask_next], 1)
  299. # prompt_next = self.char_embed(prompt_id)
  300. ques_next_i = ques_next
  301. visual_f_i = visual_f[2:3]
  302. ques_next_i = self.cmff_decoder(ques_next_i, prompt_next,
  303. visual_f_i,
  304. mask.unsqueeze(1))
  305. logits_next_i = self.ques1_head(
  306. self.norm_pred(ques_next_i))
  307. logits = F.softmax(logits_next_i, -1)
  308. pred_id_i = logits.argmax(-1)
  309. next_logits_all.append(logits)
  310. next_id = torch.concat([next_id[:, 1:, ], pred_id_i], 1)
  311. if next_id.equal(pre_id):
  312. break
  313. next_logits_all = torch.concat(next_logits_all, 1)
  314. next_logits_all_ = torch.concat(
  315. [next_logits_all_, next_logits_all], 1)
  316. return torch.concat(
  317. [next_logits_all_, pre_logits_all_[:, self.sub_str_len:]],
  318. 1)
  319. else:
  320. return torch.concat(next_logits_all + pre_logits_all[::-1], 1)
  321. def forward_test(self, x):
  322. # self.attn_maps = []
  323. if not self.ds:
  324. visual_f = x + self.vis_pos_embed
  325. elif self.pos2d:
  326. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  327. visual_f = x.flatten(2).transpose(1, 2)
  328. else:
  329. visual_f = x
  330. bs = x.shape[0]
  331. if self.next_mode:
  332. ques_next = self.next_token.tile([bs, 1, 1, 1]).squeeze(1)
  333. prompt_next_embed = self.prompt_next_embed.tile([bs, 1, 1,
  334. 1]).squeeze(1)
  335. prompt_id = torch.full([bs, self.sub_str_len],
  336. self.bos_next,
  337. dtype=torch.long,
  338. device=x.get_device())
  339. pred_id_list = torch.full([bs, self.max_len],
  340. self.ignore_index,
  341. dtype=torch.long,
  342. device=x.get_device())
  343. logits_all = []
  344. mask_pad = torch.zeros([bs, 1],
  345. dtype=torch.float32,
  346. device=x.get_device())
  347. for j in range(0, self.max_len - 1):
  348. prompt_next = torch.concat([
  349. prompt_next_embed[:, :1, :],
  350. prompt_next_embed[:, 1:, :] + self.char_embed(prompt_id)
  351. ], 1) # b, sub_l, dim
  352. mask_next = torch.where(prompt_id == self.bos_next,
  353. float('-inf'), 0) # b, subs_l
  354. mask = torch.concat([mask_pad, mask_next], 1)
  355. ques_next_i = ques_next
  356. visual_f_i = visual_f
  357. ques_next_i = self.cmff_decoder(ques_next_i, prompt_next,
  358. visual_f_i, mask.unsqueeze(1))
  359. # self.attn_maps.append(
  360. # self.cmff_decoder[-1].question_to_images_cross_attn.
  361. # attn_map[0])
  362. logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
  363. logits = F.softmax(logits_next_i, -1)
  364. pred_id_i = logits.argmax(-1)
  365. logits_all.append(logits)
  366. pred_id_list[:, j:j + 1] = pred_id_i
  367. if (pred_id_list == self.eos).any(dim=-1).all():
  368. break
  369. prompt_id = torch.concat(
  370. [
  371. prompt_id[:, 1:, ],
  372. pred_id_i,
  373. ],
  374. 1,
  375. )
  376. return torch.concat(logits_all, 1)
  377. else:
  378. ques_next = self.pre_token.tile([bs, 1, 1, 1]).squeeze(1)
  379. prompt_pre_embed = self.prompt_pre_embed.tile([bs, 1, 1,
  380. 1]).squeeze(1)
  381. prompt_id = torch.full([bs, self.sub_str_len],
  382. self.bos_pre,
  383. dtype=torch.long,
  384. device=x.get_device())
  385. pred_id_list = torch.full([bs, self.max_len],
  386. self.ignore_index,
  387. dtype=torch.long,
  388. device=x.get_device())
  389. logits_all = []
  390. mask_pad = torch.zeros([bs, 1],
  391. dtype=torch.float32,
  392. device=x.get_device())
  393. for j in range(0, self.max_len - 1):
  394. prompt_next = torch.concat([
  395. prompt_pre_embed[:, :1, :],
  396. prompt_pre_embed[:, 1:, :] + self.char_embed(prompt_id)
  397. ], 1) # b, sub_l, dim
  398. mask_next = torch.where(prompt_id == self.bos_pre,
  399. float('-inf'), 0) # b, subs_l
  400. mask = torch.concat([mask_pad, mask_next], 1)
  401. ques_next_i = ques_next
  402. visual_f_i = visual_f
  403. ques_next_i = self.cmff_decoder(ques_next_i, prompt_next,
  404. visual_f_i, mask.unsqueeze(1))
  405. logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
  406. logits = F.softmax(logits_next_i, -1)
  407. pred_id_i = logits.argmax(-1)
  408. logits_all.append(logits)
  409. pred_id_list[:, j:j + 1] = pred_id_i
  410. if (pred_id_list == self.eos).any(dim=-1).all():
  411. break
  412. prompt_id = torch.concat(
  413. [
  414. pred_id_i,
  415. prompt_id[:, :-1, ],
  416. ],
  417. 1,
  418. )
  419. return torch.concat(logits_all, 1)
  420. def forward_train(self, x, targets=None):
  421. bs = x.shape[0]
  422. if not self.ds:
  423. visual_f = x + self.vis_pos_embed
  424. elif self.pos2d:
  425. visual_f = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
  426. else:
  427. visual_f = x
  428. max_len_curr = targets[3].max()
  429. subs = targets[1][:, :max_len_curr, :] # b, n, subs_l
  430. mask_next = torch.where(subs == self.bos_next, float('-inf'),
  431. 0) # b, n, subs_l
  432. prompt_next_embed = self.prompt_next_embed.tile(
  433. [bs, max_len_curr, 1, 1])
  434. prompt_char_next = torch.concat([
  435. prompt_next_embed[:, :, :1, :],
  436. prompt_next_embed[:, :, 1:, :] + self.char_embed(subs)
  437. ], 2) # b, n, sub_l, dim
  438. next = self.next_token.tile([bs, max_len_curr, 1, 1])
  439. max_len_curr_pre = targets[6].max()
  440. subs = targets[4][:, :max_len_curr_pre, :] # b, n, subs_l
  441. mask_pre = torch.where(subs == self.bos_pre, float('-inf'),
  442. 0) # b, n, subs_l
  443. prompt_pre_embed = self.prompt_pre_embed.tile(
  444. [bs, max_len_curr_pre, 1, 1])
  445. prompt_char_pre = torch.concat([
  446. prompt_pre_embed[:, :, :1, :],
  447. prompt_pre_embed[:, :, 1:, :] + self.char_embed(subs)
  448. ], 2) # b, n, sub_l, dim
  449. pre = self.pre_token.tile([bs, max_len_curr_pre, 1, 1]) # b, n, 1, dim
  450. prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 1)
  451. next_pre = torch.concat([next, pre], 1)
  452. mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1],
  453. dtype=torch.float32,
  454. device=x.get_device())
  455. mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1)
  456. mask = torch.concat([mask_pad, mask], 1)
  457. next_pre = next_pre.flatten(0, 1)
  458. prompt_char = prompt_char.flatten(0, 1)
  459. next_pre = self.cmff_decoder(next_pre, prompt_char, visual_f,
  460. mask.unsqueeze(1))
  461. answer1_pred = self.ques1_head(self.norm_pred(next_pre))
  462. logits = answer1_pred[:, :max_len_curr]
  463. label = torch.concat(
  464. [targets[2][:, :max_len_curr], targets[5][:, :max_len_curr_pre]],
  465. 1)
  466. loss1 = F.cross_entropy(answer1_pred.flatten(0, 1),
  467. label.flatten(0, 1),
  468. ignore_index=self.ignore_index,
  469. reduction='mean')
  470. loss = {'loss': loss1}
  471. return [loss, logits]