matrn_decoder.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """This code is refer from:
  2. https://github.com/byeonghu-na/MATRN
  3. """
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from openrec.modeling.decoders.abinet_decoder import BCNLanguage, PositionAttention, _get_length
  9. from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
  10. class BaseSemanticVisual_backbone_feature(nn.Module):
  11. def __init__(self,
  12. d_model=512,
  13. nhead=8,
  14. num_layers=4,
  15. dim_feedforward=2048,
  16. dropout=0.0,
  17. alignment_mask_example_prob=0.9,
  18. alignment_mask_candidate_prob=0.9,
  19. alignment_num_vis_mask=10,
  20. max_length=25,
  21. num_classes=37):
  22. super().__init__()
  23. self.mask_example_prob = alignment_mask_example_prob
  24. self.mask_candidate_prob = alignment_mask_candidate_prob #ifnone(config.model_alignment_mask_candidate_prob, 0.9)
  25. self.num_vis_mask = alignment_num_vis_mask
  26. self.nhead = nhead
  27. self.d_model = d_model
  28. self.max_length = max_length + 1 # additional stop token
  29. self.model1 = nn.ModuleList([
  30. TransformerBlock(
  31. d_model=d_model,
  32. nhead=nhead,
  33. dim_feedforward=dim_feedforward,
  34. attention_dropout_rate=dropout,
  35. residual_dropout_rate=dropout,
  36. with_self_attn=True,
  37. with_cross_attn=False,
  38. ) for i in range(num_layers)
  39. ])
  40. self.pos_encoder_tfm = PositionalEncoding(dim=d_model,
  41. dropout=0,
  42. max_len=1024)
  43. self.model2_vis = PositionAttention(
  44. max_length=self.max_length, # additional stop token
  45. in_channels=d_model,
  46. num_channels=d_model // 8,
  47. mode='nearest',
  48. )
  49. self.cls_vis = nn.Linear(d_model, num_classes)
  50. self.cls_sem = nn.Linear(d_model, num_classes)
  51. self.w_att = nn.Linear(2 * d_model, d_model)
  52. v_token = torch.empty((1, d_model))
  53. self.v_token = nn.Parameter(v_token)
  54. torch.nn.init.uniform_(self.v_token, -0.001, 0.001)
  55. self.cls = nn.Linear(d_model, num_classes)
  56. def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None):
  57. """
  58. Args:
  59. l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
  60. v_feature: (N, E, H, W)
  61. lengths_l: (N,)
  62. v_attn: (N, T, H, W)
  63. l_logits: (N, T, C)
  64. texts: (N, T, C)
  65. """
  66. N, E, H, W = v_feature.size()
  67. v_feature = v_feature.flatten(2, 3).transpose(1, 2) #(N, H*W, E)
  68. v_attn = v_attn.flatten(2, 3) # (N, T, H*W)
  69. if self.training:
  70. for idx, length in enumerate(lengths_l):
  71. if np.random.random() <= self.mask_example_prob:
  72. l_idx = np.random.randint(int(length))
  73. v_random_idx = v_attn[idx, l_idx].argsort(
  74. descending=True).cpu().numpy()[:self.num_vis_mask, ]
  75. v_random_idx = v_random_idx[np.random.random(
  76. v_random_idx.shape) <= self.mask_candidate_prob]
  77. v_feature[idx, v_random_idx] = self.v_token
  78. zeros = v_feature.new_zeros((N, H * W, E)) # (N, H*W, E)
  79. base_pos = self.pos_encoder_tfm(zeros) # (N, H*W, E)
  80. base_pos = torch.bmm(v_attn, base_pos) # (N, T, E)
  81. l_feature = l_feature + base_pos
  82. sv_feature = torch.cat((v_feature, l_feature), dim=1) # (H*W+T, N, E)
  83. for decoder_layer in self.model1:
  84. sv_feature = decoder_layer(sv_feature) # (H*W+T, N, E)
  85. sv_to_v_feature = sv_feature[:, :H * W] # (N, H*W, E)
  86. sv_to_s_feature = sv_feature[:, H * W:] # (N, T, E)
  87. sv_to_v_feature = sv_to_v_feature.transpose(1, 2).reshape(N, E, H, W)
  88. sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E)
  89. sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C)
  90. pt_v_lengths = _get_length(sv_to_v_logits) # (N,)
  91. sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C)
  92. pt_s_lengths = _get_length(sv_to_s_logits) # (N,)
  93. f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2)
  94. f_att = torch.sigmoid(self.w_att(f))
  95. output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature
  96. logits = self.cls(output) # (N, T, C)
  97. pt_lengths = _get_length(logits)
  98. return {
  99. 'logits': logits,
  100. 'pt_lengths': pt_lengths,
  101. 'v_logits': sv_to_v_logits,
  102. 'pt_v_lengths': pt_v_lengths,
  103. 's_logits': sv_to_s_logits,
  104. 'pt_s_lengths': pt_s_lengths,
  105. 'name': 'alignment'
  106. }
  107. class MATRNDecoder(nn.Module):
  108. def __init__(self,
  109. in_channels,
  110. out_channels,
  111. nhead=8,
  112. num_layers=3,
  113. dim_feedforward=2048,
  114. dropout=0.1,
  115. max_length=25,
  116. iter_size=3,
  117. **kwargs):
  118. super().__init__()
  119. self.max_length = max_length + 1
  120. d_model = in_channels
  121. self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
  122. self.encoder = nn.ModuleList([
  123. TransformerBlock(
  124. d_model=d_model,
  125. nhead=nhead,
  126. dim_feedforward=dim_feedforward,
  127. attention_dropout_rate=dropout,
  128. residual_dropout_rate=dropout,
  129. with_self_attn=True,
  130. with_cross_attn=False,
  131. ) for _ in range(num_layers)
  132. ])
  133. self.decoder = PositionAttention(
  134. max_length=self.max_length, # additional stop token
  135. in_channels=d_model,
  136. num_channels=d_model // 8,
  137. mode='nearest',
  138. )
  139. self.out_channels = out_channels
  140. self.cls = nn.Linear(d_model, self.out_channels)
  141. self.iter_size = iter_size
  142. if iter_size > 0:
  143. self.language = BCNLanguage(
  144. d_model=d_model,
  145. nhead=nhead,
  146. num_layers=4,
  147. dim_feedforward=dim_feedforward,
  148. dropout=dropout,
  149. max_length=max_length,
  150. num_classes=self.out_channels,
  151. )
  152. # alignment
  153. self.semantic_visual = BaseSemanticVisual_backbone_feature(
  154. d_model=d_model,
  155. nhead=nhead,
  156. num_layers=2,
  157. dim_feedforward=dim_feedforward,
  158. max_length=max_length,
  159. num_classes=self.out_channels)
  160. def forward(self, x, data=None):
  161. # bs, c, h, w
  162. x = x.permute([0, 2, 3, 1]) # bs, h, w, c
  163. _, H, W, C = x.shape
  164. # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
  165. feature = x.flatten(1, 2) # bs, h*w, c
  166. feature = self.pos_encoder(feature) # bs, h*w, c
  167. for encoder_layer in self.encoder:
  168. feature = encoder_layer(feature)
  169. # bs, h*w, c
  170. feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
  171. 2) # bs, c, h, w
  172. v_feature, v_attn_input = self.decoder(feature) # (bs[N], T, E)
  173. vis_logits = self.cls(v_feature) # (bs[N], T, E)
  174. align_lengths = _get_length(vis_logits)
  175. align_logits = vis_logits
  176. all_l_res, all_a_res = [], []
  177. for _ in range(self.iter_size):
  178. tokens = F.softmax(align_logits, dim=-1)
  179. lengths = torch.clamp(
  180. align_lengths, 2,
  181. self.max_length) # TODO: move to language model
  182. l_feature, l_logits = self.language(tokens, lengths)
  183. all_l_res.append(l_logits)
  184. # alignment
  185. lengths_l = _get_length(l_logits)
  186. lengths_l.clamp_(2, self.max_length)
  187. a_res = self.semantic_visual(l_feature,
  188. feature,
  189. lengths_l=lengths_l,
  190. v_attn=v_attn_input)
  191. a_v_res = a_res['v_logits']
  192. # {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'],
  193. # 'name': 'alignment'}
  194. all_a_res.append(a_v_res)
  195. a_s_res = a_res['s_logits']
  196. # {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'],
  197. # 'name': 'alignment'}
  198. align_logits = a_res['logits']
  199. all_a_res.append(a_s_res)
  200. all_a_res.append(align_logits)
  201. align_lengths = a_res['pt_lengths']
  202. if self.training:
  203. return {
  204. 'align': all_a_res,
  205. 'lang': all_l_res,
  206. 'vision': vis_logits
  207. }
  208. else:
  209. return F.softmax(align_logits, -1)