ote_decoder.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from torch.nn.init import ones_, trunc_normal_, zeros_
  5. from .nrtr_decoder import TransformerBlock, Embeddings
  6. class CPA(nn.Module):
  7. def __init__(self, dim, max_len=25):
  8. super(CPA, self).__init__()
  9. self.fc1 = nn.Linear(dim, dim)
  10. self.fc2 = nn.Linear(dim, dim)
  11. self.fc3 = nn.Linear(dim, dim)
  12. self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, dim],
  13. dtype=torch.float32),
  14. requires_grad=True)
  15. trunc_normal_(self.pos_embed, std=0.02)
  16. def forward(self, feat):
  17. # feat: B, L, Dim
  18. feat = feat.mean(1).unsqueeze(1) # B, 1, Dim
  19. x = self.fc1(feat) + self.pos_embed # B max_len dim
  20. x = F.softmax(self.fc2(F.tanh(x)), -1) # B max_len dim
  21. x = self.fc3(feat * x) # B max_len dim
  22. return x
  23. class ARDecoder(nn.Module):
  24. def __init__(
  25. self,
  26. in_channels,
  27. out_channels,
  28. nhead=None,
  29. num_decoder_layers=6,
  30. max_len=25,
  31. attention_dropout_rate=0.0,
  32. residual_dropout_rate=0.1,
  33. scale_embedding=True,
  34. ):
  35. super(ARDecoder, self).__init__()
  36. self.out_channels = out_channels
  37. self.ignore_index = out_channels - 1
  38. self.bos = out_channels - 2
  39. self.eos = 0
  40. self.max_len = max_len
  41. d_model = in_channels
  42. dim_feedforward = d_model * 4
  43. nhead = nhead if nhead is not None else d_model // 32
  44. self.embedding = Embeddings(
  45. d_model=d_model,
  46. vocab=self.out_channels,
  47. padding_idx=0,
  48. scale_embedding=scale_embedding,
  49. )
  50. self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, d_model],
  51. dtype=torch.float32),
  52. requires_grad=True)
  53. trunc_normal_(self.pos_embed, std=0.02)
  54. self.decoder = nn.ModuleList([
  55. TransformerBlock(
  56. d_model,
  57. nhead,
  58. dim_feedforward,
  59. attention_dropout_rate,
  60. residual_dropout_rate,
  61. with_self_attn=True,
  62. with_cross_attn=False,
  63. ) for i in range(num_decoder_layers)
  64. ])
  65. self.tgt_word_prj = nn.Linear(d_model,
  66. self.out_channels - 2,
  67. bias=False)
  68. self.apply(self._init_weights)
  69. def _init_weights(self, m):
  70. if isinstance(m, nn.Linear):
  71. nn.init.xavier_normal_(m.weight)
  72. if m.bias is not None:
  73. nn.init.zeros_(m.bias)
  74. def forward_train(self, src, tgt):
  75. tgt = tgt[:, :-1]
  76. tgt = self.embedding(
  77. tgt) + src[:, :tgt.shape[1]] + self.pos_embed[:, :tgt.shape[1]]
  78. tgt_mask = self.generate_square_subsequent_mask(
  79. tgt.shape[1], device=src.get_device())
  80. for decoder_layer in self.decoder:
  81. tgt = decoder_layer(tgt, self_mask=tgt_mask)
  82. output = tgt
  83. logit = self.tgt_word_prj(output)
  84. return logit
  85. def forward(self, src, data=None):
  86. if self.training:
  87. max_len = data[1].max()
  88. tgt = data[0][:, :2 + max_len]
  89. res = self.forward_train(src, tgt)
  90. else:
  91. res = self.forward_test(src)
  92. return res
  93. def forward_test(self, src):
  94. bs = src.shape[0]
  95. src = src + self.pos_embed
  96. dec_seq = torch.full((bs, self.max_len + 1),
  97. self.ignore_index,
  98. dtype=torch.int64,
  99. device=src.get_device())
  100. dec_seq[:, 0] = self.bos
  101. logits = []
  102. for len_dec_seq in range(0, self.max_len):
  103. dec_seq_embed = self.embedding(
  104. dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a
  105. dec_seq_embed = dec_seq_embed + src[:, :len_dec_seq + 1]
  106. tgt_mask = self.generate_square_subsequent_mask(
  107. dec_seq_embed.shape[1], src.get_device())
  108. tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
  109. for decoder_layer in self.decoder:
  110. tgt = decoder_layer(tgt, self_mask=tgt_mask)
  111. dec_output = tgt
  112. dec_output = dec_output[:, -1:, :]
  113. word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
  114. logits.append(word_prob)
  115. if len_dec_seq < self.max_len:
  116. # greedy decode. add the next token index to the target input
  117. dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
  118. # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
  119. if (dec_seq == self.eos).any(dim=-1).all():
  120. break
  121. logits = torch.cat(logits, dim=1)
  122. return logits
  123. def generate_square_subsequent_mask(self, sz, device):
  124. """Generate a square mask for the sequence.
  125. The masked positions are filled with float('-inf'). Unmasked positions
  126. are filled with float(0.0).
  127. """
  128. mask = torch.zeros([sz, sz], dtype=torch.float32)
  129. mask_inf = torch.triu(
  130. torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
  131. diagonal=1,
  132. )
  133. mask = mask + mask_inf
  134. return mask.unsqueeze(0).unsqueeze(0).to(device)
  135. class OTEDecoder(nn.Module):
  136. def __init__(self,
  137. in_channels,
  138. out_channels,
  139. max_len=25,
  140. num_heads=None,
  141. ar=False,
  142. num_decoder_layers=1,
  143. **kwargs):
  144. super(OTEDecoder, self).__init__()
  145. self.out_channels = out_channels - 2 # none + 26 + 10
  146. dim = in_channels
  147. self.dim = dim
  148. self.max_len = max_len + 1 # max_len + eos
  149. self.cpa = CPA(dim=dim, max_len=max_len)
  150. self.ar = ar
  151. if ar:
  152. self.ar_decoder = ARDecoder(in_channels=dim,
  153. out_channels=out_channels,
  154. nhead=num_heads,
  155. num_decoder_layers=num_decoder_layers,
  156. max_len=max_len)
  157. else:
  158. self.fc = nn.Linear(dim, self.out_channels)
  159. self.apply(self._init_weights)
  160. def _init_weights(self, m):
  161. if isinstance(m, nn.Linear):
  162. trunc_normal_(m.weight, std=0.02)
  163. if isinstance(m, nn.Linear) and m.bias is not None:
  164. zeros_(m.bias)
  165. elif isinstance(m, nn.LayerNorm):
  166. zeros_(m.bias)
  167. ones_(m.weight)
  168. @torch.jit.ignore
  169. def no_weight_decay(self):
  170. return {'pos_embed'}
  171. def forward(self, x, data=None):
  172. x = self.cpa(x)
  173. if self.ar:
  174. return self.ar_decoder(x, data=data)
  175. logits = self.fc(x) # B, 26, 37
  176. if self.training:
  177. logits = logits[:, :data[1].max() + 1]
  178. return logits