123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- import torch
- import torch.nn as nn
- from torch.nn import functional as F
- from torch.nn import init
- class Embedding(nn.Module):
- def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
- super(Embedding, self).__init__()
- self.in_timestep = in_timestep
- self.in_planes = in_planes
- self.embed_dim = embed_dim
- self.mid_dim = mid_dim
- self.eEmbed = nn.Linear(
- in_timestep * in_planes,
- self.embed_dim) # Embed encoder output to a word-embedding like
- def forward(self, x):
- x = x.flatten(1)
- x = self.eEmbed(x)
- return x
- class Attn_Rnn_Block(nn.Module):
- def __init__(self, featdim, hiddendim, embedding_dim, out_channels,
- attndim):
- super(Attn_Rnn_Block, self).__init__()
- self.attndim = attndim
- self.embedding_dim = embedding_dim
- self.feat_embed = nn.Linear(featdim, attndim)
- self.hidden_embed = nn.Linear(hiddendim, attndim)
- self.attnfeat_embed = nn.Linear(attndim, 1)
- self.gru = nn.GRU(input_size=featdim + self.embedding_dim,
- hidden_size=hiddendim,
- batch_first=True)
- self.fc = nn.Linear(hiddendim, out_channels)
- self.init_weights()
- def init_weights(self):
- init.normal_(self.hidden_embed.weight, std=0.01)
- init.constant_(self.hidden_embed.bias, 0)
- init.normal_(self.attnfeat_embed.weight, std=0.01)
- init.constant_(self.attnfeat_embed.bias, 0)
- def _attn(self, feat, h_state):
- b, t, _ = feat.shape
- feat = self.feat_embed(feat)
- h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1)
- h_state = h_state.expand(b, t, self.attndim)
- sumTanh = torch.tanh(feat + h_state)
- attn_w = self.attnfeat_embed(sumTanh).squeeze(-1)
- attn_w = F.softmax(attn_w, dim=1).unsqueeze(1)
- # [B,1,25]
- return attn_w
- def forward(self, feat, h_state, label_input):
- attn_w = self._attn(feat, h_state)
- attn_feat = attn_w @ feat
- attn_feat = attn_feat.squeeze(1)
- output, h_state = self.gru(
- torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state)
- pred = self.fc(output)
- return pred, h_state
- class ASTERDecoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- embedding_dim=256,
- hiddendim=256,
- attndim=256,
- max_len=25,
- seed=False,
- time_step=32,
- **kwargs):
- super(ASTERDecoder, self).__init__()
- self.num_classes = out_channels
- self.bos = out_channels - 2
- self.eos = 0
- self.padding_idx = out_channels - 1
- self.seed = seed
- if seed:
- self.embeder = Embedding(
- in_timestep=time_step,
- in_planes=in_channels,
- )
- self.word_embedding = nn.Embedding(self.num_classes,
- embedding_dim,
- padding_idx=self.padding_idx)
- self.attndim = attndim
- self.hiddendim = hiddendim
- self.max_seq_len = max_len + 1
- self.featdim = in_channels
- self.attn_rnn_block = Attn_Rnn_Block(
- featdim=self.featdim,
- hiddendim=hiddendim,
- embedding_dim=embedding_dim,
- out_channels=out_channels - 2,
- attndim=attndim,
- )
- self.embed_fc = nn.Linear(300, self.hiddendim)
- def get_initial_state(self, embed, tile_times=1):
- assert embed.shape[1] == 300
- state = self.embed_fc(embed) # N * sDim
- if tile_times != 1:
- state = state.unsqueeze(1)
- trans_state = state.transpose(0, 1)
- state = trans_state.tile([tile_times, 1, 1])
- trans_state = state.transpose(0, 1)
- state = trans_state.reshape(-1, self.hiddendim)
- state = state.unsqueeze(0) # 1 * N * sDim
- return state
- def forward(self, feat, data=None):
- # b,25,512
- b = feat.size(0)
- if self.seed:
- embedding_vectors = self.embeder(feat)
- h_state = self.get_initial_state(embedding_vectors)
- else:
- h_state = torch.zeros(1, b, self.hiddendim).to(feat.device)
- outputs = []
- if self.training:
- label = data[0]
- label_embedding = self.word_embedding(label) # [B,25,256]
- tokens = label_embedding[:, 0, :]
- max_len = data[1].max() + 1
- else:
- tokens = torch.full([b, 1],
- self.bos,
- device=feat.device,
- dtype=torch.long)
- tokens = self.word_embedding(tokens.squeeze(1))
- max_len = self.max_seq_len
- pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
- outputs.append(pred)
- dec_seq = torch.full((feat.shape[0], max_len),
- self.padding_idx,
- dtype=torch.int64,
- device=feat.get_device())
- dec_seq[:, :1] = torch.argmax(pred, dim=-1)
- for i in range(1, max_len):
- if not self.training:
- max_idx = torch.argmax(pred, dim=-1).squeeze(1)
- tokens = self.word_embedding(max_idx)
- dec_seq[:, i] = max_idx
- if (dec_seq == self.eos).any(dim=-1).all():
- break
- else:
- tokens = label_embedding[:, i, :]
- pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
- outputs.append(pred)
- preds = torch.cat(outputs, 1)
- if self.seed and self.training:
- return [embedding_vectors, preds]
- return preds if self.training else F.softmax(preds, -1)
|