123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam
- from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder
- def generate_square_subsequent_mask(sz):
- r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
- Unmasked positions are filled with float(0.0).
- """
- mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
- mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
- mask == 1, float(0.0)))
- return mask
- class SEM_Pre(nn.Module):
- def __init__(
- self,
- d_model=512,
- dst_vocab_size=40,
- residual_dropout_rate=0.1,
- ):
- super(SEM_Pre, self).__init__()
- self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size)
- self.positional_encoding = PositionalEncoding(
- dropout=residual_dropout_rate,
- dim=d_model,
- )
- def forward(self, tgt):
- tgt = self.embedding(tgt)
- tgt = self.positional_encoding(tgt)
- tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device)
- return tgt, tgt_mask
- class POS_Pre(nn.Module):
- def __init__(
- self,
- d_model=512,
- ):
- super(POS_Pre, self).__init__()
- self.pos_encoding = PositionalEncoding(
- dropout=0.1,
- dim=d_model,
- )
- self.linear1 = nn.Linear(d_model, d_model)
- self.linear2 = nn.Linear(d_model, d_model)
- self.norm2 = nn.LayerNorm(d_model)
- def forward(self, tgt):
- pos = tgt.new_zeros(*tgt.shape)
- pos = self.pos_encoding(pos)
- pos2 = self.linear2(F.relu(self.linear1(pos)))
- pos = self.norm2(pos + pos2)
- return pos
- class DSF(nn.Module):
- def __init__(self, d_model, fusion_num):
- super(DSF, self).__init__()
- self.w_att = nn.Linear(fusion_num * d_model, d_model)
- def forward(self, l_feature, v_feature):
- """
- Args:
- l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
- v_feature: (N, T, E) shape the same as l_feature
- l_lengths: (N,)
- v_lengths: (N,)
- """
- f = torch.cat((l_feature, v_feature), dim=2)
- f_att = torch.sigmoid(self.w_att(f))
- output = f_att * v_feature + (1 - f_att) * l_feature
- return output
- class MDCDP(nn.Module):
- r"""
- Multi-Domain CharacterDistance Perception
- """
- def __init__(self, d_model, n_head, d_inner, num_layers):
- super(MDCDP, self).__init__()
- self.num_layers = num_layers
- # step 1 SAE
- self.layers_pos = nn.ModuleList([
- TransformerBlock(d_model, n_head, d_inner)
- for _ in range(num_layers)
- ])
- # step 2 CBI:
- self.layers2 = nn.ModuleList([
- TransformerBlock(
- d_model,
- n_head,
- d_inner,
- with_self_attn=False,
- with_cross_attn=True,
- ) for _ in range(num_layers)
- ])
- self.layers3 = nn.ModuleList([
- TransformerBlock(
- d_model,
- n_head,
- d_inner,
- with_self_attn=False,
- with_cross_attn=True,
- ) for _ in range(num_layers)
- ])
- # step 3 :DSF
- self.dynamic_shared_fusion = DSF(d_model, 2)
- def forward(
- self,
- sem,
- vis,
- pos,
- tgt_mask=None,
- memory_mask=None,
- ):
- for i in range(self.num_layers):
- # ----------step 1 -----------: SAE: Self-Attention Enhancement
- pos = self.layers_pos[i](pos, self_mask=tgt_mask)
- # ----------step 2 -----------: CBI: Cross-Branch Interaction
- # CBI-V
- pos_vis = self.layers2[i](
- pos,
- vis,
- cross_mask=memory_mask,
- )
- # CBI-S
- pos_sem = self.layers3[i](
- pos,
- sem,
- cross_mask=tgt_mask,
- )
- # ----------step 3 -----------: DSF: Dynamic Shared Fusion
- pos = self.dynamic_shared_fusion(pos_vis, pos_sem)
- output = pos
- return output
- class ConvBnRelu(nn.Module):
- # adapt padding for kernel_size change
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- conv=nn.Conv2d,
- stride=2,
- inplace=True,
- ):
- super().__init__()
- p_size = [int(k // 2) for k in kernel_size]
- # p_size = int(kernel_size//2)
- self.conv = conv(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=p_size,
- )
- self.bn = nn.BatchNorm2d(out_channels)
- self.relu = nn.ReLU(inplace=inplace)
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
- class CDistNetDecoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- n_head=None,
- num_encoder_blocks=3,
- num_decoder_blocks=3,
- beam_size=0,
- max_len=25,
- residual_dropout_rate=0.1,
- add_conv=False,
- **kwargs):
- super(CDistNetDecoder, self).__init__()
- dst_vocab_size = out_channels
- self.ignore_index = dst_vocab_size - 1
- self.bos = dst_vocab_size - 2
- self.eos = 0
- self.beam_size = beam_size
- self.max_len = max_len
- self.add_conv = add_conv
- d_model = in_channels
- dim_feedforward = d_model * 4
- n_head = n_head if n_head is not None else d_model // 32
- if add_conv:
- self.convbnrelu = ConvBnRelu(
- in_channels=in_channels,
- out_channels=in_channels,
- kernel_size=(1, 3),
- stride=(1, 2),
- )
- if num_encoder_blocks > 0:
- self.positional_encoding = PositionalEncoding(
- dropout=0.1,
- dim=d_model,
- )
- self.trans_encoder = Transformer_Encoder(
- n_layers=num_encoder_blocks,
- n_head=n_head,
- d_model=d_model,
- d_inner=dim_feedforward,
- )
- else:
- self.trans_encoder = None
- self.semantic_branch = SEM_Pre(
- d_model=d_model,
- dst_vocab_size=dst_vocab_size,
- residual_dropout_rate=residual_dropout_rate,
- )
- self.positional_branch = POS_Pre(d_model=d_model)
- self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2,
- num_decoder_blocks)
- self._reset_parameters()
- self.tgt_word_prj = nn.Linear(
- d_model, dst_vocab_size - 2,
- bias=False) # We don't predict <bos> nor <pad>
- self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5)
- def forward(self, x, data=None):
- if self.add_conv:
- x = self.convbnrelu(x)
- # x = rearrange(x, "b c h w -> b (w h) c")
- x = x.flatten(2).transpose(1, 2)
- if self.trans_encoder is not None:
- x = self.positional_encoding(x)
- vis_feat = self.trans_encoder(x, src_mask=None)
- else:
- vis_feat = x
- if self.training:
- max_len = data[1].max()
- tgt = data[0][:, :1 + max_len]
- res = self.forward_train(vis_feat, tgt)
- else:
- if self.beam_size > 0:
- res = self.forward_beam(vis_feat)
- else:
- res = self.forward_test(vis_feat)
- return res
- def forward_train(self, vis_feat, tgt):
- sem_feat, sem_mask = self.semantic_branch(tgt)
- pos_feat = self.positional_branch(sem_feat)
- output = self.mdcdp(
- sem_feat,
- vis_feat,
- pos_feat,
- tgt_mask=sem_mask,
- memory_mask=None,
- )
- logit = self.tgt_word_prj(output)
- return logit
- def forward_test(self, vis_feat):
- bs = vis_feat.size(0)
- dec_seq = torch.full(
- (bs, self.max_len + 1),
- self.ignore_index,
- dtype=torch.int64,
- device=vis_feat.device,
- )
- dec_seq[:, 0] = self.bos
- logits = []
- for len_dec_seq in range(0, self.max_len):
- sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq +
- 1])
- pos_feat = self.positional_branch(sem_feat)
- output = self.mdcdp(
- sem_feat,
- vis_feat,
- pos_feat,
- tgt_mask=sem_mask,
- memory_mask=None,
- )
- dec_output = output[:, -1:, :]
- word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
- logits.append(word_prob)
- if len_dec_seq < self.max_len:
- # greedy decode. add the next token index to the target input
- dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
- # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
- if (dec_seq == self.eos).any(dim=-1).all():
- break
- logits = torch.cat(logits, dim=1)
- return logits
- def forward_beam(self, x):
- """Translation work in one batch."""
- # to do
- def _reset_parameters(self):
- r"""Initiate parameters in the transformer model."""
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
|