aster_decoder.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import functional as F
  4. from torch.nn import init
  5. class Embedding(nn.Module):
  6. def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
  7. super(Embedding, self).__init__()
  8. self.in_timestep = in_timestep
  9. self.in_planes = in_planes
  10. self.embed_dim = embed_dim
  11. self.mid_dim = mid_dim
  12. self.eEmbed = nn.Linear(
  13. in_timestep * in_planes,
  14. self.embed_dim) # Embed encoder output to a word-embedding like
  15. def forward(self, x):
  16. x = x.flatten(1)
  17. x = self.eEmbed(x)
  18. return x
  19. class Attn_Rnn_Block(nn.Module):
  20. def __init__(self, featdim, hiddendim, embedding_dim, out_channels,
  21. attndim):
  22. super(Attn_Rnn_Block, self).__init__()
  23. self.attndim = attndim
  24. self.embedding_dim = embedding_dim
  25. self.feat_embed = nn.Linear(featdim, attndim)
  26. self.hidden_embed = nn.Linear(hiddendim, attndim)
  27. self.attnfeat_embed = nn.Linear(attndim, 1)
  28. self.gru = nn.GRU(input_size=featdim + self.embedding_dim,
  29. hidden_size=hiddendim,
  30. batch_first=True)
  31. self.fc = nn.Linear(hiddendim, out_channels)
  32. self.init_weights()
  33. def init_weights(self):
  34. init.normal_(self.hidden_embed.weight, std=0.01)
  35. init.constant_(self.hidden_embed.bias, 0)
  36. init.normal_(self.attnfeat_embed.weight, std=0.01)
  37. init.constant_(self.attnfeat_embed.bias, 0)
  38. def _attn(self, feat, h_state):
  39. b, t, _ = feat.shape
  40. feat = self.feat_embed(feat)
  41. h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1)
  42. h_state = h_state.expand(b, t, self.attndim)
  43. sumTanh = torch.tanh(feat + h_state)
  44. attn_w = self.attnfeat_embed(sumTanh).squeeze(-1)
  45. attn_w = F.softmax(attn_w, dim=1).unsqueeze(1)
  46. # [B,1,25]
  47. return attn_w
  48. def forward(self, feat, h_state, label_input):
  49. attn_w = self._attn(feat, h_state)
  50. attn_feat = attn_w @ feat
  51. attn_feat = attn_feat.squeeze(1)
  52. output, h_state = self.gru(
  53. torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state)
  54. pred = self.fc(output)
  55. return pred, h_state
  56. class ASTERDecoder(nn.Module):
  57. def __init__(self,
  58. in_channels,
  59. out_channels,
  60. embedding_dim=256,
  61. hiddendim=256,
  62. attndim=256,
  63. max_len=25,
  64. seed=False,
  65. time_step=32,
  66. **kwargs):
  67. super(ASTERDecoder, self).__init__()
  68. self.num_classes = out_channels
  69. self.bos = out_channels - 2
  70. self.eos = 0
  71. self.padding_idx = out_channels - 1
  72. self.seed = seed
  73. if seed:
  74. self.embeder = Embedding(
  75. in_timestep=time_step,
  76. in_planes=in_channels,
  77. )
  78. self.word_embedding = nn.Embedding(self.num_classes,
  79. embedding_dim,
  80. padding_idx=self.padding_idx)
  81. self.attndim = attndim
  82. self.hiddendim = hiddendim
  83. self.max_seq_len = max_len + 1
  84. self.featdim = in_channels
  85. self.attn_rnn_block = Attn_Rnn_Block(
  86. featdim=self.featdim,
  87. hiddendim=hiddendim,
  88. embedding_dim=embedding_dim,
  89. out_channels=out_channels - 2,
  90. attndim=attndim,
  91. )
  92. self.embed_fc = nn.Linear(300, self.hiddendim)
  93. def get_initial_state(self, embed, tile_times=1):
  94. assert embed.shape[1] == 300
  95. state = self.embed_fc(embed) # N * sDim
  96. if tile_times != 1:
  97. state = state.unsqueeze(1)
  98. trans_state = state.transpose(0, 1)
  99. state = trans_state.tile([tile_times, 1, 1])
  100. trans_state = state.transpose(0, 1)
  101. state = trans_state.reshape(-1, self.hiddendim)
  102. state = state.unsqueeze(0) # 1 * N * sDim
  103. return state
  104. def forward(self, feat, data=None):
  105. # b,25,512
  106. b = feat.size(0)
  107. if self.seed:
  108. embedding_vectors = self.embeder(feat)
  109. h_state = self.get_initial_state(embedding_vectors)
  110. else:
  111. h_state = torch.zeros(1, b, self.hiddendim).to(feat.device)
  112. outputs = []
  113. if self.training:
  114. label = data[0]
  115. label_embedding = self.word_embedding(label) # [B,25,256]
  116. tokens = label_embedding[:, 0, :]
  117. max_len = data[1].max() + 1
  118. else:
  119. tokens = torch.full([b, 1],
  120. self.bos,
  121. device=feat.device,
  122. dtype=torch.long)
  123. tokens = self.word_embedding(tokens.squeeze(1))
  124. max_len = self.max_seq_len
  125. pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
  126. outputs.append(pred)
  127. dec_seq = torch.full((feat.shape[0], max_len),
  128. self.padding_idx,
  129. dtype=torch.int64,
  130. device=feat.get_device())
  131. dec_seq[:, :1] = torch.argmax(pred, dim=-1)
  132. for i in range(1, max_len):
  133. if not self.training:
  134. max_idx = torch.argmax(pred, dim=-1).squeeze(1)
  135. tokens = self.word_embedding(max_idx)
  136. dec_seq[:, i] = max_idx
  137. if (dec_seq == self.eos).any(dim=-1).all():
  138. break
  139. else:
  140. tokens = label_embedding[:, i, :]
  141. pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
  142. outputs.append(pred)
  143. preds = torch.cat(outputs, 1)
  144. if self.seed and self.training:
  145. return [embedding_vectors, preds]
  146. return preds if self.training else F.softmax(preds, -1)