srn_decoder.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .nrtr_decoder import Embeddings, TransformerBlock
  5. class PVAM(nn.Module):
  6. def __init__(self,
  7. in_channels,
  8. char_num,
  9. max_text_length,
  10. num_heads,
  11. hidden_dims,
  12. dropout_rate=0):
  13. super(PVAM, self).__init__()
  14. self.char_num = char_num
  15. self.max_length = max_text_length
  16. self.num_heads = num_heads
  17. self.hidden_dims = hidden_dims
  18. self.dropout_rate = dropout_rate
  19. #TODO
  20. self.emb = nn.Embedding(num_embeddings=256,
  21. embedding_dim=hidden_dims,
  22. sparse=False)
  23. self.drop_out = nn.Dropout(dropout_rate)
  24. self.feat_emb = nn.Linear(in_channels, in_channels)
  25. self.token_emb = nn.Embedding(max_text_length, in_channels)
  26. self.score = nn.Linear(in_channels, 1, bias=False)
  27. def feat_pos_mix(self, conv_features, encoder_word_pos, dropout_rate):
  28. #b h*w c
  29. pos_emb = self.emb(encoder_word_pos)
  30. # pos_emb = pos_emb.detach()
  31. enc_input = conv_features + pos_emb
  32. if dropout_rate:
  33. enc_input = self.drop_out(enc_input)
  34. return enc_input
  35. def forward(self, inputs):
  36. b, c, h, w = inputs.shape
  37. conv_features = inputs.view(-1, c, h * w)
  38. conv_features = conv_features.permute(0, 2, 1).contiguous()
  39. # b h*w c
  40. # transformer encoder
  41. b, t, c = conv_features.shape
  42. encoder_feat_pos = torch.arange(t, dtype=torch.long).to(inputs.device)
  43. enc_inputs = self.feat_pos_mix(conv_features, encoder_feat_pos,
  44. self.dropout_rate)
  45. inputs = self.feat_emb(enc_inputs) # feat emb
  46. inputs = inputs.unsqueeze(1).expand(-1, self.max_length, -1, -1)
  47. # b maxlen h*w c
  48. tokens_pos = torch.arange(self.max_length,
  49. dtype=torch.long).to(inputs.device)
  50. tokens_pos = tokens_pos.unsqueeze(0).expand(b, -1)
  51. tokens_pos_emd = self.token_emb(tokens_pos)
  52. tokens_pos_emd = tokens_pos_emd.unsqueeze(2).expand(-1, -1, t, -1)
  53. # b maxlen h*w c
  54. attention_weight = torch.tanh(tokens_pos_emd + inputs)
  55. attention_weight = torch.squeeze(self.score(attention_weight),
  56. -1) #b,25,256
  57. attention_weight = F.softmax(attention_weight, dim=-1) #b,25,256
  58. pvam_features = torch.matmul(attention_weight, enc_inputs)
  59. return pvam_features
  60. class GSRM(nn.Module):
  61. def __init__(self,
  62. in_channel,
  63. char_num,
  64. max_len,
  65. num_heads,
  66. hidden_dims,
  67. num_layers,
  68. dropout_rate=0,
  69. attention_dropout=0.1):
  70. super(GSRM, self).__init__()
  71. self.char_num = char_num
  72. self.max_len = max_len
  73. self.num_heads = num_heads
  74. self.cls_op = nn.Linear(in_channel, self.char_num)
  75. self.cls_final = nn.Linear(in_channel, self.char_num)
  76. self.word_emb = Embeddings(d_model=hidden_dims, vocab=char_num)
  77. self.pos_emb = nn.Embedding(char_num, hidden_dims)
  78. self.dropout_rate = dropout_rate
  79. self.emb_drop_out = nn.Dropout(dropout_rate)
  80. self.forward_self_attn = nn.ModuleList([
  81. TransformerBlock(
  82. d_model=hidden_dims,
  83. nhead=num_heads,
  84. attention_dropout_rate=attention_dropout,
  85. residual_dropout_rate=0.1,
  86. dim_feedforward=hidden_dims,
  87. with_self_attn=True,
  88. with_cross_attn=False,
  89. ) for i in range(num_layers)
  90. ])
  91. self.backward_self_attn = nn.ModuleList([
  92. TransformerBlock(
  93. d_model=hidden_dims,
  94. nhead=num_heads,
  95. attention_dropout_rate=attention_dropout,
  96. residual_dropout_rate=0.1,
  97. dim_feedforward=hidden_dims,
  98. with_self_attn=True,
  99. with_cross_attn=False,
  100. ) for i in range(num_layers)
  101. ])
  102. def _pos_emb(self, word_seq, pos, dropoutrate):
  103. """
  104. word_Seq: bsz len
  105. pos: bsz len
  106. """
  107. word_emb_seq = self.word_emb(word_seq)
  108. pos_emb_seq = self.pos_emb(pos)
  109. # pos_emb_seq = pos_emb_seq.detach()
  110. input_mix = word_emb_seq + pos_emb_seq
  111. if dropoutrate > 0:
  112. input_mix = self.emb_drop_out(input_mix)
  113. return input_mix
  114. def forward(self, inputs):
  115. bos_idx = self.char_num - 2
  116. eos_idx = self.char_num - 1
  117. b, t, c = inputs.size() #b,25,512
  118. inputs = inputs.view(-1, c)
  119. cls_res = self.cls_op(inputs) #b,25,n_class
  120. word_pred_PVAM = F.softmax(cls_res, dim=-1).argmax(-1)
  121. word_pred_PVAM = word_pred_PVAM.view(-1, t, 1)
  122. #b 25 1
  123. word1 = F.pad(word_pred_PVAM, [0, 0, 1, 0], 'constant', value=bos_idx)
  124. word_forward = word1[:, :-1, :].squeeze(-1)
  125. word_backward = word_pred_PVAM.squeeze(-1)
  126. #mask
  127. attn_mask_forward = torch.triu(
  128. torch.full((self.max_len, self.max_len),
  129. dtype=torch.float32,
  130. fill_value=-torch.inf),
  131. diagonal=1,
  132. ).to(inputs.device)
  133. attn_mask_forward = attn_mask_forward.unsqueeze(0).expand(
  134. self.num_heads, -1, -1)
  135. attn_mask_backward = torch.tril(
  136. torch.full((self.max_len, self.max_len),
  137. dtype=torch.float32,
  138. fill_value=-torch.inf),
  139. diagonal=-1,
  140. ).to(inputs.device)
  141. attn_mask_backward = attn_mask_backward.unsqueeze(0).expand(
  142. self.num_heads, -1, -1)
  143. #B,25
  144. pos = torch.arange(self.max_len, dtype=torch.long).to(inputs.device)
  145. pos = pos.unsqueeze(0).expand(b, -1) #b,25
  146. word_front_mix = self._pos_emb(word_forward, pos, self.dropout_rate)
  147. word_backward_mix = self._pos_emb(word_backward, pos,
  148. self.dropout_rate)
  149. # b 25 emb_dim
  150. for attn_layer in self.forward_self_attn:
  151. word_front_mix = attn_layer(word_front_mix,
  152. self_mask=attn_mask_forward)
  153. for attn_layer in self.backward_self_attn:
  154. word_backward_mix = attn_layer(word_backward_mix,
  155. self_mask=attn_mask_backward)
  156. #b,25,emb_dim
  157. eos_emd = self.word_emb(torch.full(
  158. (1, ), eos_idx).to(inputs.device)).expand(b, 1, -1)
  159. word_backward_mix = torch.cat((word_backward_mix, eos_emd), dim=1)
  160. word_backward_mix = word_backward_mix[:, 1:, ]
  161. gsrm_features = word_front_mix + word_backward_mix
  162. gsrm_out = self.cls_final(gsrm_features)
  163. # torch.matmul(gsrm_features,
  164. # self.word_emb.embedding.weight.permute(1, 0))
  165. b, t, c = gsrm_out.size()
  166. #b,25,n_class
  167. gsrm_out = gsrm_out.view(-1, c).contiguous()
  168. return gsrm_features, cls_res, gsrm_out
  169. class VSFD(nn.Module):
  170. def __init__(self, in_channels, out_channels):
  171. super(VSFD, self).__init__()
  172. self.char_num = out_channels
  173. self.fc0 = nn.Linear(in_channels * 2, in_channels)
  174. self.fc1 = nn.Linear(in_channels, self.char_num)
  175. def forward(self, pvam_feature, gsrm_feature):
  176. _, t, c1 = pvam_feature.size()
  177. _, t, c2 = gsrm_feature.size()
  178. combine_featurs = torch.cat([pvam_feature, gsrm_feature], dim=-1)
  179. combine_featurs = combine_featurs.view(-1, c1 + c2).contiguous()
  180. atten = self.fc0(combine_featurs)
  181. atten = torch.sigmoid(atten)
  182. atten = atten.view(-1, t, c1)
  183. combine_featurs = atten * pvam_feature + (1 - atten) * gsrm_feature
  184. combine_featurs = combine_featurs.view(-1, c1).contiguous()
  185. out = self.fc1(combine_featurs)
  186. return out
  187. class SRNDecoder(nn.Module):
  188. def __init__(self,
  189. in_channels,
  190. out_channels,
  191. hidden_dims,
  192. num_decoder_layers=4,
  193. max_text_length=25,
  194. num_heads=8,
  195. **kwargs):
  196. super(SRNDecoder, self).__init__()
  197. self.max_text_length = max_text_length
  198. self.num_heads = num_heads
  199. self.pvam = PVAM(in_channels=in_channels,
  200. char_num=out_channels,
  201. max_text_length=max_text_length,
  202. num_heads=num_heads,
  203. hidden_dims=hidden_dims,
  204. dropout_rate=0.1)
  205. self.gsrm = GSRM(in_channel=in_channels,
  206. char_num=out_channels,
  207. max_len=max_text_length,
  208. num_heads=num_heads,
  209. num_layers=num_decoder_layers,
  210. hidden_dims=hidden_dims)
  211. self.vsfd = VSFD(in_channels=in_channels, out_channels=out_channels)
  212. def forward(self, feat, data=None):
  213. # feat [B,512,8,32]
  214. pvam_feature = self.pvam(feat)
  215. gsrm_features, pvam_preds, gsrm_preds = self.gsrm(pvam_feature)
  216. vsfd_preds = self.vsfd(pvam_feature, gsrm_features)
  217. if not self.training:
  218. preds = F.softmax(vsfd_preds, dim=-1)
  219. return preds
  220. return [pvam_preds, gsrm_preds, vsfd_preds]