123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import copy
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .abinet_decoder import PositionAttention
- from .nrtr_decoder import PositionalEncoding, TransformerBlock
- class Trans(nn.Module):
- def __init__(self, dim, nhead, dim_feedforward, dropout, num_layers):
- super().__init__()
- self.d_model = dim
- self.nhead = nhead
- self.pos_encoder = PositionalEncoding(dropout=0.0,
- dim=self.d_model,
- max_len=512)
- self.transformer = nn.ModuleList([
- TransformerBlock(
- dim,
- nhead,
- dim_feedforward,
- attention_dropout_rate=dropout,
- residual_dropout_rate=dropout,
- with_self_attn=True,
- with_cross_attn=False,
- ) for i in range(num_layers)
- ])
- def forward(self, feature, attn_map=None, use_mask=False):
- n, c, h, w = feature.shape
- feature = feature.flatten(2).transpose(1, 2)
- if use_mask:
- _, t, h, w = attn_map.shape
- location_mask = (attn_map.view(n, t, -1).transpose(1, 2) >
- 0.05).type(torch.float) # n,hw,t
- location_mask = location_mask.bmm(location_mask.transpose(
- 1, 2)) # n, hw, hw
- location_mask = location_mask.new_zeros(
- (h * w, h * w)).masked_fill(location_mask > 0, float('-inf'))
- location_mask = location_mask.unsqueeze(1) # n, 1, hw, hw
- else:
- location_mask = None
- feature = self.pos_encoder(feature)
- for layer in self.transformer:
- feature = layer(feature, self_mask=location_mask)
- feature = feature.transpose(1, 2).view(n, c, h, w)
- return feature, location_mask
- def _get_clones(module, N):
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
- class LPVDecoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- num_layer=3,
- max_len=25,
- use_mask=False,
- dim_feedforward=1024,
- nhead=8,
- dropout=0.1,
- trans_layer=2):
- super().__init__()
- self.use_mask = use_mask
- self.max_len = max_len
- attn_layer = PositionAttention(max_length=max_len + 1,
- mode='nearest',
- in_channels=in_channels,
- num_channels=in_channels // 8)
- trans_layer = Trans(dim=in_channels,
- nhead=nhead,
- dim_feedforward=dim_feedforward,
- dropout=dropout,
- num_layers=trans_layer)
- cls_layer = nn.Linear(in_channels, out_channels - 2)
- self.attention = _get_clones(attn_layer, num_layer)
- self.trans = _get_clones(trans_layer, num_layer - 1)
- self.cls = _get_clones(cls_layer, num_layer)
- def forward(self, x, data=None):
- if data is not None:
- max_len = data[1].max()
- else:
- max_len = self.max_len
- features = x # (N, E, H, W)
- attn_vecs, attn_scores_map = self.attention[0](features)
- attn_vecs = attn_vecs[:, :max_len + 1, :]
- if not self.training:
- for i in range(1, len(self.attention)):
- features, mask = self.trans[i - 1](features,
- attn_scores_map,
- use_mask=self.use_mask)
- attn_vecs, attn_scores_map = self.attention[i](
- features, attn_vecs) # (N, T, E), (N, T, H, W)
- return F.softmax(self.cls[-1](attn_vecs), -1)
- else:
- logits = []
- logit = self.cls[0](attn_vecs) # (N, T, C)
- logits.append(logit)
- for i in range(1, len(self.attention)):
- features, mask = self.trans[i - 1](features,
- attn_scores_map,
- use_mask=self.use_mask)
- attn_vecs, attn_scores_map = self.attention[i](
- features, attn_vecs) # (N, T, E), (N, T, H, W)
- logit = self.cls[i](attn_vecs) # (N, T, C)
- logits.append(logit)
- return logits
|