123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class SAREncoder(nn.Module):
- def __init__(self,
- enc_bi_rnn=False,
- enc_drop_rnn=0.1,
- in_channels=512,
- d_enc=512,
- **kwargs):
- super().__init__()
- # LSTM Encoder
- if enc_bi_rnn:
- bidirectional = True
- else:
- bidirectional = False
- hidden_size = d_enc
- self.rnn_encoder = nn.LSTM(input_size=in_channels,
- hidden_size=hidden_size,
- num_layers=2,
- dropout=enc_drop_rnn,
- bidirectional=bidirectional,
- batch_first=True)
- # global feature transformation
- encoder_rnn_out_size = hidden_size * (int(enc_bi_rnn) + 1)
- self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
- def forward(self, feat):
- h_feat = feat.shape[2]
- feat_v = F.max_pool2d(feat,
- kernel_size=(h_feat, 1),
- stride=1,
- padding=0)
- feat_v = feat_v.squeeze(2)
- feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
- holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * hidden_size
- valid_hf = holistic_feat[:, -1, :] # bsz * hidden_size
- holistic_feat = self.linear(valid_hf) # bsz * C
- return holistic_feat
- class SARDecoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- max_len=25,
- enc_bi_rnn=False,
- enc_drop_rnn=0.1,
- dec_bi_rnn=False,
- dec_drop_rnn=0.0,
- pred_dropout=0.1,
- pred_concat=True,
- mask=True,
- use_lstm=True,
- **kwargs):
- super(SARDecoder, self).__init__()
- self.num_classes = out_channels
- self.start_idx = out_channels - 2
- self.padding_idx = out_channels - 1
- self.end_idx = 0
- self.max_seq_len = max_len + 1
- self.pred_concat = pred_concat
- self.mask = mask
- enc_dim = in_channels
- d = in_channels
- embedding_dim = in_channels
- dec_dim = in_channels
- self.use_lstm = use_lstm
- if use_lstm:
- # encoder module
- self.encoder = SAREncoder(enc_bi_rnn=enc_bi_rnn,
- enc_drop_rnn=enc_drop_rnn,
- in_channels=in_channels,
- d_enc=enc_dim)
- # decoder module
- # 2D attention layer
- self.conv1x1_1 = nn.Linear(dec_dim, d)
- self.conv3x3_1 = nn.Conv2d(in_channels,
- d,
- kernel_size=3,
- stride=1,
- padding=1)
- self.conv1x1_2 = nn.Linear(d, 1)
- # Decoder input embedding
- self.embedding = nn.Embedding(self.num_classes,
- embedding_dim,
- padding_idx=self.padding_idx)
- self.rnndecoder = nn.LSTM(input_size=embedding_dim,
- hidden_size=dec_dim,
- num_layers=2,
- dropout=dec_drop_rnn,
- bidirectional=dec_bi_rnn,
- batch_first=True)
- # Prediction layer
- self.pred_dropout = nn.Dropout(pred_dropout)
- if pred_concat:
- fc_in_channel = in_channels + in_channels + dec_dim
- else:
- fc_in_channel = in_channels
- self.prediction = nn.Linear(fc_in_channel, self.num_classes)
- self.softmax = nn.Softmax(dim=-1)
- def _2d_attation(self, feat, tokens, data, training):
- Hidden_state = self.rnndecoder(tokens)[0]
- attn_query = self.conv1x1_1(Hidden_state)
- bsz, seq_len, _ = attn_query.size()
- attn_query = attn_query.unsqueeze(-1).unsqueeze(-1)
- # bsz * seq_len+1 * attn_size * 1 * 1
- attn_key = self.conv3x3_1(feat).unsqueeze(1)
- # bsz * 1 * attn_size * h * w
- attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
- attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
- attn_weight = self.conv1x1_2(attn_weight)
- _, T, h, w, c = attn_weight.size()
- if self.mask:
- valid_ratios = data[-1]
- # cal mask of attention weight
- attn_mask = torch.zeros_like(attn_weight)
- for i, valid_ratio in enumerate(valid_ratios):
- valid_width = min(w, math.ceil(w * valid_ratio))
- attn_mask[i, :, :, valid_width:, :] = 1
- attn_weight = attn_weight.masked_fill(attn_mask.bool(),
- float('-inf'))
- attn_weight = attn_weight.view(bsz, T, -1)
- attn_weight = F.softmax(attn_weight, dim=-1)
- attn_weight = attn_weight.view(bsz, T, h, w,
- c).permute(0, 1, 4, 2, 3).contiguous()
- # bsz, T, 1, h, w
- # bsz, 1, f_c ,h, w
- attn_feat = torch.sum(torch.mul(feat.unsqueeze(1), attn_weight),
- (3, 4),
- keepdim=False)
- return [Hidden_state, attn_feat]
- def forward_train(self, feat, holistic_feat, data):
- max_len = data[1].max()
- label = data[0][:, :1 + max_len] # label
- label_embedding = self.embedding(label)
- holistic_feat = holistic_feat.unsqueeze(1)
- tokens = torch.cat((holistic_feat, label_embedding), dim=1)
- Hidden_state, attn_feat = self._2d_attation(feat,
- tokens,
- data,
- training=self.training)
- bsz, seq_len, f_c = Hidden_state.size()
- # linear transformation
- if self.pred_concat:
- f_c = holistic_feat.size(-1)
- holistic_feat = holistic_feat.expand(bsz, seq_len, f_c)
- preds = self.prediction(
- torch.cat((Hidden_state, attn_feat, holistic_feat), 2))
- else:
- preds = self.prediction(attn_feat)
- # bsz * (seq_len + 1) * num_classes
- preds = self.pred_dropout(preds)
- return preds[:, 1:, :]
- def forward_test(self, feat, holistic_feat, data=None):
- bsz = feat.shape[0]
- seq_len = self.max_seq_len
- holistic_feat = holistic_feat.unsqueeze(1)
- tokens = torch.full((bsz, ),
- self.start_idx,
- device=feat.device,
- dtype=torch.long)
- outputs = []
- tokens = self.embedding(tokens)
- tokens = tokens.unsqueeze(1).expand(-1, seq_len, -1)
- tokens = torch.cat((holistic_feat, tokens), dim=1)
- for i in range(1, seq_len + 1):
- Hidden_state, attn_feat = self._2d_attation(feat,
- tokens,
- data=data,
- training=self.training)
- if self.pred_concat:
- f_c = holistic_feat.size(-1)
- holistic_feat = holistic_feat.expand(bsz, seq_len + 1, f_c)
- preds = self.prediction(
- torch.cat((Hidden_state, attn_feat, holistic_feat), 2))
- else:
- preds = self.prediction(attn_feat)
- # bsz * (seq_len + 1) * num_classes
- char_output = preds[:, i, :]
- char_output = F.softmax(char_output, -1)
- outputs.append(char_output)
- _, max_idx = torch.max(char_output, dim=1, keepdim=False)
- char_embedding = self.embedding(max_idx)
- if (i < seq_len):
- tokens[:, i + 1, :] = char_embedding
- if (tokens == self.end_idx).any(dim=-1).all():
- break
- outputs = torch.stack(outputs, 1)
- return outputs
- def forward(self, feat, data=None):
- if self.use_lstm:
- holistic_feat = self.encoder(feat) # bsz c
- else:
- holistic_feat = F.adaptive_avg_pool2d(feat, (1, 1)).squeeze()
- if self.training:
- preds = self.forward_train(feat, holistic_feat, data=data)
- else:
- preds = self.forward_test(feat, holistic_feat, data=data)
- # (bsz, seq_len, num_classes)
- return preds
|