cdistnet_decoder.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam
  5. from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder
  6. def generate_square_subsequent_mask(sz):
  7. r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  8. Unmasked positions are filled with float(0.0).
  9. """
  10. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  11. mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
  12. mask == 1, float(0.0)))
  13. return mask
  14. class SEM_Pre(nn.Module):
  15. def __init__(
  16. self,
  17. d_model=512,
  18. dst_vocab_size=40,
  19. residual_dropout_rate=0.1,
  20. ):
  21. super(SEM_Pre, self).__init__()
  22. self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size)
  23. self.positional_encoding = PositionalEncoding(
  24. dropout=residual_dropout_rate,
  25. dim=d_model,
  26. )
  27. def forward(self, tgt):
  28. tgt = self.embedding(tgt)
  29. tgt = self.positional_encoding(tgt)
  30. tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device)
  31. return tgt, tgt_mask
  32. class POS_Pre(nn.Module):
  33. def __init__(
  34. self,
  35. d_model=512,
  36. ):
  37. super(POS_Pre, self).__init__()
  38. self.pos_encoding = PositionalEncoding(
  39. dropout=0.1,
  40. dim=d_model,
  41. )
  42. self.linear1 = nn.Linear(d_model, d_model)
  43. self.linear2 = nn.Linear(d_model, d_model)
  44. self.norm2 = nn.LayerNorm(d_model)
  45. def forward(self, tgt):
  46. pos = tgt.new_zeros(*tgt.shape)
  47. pos = self.pos_encoding(pos)
  48. pos2 = self.linear2(F.relu(self.linear1(pos)))
  49. pos = self.norm2(pos + pos2)
  50. return pos
  51. class DSF(nn.Module):
  52. def __init__(self, d_model, fusion_num):
  53. super(DSF, self).__init__()
  54. self.w_att = nn.Linear(fusion_num * d_model, d_model)
  55. def forward(self, l_feature, v_feature):
  56. """
  57. Args:
  58. l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
  59. v_feature: (N, T, E) shape the same as l_feature
  60. l_lengths: (N,)
  61. v_lengths: (N,)
  62. """
  63. f = torch.cat((l_feature, v_feature), dim=2)
  64. f_att = torch.sigmoid(self.w_att(f))
  65. output = f_att * v_feature + (1 - f_att) * l_feature
  66. return output
  67. class MDCDP(nn.Module):
  68. r"""
  69. Multi-Domain CharacterDistance Perception
  70. """
  71. def __init__(self, d_model, n_head, d_inner, num_layers):
  72. super(MDCDP, self).__init__()
  73. self.num_layers = num_layers
  74. # step 1 SAE
  75. self.layers_pos = nn.ModuleList([
  76. TransformerBlock(d_model, n_head, d_inner)
  77. for _ in range(num_layers)
  78. ])
  79. # step 2 CBI:
  80. self.layers2 = nn.ModuleList([
  81. TransformerBlock(
  82. d_model,
  83. n_head,
  84. d_inner,
  85. with_self_attn=False,
  86. with_cross_attn=True,
  87. ) for _ in range(num_layers)
  88. ])
  89. self.layers3 = nn.ModuleList([
  90. TransformerBlock(
  91. d_model,
  92. n_head,
  93. d_inner,
  94. with_self_attn=False,
  95. with_cross_attn=True,
  96. ) for _ in range(num_layers)
  97. ])
  98. # step 3 :DSF
  99. self.dynamic_shared_fusion = DSF(d_model, 2)
  100. def forward(
  101. self,
  102. sem,
  103. vis,
  104. pos,
  105. tgt_mask=None,
  106. memory_mask=None,
  107. ):
  108. for i in range(self.num_layers):
  109. # ----------step 1 -----------: SAE: Self-Attention Enhancement
  110. pos = self.layers_pos[i](pos, self_mask=tgt_mask)
  111. # ----------step 2 -----------: CBI: Cross-Branch Interaction
  112. # CBI-V
  113. pos_vis = self.layers2[i](
  114. pos,
  115. vis,
  116. cross_mask=memory_mask,
  117. )
  118. # CBI-S
  119. pos_sem = self.layers3[i](
  120. pos,
  121. sem,
  122. cross_mask=tgt_mask,
  123. )
  124. # ----------step 3 -----------: DSF: Dynamic Shared Fusion
  125. pos = self.dynamic_shared_fusion(pos_vis, pos_sem)
  126. output = pos
  127. return output
  128. class ConvBnRelu(nn.Module):
  129. # adapt padding for kernel_size change
  130. def __init__(
  131. self,
  132. in_channels,
  133. out_channels,
  134. kernel_size,
  135. conv=nn.Conv2d,
  136. stride=2,
  137. inplace=True,
  138. ):
  139. super().__init__()
  140. p_size = [int(k // 2) for k in kernel_size]
  141. # p_size = int(kernel_size//2)
  142. self.conv = conv(
  143. in_channels,
  144. out_channels,
  145. kernel_size=kernel_size,
  146. stride=stride,
  147. padding=p_size,
  148. )
  149. self.bn = nn.BatchNorm2d(out_channels)
  150. self.relu = nn.ReLU(inplace=inplace)
  151. def forward(self, x):
  152. x = self.conv(x)
  153. x = self.bn(x)
  154. x = self.relu(x)
  155. return x
  156. class CDistNetDecoder(nn.Module):
  157. def __init__(self,
  158. in_channels,
  159. out_channels,
  160. n_head=None,
  161. num_encoder_blocks=3,
  162. num_decoder_blocks=3,
  163. beam_size=0,
  164. max_len=25,
  165. residual_dropout_rate=0.1,
  166. add_conv=False,
  167. **kwargs):
  168. super(CDistNetDecoder, self).__init__()
  169. dst_vocab_size = out_channels
  170. self.ignore_index = dst_vocab_size - 1
  171. self.bos = dst_vocab_size - 2
  172. self.eos = 0
  173. self.beam_size = beam_size
  174. self.max_len = max_len
  175. self.add_conv = add_conv
  176. d_model = in_channels
  177. dim_feedforward = d_model * 4
  178. n_head = n_head if n_head is not None else d_model // 32
  179. if add_conv:
  180. self.convbnrelu = ConvBnRelu(
  181. in_channels=in_channels,
  182. out_channels=in_channels,
  183. kernel_size=(1, 3),
  184. stride=(1, 2),
  185. )
  186. if num_encoder_blocks > 0:
  187. self.positional_encoding = PositionalEncoding(
  188. dropout=0.1,
  189. dim=d_model,
  190. )
  191. self.trans_encoder = Transformer_Encoder(
  192. n_layers=num_encoder_blocks,
  193. n_head=n_head,
  194. d_model=d_model,
  195. d_inner=dim_feedforward,
  196. )
  197. else:
  198. self.trans_encoder = None
  199. self.semantic_branch = SEM_Pre(
  200. d_model=d_model,
  201. dst_vocab_size=dst_vocab_size,
  202. residual_dropout_rate=residual_dropout_rate,
  203. )
  204. self.positional_branch = POS_Pre(d_model=d_model)
  205. self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2,
  206. num_decoder_blocks)
  207. self._reset_parameters()
  208. self.tgt_word_prj = nn.Linear(
  209. d_model, dst_vocab_size - 2,
  210. bias=False) # We don't predict <bos> nor <pad>
  211. self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5)
  212. def forward(self, x, data=None):
  213. if self.add_conv:
  214. x = self.convbnrelu(x)
  215. # x = rearrange(x, "b c h w -> b (w h) c")
  216. x = x.flatten(2).transpose(1, 2)
  217. if self.trans_encoder is not None:
  218. x = self.positional_encoding(x)
  219. vis_feat = self.trans_encoder(x, src_mask=None)
  220. else:
  221. vis_feat = x
  222. if self.training:
  223. max_len = data[1].max()
  224. tgt = data[0][:, :1 + max_len]
  225. res = self.forward_train(vis_feat, tgt)
  226. else:
  227. if self.beam_size > 0:
  228. res = self.forward_beam(vis_feat)
  229. else:
  230. res = self.forward_test(vis_feat)
  231. return res
  232. def forward_train(self, vis_feat, tgt):
  233. sem_feat, sem_mask = self.semantic_branch(tgt)
  234. pos_feat = self.positional_branch(sem_feat)
  235. output = self.mdcdp(
  236. sem_feat,
  237. vis_feat,
  238. pos_feat,
  239. tgt_mask=sem_mask,
  240. memory_mask=None,
  241. )
  242. logit = self.tgt_word_prj(output)
  243. return logit
  244. def forward_test(self, vis_feat):
  245. bs = vis_feat.size(0)
  246. dec_seq = torch.full(
  247. (bs, self.max_len + 1),
  248. self.ignore_index,
  249. dtype=torch.int64,
  250. device=vis_feat.device,
  251. )
  252. dec_seq[:, 0] = self.bos
  253. logits = []
  254. for len_dec_seq in range(0, self.max_len):
  255. sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq +
  256. 1])
  257. pos_feat = self.positional_branch(sem_feat)
  258. output = self.mdcdp(
  259. sem_feat,
  260. vis_feat,
  261. pos_feat,
  262. tgt_mask=sem_mask,
  263. memory_mask=None,
  264. )
  265. dec_output = output[:, -1:, :]
  266. word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
  267. logits.append(word_prob)
  268. if len_dec_seq < self.max_len:
  269. # greedy decode. add the next token index to the target input
  270. dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
  271. # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
  272. if (dec_seq == self.eos).any(dim=-1).all():
  273. break
  274. logits = torch.cat(logits, dim=1)
  275. return logits
  276. def forward_beam(self, x):
  277. """Translation work in one batch."""
  278. # to do
  279. def _reset_parameters(self):
  280. r"""Initiate parameters in the transformer model."""
  281. for p in self.parameters():
  282. if p.dim() > 1:
  283. nn.init.xavier_uniform_(p)