123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- """This code is refer from:
- https://github.com/jjwei66/BUSNet
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .nrtr_decoder import PositionalEncoding, TransformerBlock
- from .abinet_decoder import _get_mask, _get_length
- class BUSDecoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- nhead=8,
- num_layers=4,
- dim_feedforward=2048,
- dropout=0.1,
- max_length=25,
- ignore_index=100,
- pretraining=False,
- detach=True):
- super().__init__()
- d_model = in_channels
- self.ignore_index = ignore_index
- self.pretraining = pretraining
- self.d_model = d_model
- self.detach = detach
- self.max_length = max_length + 1 # additional stop token
- self.out_channels = out_channels
- # --------------------------------------------------------------------------
- # decoder specifics
- self.proj = nn.Linear(out_channels, d_model, False)
- self.token_encoder = PositionalEncoding(dropout=0.1,
- dim=d_model,
- max_len=self.max_length)
- self.pos_encoder = PositionalEncoding(dropout=0.1,
- dim=d_model,
- max_len=self.max_length)
- self.decoder = nn.ModuleList([
- TransformerBlock(
- d_model=d_model,
- nhead=nhead,
- dim_feedforward=dim_feedforward,
- attention_dropout_rate=dropout,
- residual_dropout_rate=dropout,
- with_self_attn=False,
- with_cross_attn=True,
- ) for i in range(num_layers)
- ])
- v_mask = torch.empty((1, 1, d_model))
- l_mask = torch.empty((1, 1, d_model))
- self.v_mask = nn.Parameter(v_mask)
- self.l_mask = nn.Parameter(l_mask)
- torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
- torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)
- v_embeding = torch.empty((1, 1, d_model))
- l_embeding = torch.empty((1, 1, d_model))
- self.v_embeding = nn.Parameter(v_embeding)
- self.l_embeding = nn.Parameter(l_embeding)
- torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
- torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
- self.cls = nn.Linear(d_model, out_channels)
- def forward_decoder(self, q, x, mask=None):
- for decoder_layer in self.decoder:
- q = decoder_layer(q, x, cross_mask=mask)
- output = q # (N, T, E)
- logits = self.cls(output) # (N, T, C)
- return logits
- def forward(self, img_feat, data=None):
- """
- Args:
- tokens: (N, T, C) where T is length, N is batch size and C is classes number
- lengths: (N,)
- """
- img_feat = img_feat + self.v_embeding
- B, L, C = img_feat.shape
- # --------------------------------------------------------------------------
- # decoder procedure
- T = self.max_length
- zeros = img_feat.new_zeros((B, T, C))
- zeros_len = img_feat.new_zeros(B)
- query = self.pos_encoder(zeros)
- # 1. vision decode
- v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
- dim=1) # v
- padding_mask = _get_mask(
- self.max_length + zeros_len,
- self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen
- v_mask = torch.zeros((1, 1, self.max_length, L),
- device=img_feat.device).tile([B, 1, 1,
- 1]) # maxlen L
- mask = torch.cat((v_mask, padding_mask), 3)
- v_logits = self.forward_decoder(query, v_embed, mask=mask)
- # 2. language decode
- if self.training and self.pretraining:
- tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
- tokens = F.one_hot(tgt, num_classes=self.out_channels)
- tokens = tokens.float()
- lengths = data[-1]
- else:
- tokens = torch.softmax(v_logits, dim=-1)
- lengths = _get_length(v_logits)
- tokens = tokens.detach()
- token_embed = self.proj(tokens) # (N, T, E)
- token_embed = self.token_encoder(token_embed) # (T, N, E)
- token_embed = token_embed + self.l_embeding
- padding_mask = _get_mask(lengths,
- self.max_length) # 对tokens长度以外的padding
- mask = torch.cat((v_mask, padding_mask), 3)
- l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
- l_logits = self.forward_decoder(query, l_embed, mask=mask)
- # 3. vision language decode
- vl_embed = torch.cat((img_feat, token_embed), dim=1)
- vl_logits = self.forward_decoder(query, vl_embed, mask=mask)
- if self.training:
- return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
- else:
- return F.softmax(vl_logits, -1)
|