123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- """This code is refer from:
- https://github.com/byeonghu-na/MATRN
- """
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from openrec.modeling.decoders.abinet_decoder import BCNLanguage, PositionAttention, _get_length
- from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
- class BaseSemanticVisual_backbone_feature(nn.Module):
- def __init__(self,
- d_model=512,
- nhead=8,
- num_layers=4,
- dim_feedforward=2048,
- dropout=0.0,
- alignment_mask_example_prob=0.9,
- alignment_mask_candidate_prob=0.9,
- alignment_num_vis_mask=10,
- max_length=25,
- num_classes=37):
- super().__init__()
- self.mask_example_prob = alignment_mask_example_prob
- self.mask_candidate_prob = alignment_mask_candidate_prob #ifnone(config.model_alignment_mask_candidate_prob, 0.9)
- self.num_vis_mask = alignment_num_vis_mask
- self.nhead = nhead
- self.d_model = d_model
- self.max_length = max_length + 1 # additional stop token
- self.model1 = nn.ModuleList([
- TransformerBlock(
- d_model=d_model,
- nhead=nhead,
- dim_feedforward=dim_feedforward,
- attention_dropout_rate=dropout,
- residual_dropout_rate=dropout,
- with_self_attn=True,
- with_cross_attn=False,
- ) for i in range(num_layers)
- ])
- self.pos_encoder_tfm = PositionalEncoding(dim=d_model,
- dropout=0,
- max_len=1024)
- self.model2_vis = PositionAttention(
- max_length=self.max_length, # additional stop token
- in_channels=d_model,
- num_channels=d_model // 8,
- mode='nearest',
- )
- self.cls_vis = nn.Linear(d_model, num_classes)
- self.cls_sem = nn.Linear(d_model, num_classes)
- self.w_att = nn.Linear(2 * d_model, d_model)
- v_token = torch.empty((1, d_model))
- self.v_token = nn.Parameter(v_token)
- torch.nn.init.uniform_(self.v_token, -0.001, 0.001)
- self.cls = nn.Linear(d_model, num_classes)
- def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None):
- """
- Args:
- l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
- v_feature: (N, E, H, W)
- lengths_l: (N,)
- v_attn: (N, T, H, W)
- l_logits: (N, T, C)
- texts: (N, T, C)
- """
- N, E, H, W = v_feature.size()
- v_feature = v_feature.flatten(2, 3).transpose(1, 2) #(N, H*W, E)
- v_attn = v_attn.flatten(2, 3) # (N, T, H*W)
- if self.training:
- for idx, length in enumerate(lengths_l):
- if np.random.random() <= self.mask_example_prob:
- l_idx = np.random.randint(int(length))
- v_random_idx = v_attn[idx, l_idx].argsort(
- descending=True).cpu().numpy()[:self.num_vis_mask, ]
- v_random_idx = v_random_idx[np.random.random(
- v_random_idx.shape) <= self.mask_candidate_prob]
- v_feature[idx, v_random_idx] = self.v_token
- zeros = v_feature.new_zeros((N, H * W, E)) # (N, H*W, E)
- base_pos = self.pos_encoder_tfm(zeros) # (N, H*W, E)
- base_pos = torch.bmm(v_attn, base_pos) # (N, T, E)
- l_feature = l_feature + base_pos
- sv_feature = torch.cat((v_feature, l_feature), dim=1) # (H*W+T, N, E)
- for decoder_layer in self.model1:
- sv_feature = decoder_layer(sv_feature) # (H*W+T, N, E)
- sv_to_v_feature = sv_feature[:, :H * W] # (N, H*W, E)
- sv_to_s_feature = sv_feature[:, H * W:] # (N, T, E)
- sv_to_v_feature = sv_to_v_feature.transpose(1, 2).reshape(N, E, H, W)
- sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E)
- sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C)
- pt_v_lengths = _get_length(sv_to_v_logits) # (N,)
- sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C)
- pt_s_lengths = _get_length(sv_to_s_logits) # (N,)
- f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2)
- f_att = torch.sigmoid(self.w_att(f))
- output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature
- logits = self.cls(output) # (N, T, C)
- pt_lengths = _get_length(logits)
- return {
- 'logits': logits,
- 'pt_lengths': pt_lengths,
- 'v_logits': sv_to_v_logits,
- 'pt_v_lengths': pt_v_lengths,
- 's_logits': sv_to_s_logits,
- 'pt_s_lengths': pt_s_lengths,
- 'name': 'alignment'
- }
- class MATRNDecoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- nhead=8,
- num_layers=3,
- dim_feedforward=2048,
- dropout=0.1,
- max_length=25,
- iter_size=3,
- **kwargs):
- super().__init__()
- self.max_length = max_length + 1
- d_model = in_channels
- self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
- self.encoder = nn.ModuleList([
- TransformerBlock(
- d_model=d_model,
- nhead=nhead,
- dim_feedforward=dim_feedforward,
- attention_dropout_rate=dropout,
- residual_dropout_rate=dropout,
- with_self_attn=True,
- with_cross_attn=False,
- ) for _ in range(num_layers)
- ])
- self.decoder = PositionAttention(
- max_length=self.max_length, # additional stop token
- in_channels=d_model,
- num_channels=d_model // 8,
- mode='nearest',
- )
- self.out_channels = out_channels
- self.cls = nn.Linear(d_model, self.out_channels)
- self.iter_size = iter_size
- if iter_size > 0:
- self.language = BCNLanguage(
- d_model=d_model,
- nhead=nhead,
- num_layers=4,
- dim_feedforward=dim_feedforward,
- dropout=dropout,
- max_length=max_length,
- num_classes=self.out_channels,
- )
- # alignment
- self.semantic_visual = BaseSemanticVisual_backbone_feature(
- d_model=d_model,
- nhead=nhead,
- num_layers=2,
- dim_feedforward=dim_feedforward,
- max_length=max_length,
- num_classes=self.out_channels)
- def forward(self, x, data=None):
- # bs, c, h, w
- x = x.permute([0, 2, 3, 1]) # bs, h, w, c
- _, H, W, C = x.shape
- # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
- feature = x.flatten(1, 2) # bs, h*w, c
- feature = self.pos_encoder(feature) # bs, h*w, c
- for encoder_layer in self.encoder:
- feature = encoder_layer(feature)
- # bs, h*w, c
- feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
- 2) # bs, c, h, w
- v_feature, v_attn_input = self.decoder(feature) # (bs[N], T, E)
- vis_logits = self.cls(v_feature) # (bs[N], T, E)
- align_lengths = _get_length(vis_logits)
- align_logits = vis_logits
- all_l_res, all_a_res = [], []
- for _ in range(self.iter_size):
- tokens = F.softmax(align_logits, dim=-1)
- lengths = torch.clamp(
- align_lengths, 2,
- self.max_length) # TODO: move to language model
- l_feature, l_logits = self.language(tokens, lengths)
- all_l_res.append(l_logits)
- # alignment
- lengths_l = _get_length(l_logits)
- lengths_l.clamp_(2, self.max_length)
- a_res = self.semantic_visual(l_feature,
- feature,
- lengths_l=lengths_l,
- v_attn=v_attn_input)
- a_v_res = a_res['v_logits']
- # {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'],
- # 'name': 'alignment'}
- all_a_res.append(a_v_res)
- a_s_res = a_res['s_logits']
- # {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'],
- # 'name': 'alignment'}
- align_logits = a_res['logits']
- all_a_res.append(a_s_res)
- all_a_res.append(align_logits)
- align_lengths = a_res['pt_lengths']
- if self.training:
- return {
- 'align': all_a_res,
- 'lang': all_l_res,
- 'vision': vis_logits
- }
- else:
- return F.softmax(align_logits, -1)
|