123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- import torch
- import torch.nn as nn
- from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
- class Transformer_Encoder(nn.Module):
- def __init__(
- self,
- n_layers=3,
- n_head=8,
- d_model=512,
- d_inner=2048,
- dropout=0.1,
- n_position=256,
- ):
- super(Transformer_Encoder, self).__init__()
- self.pe = PositionalEncoding(dropout=dropout,
- dim=d_model,
- max_len=n_position)
- self.layer_stack = nn.ModuleList([
- TransformerBlock(d_model, n_head, d_inner) for _ in range(n_layers)
- ])
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
- def forward(self, enc_output, src_mask):
- enc_output = self.pe(enc_output) # position embeding
- for enc_layer in self.layer_stack:
- enc_output = enc_layer(enc_output, self_mask=src_mask)
- enc_output = self.layer_norm(enc_output)
- return enc_output
- class PP_layer(nn.Module):
- def __init__(self, n_dim=512, N_max_character=25, n_position=256):
- super(PP_layer, self).__init__()
- self.character_len = N_max_character
- self.f0_embedding = nn.Embedding(N_max_character, n_dim)
- self.w0 = nn.Linear(N_max_character, n_position)
- self.wv = nn.Linear(n_dim, n_dim)
- self.we = nn.Linear(n_dim, N_max_character)
- self.active = nn.Tanh()
- self.softmax = nn.Softmax(dim=2)
- def forward(self, enc_output):
- reading_order = torch.arange(self.character_len,
- dtype=torch.long,
- device=enc_output.device)
- reading_order = reading_order.unsqueeze(0).expand(
- enc_output.shape[0], -1) # (S,) -> (B, S)
- reading_order = self.f0_embedding(reading_order) # b,25,512
- # calculate attention
- t = self.w0(reading_order.transpose(1, 2)) # b,512,256
- t = self.active(t.transpose(1, 2) + self.wv(enc_output)) # b,256,512
- t = self.we(t) # b,256,25
- t = self.softmax(t.transpose(1, 2)) # b,25,256
- g_output = torch.bmm(t, enc_output) # b,25,512
- return g_output
- class Prediction(nn.Module):
- def __init__(
- self,
- n_dim=512,
- n_class=37,
- N_max_character=25,
- n_position=256,
- ):
- super(Prediction, self).__init__()
- self.pp = PP_layer(n_dim=n_dim,
- N_max_character=N_max_character,
- n_position=n_position)
- self.pp_share = PP_layer(n_dim=n_dim,
- N_max_character=N_max_character,
- n_position=n_position)
- self.w_vrm = nn.Linear(n_dim, n_class) # output layer
- self.w_share = nn.Linear(n_dim, n_class) # output layer
- self.nclass = n_class
- def forward(self, cnn_feature, f_res, f_sub, is_Train=False, use_mlm=True):
- if is_Train:
- if not use_mlm:
- g_output = self.pp(cnn_feature) # b,25,512
- g_output = self.w_vrm(g_output)
- f_res = 0
- f_sub = 0
- return g_output, f_res, f_sub
- g_output = self.pp(cnn_feature) # b,25,512
- f_res = self.pp_share(f_res)
- f_sub = self.pp_share(f_sub)
- g_output = self.w_vrm(g_output)
- f_res = self.w_share(f_res)
- f_sub = self.w_share(f_sub)
- return g_output, f_res, f_sub
- else:
- g_output = self.pp(cnn_feature) # b,25,512
- g_output = self.w_vrm(g_output)
- return g_output
- class MLM(nn.Module):
- """Architecture of MLM."""
- def __init__(
- self,
- n_dim=512,
- n_position=256,
- n_head=8,
- dim_feedforward=2048,
- max_text_length=25,
- ):
- super(MLM, self).__init__()
- self.MLM_SequenceModeling_mask = Transformer_Encoder(
- n_layers=2,
- n_head=n_head,
- d_model=n_dim,
- d_inner=dim_feedforward,
- n_position=n_position,
- )
- self.MLM_SequenceModeling_WCL = Transformer_Encoder(
- n_layers=1,
- n_head=n_head,
- d_model=n_dim,
- d_inner=dim_feedforward,
- n_position=n_position,
- )
- self.pos_embedding = nn.Embedding(max_text_length, n_dim)
- self.w0_linear = nn.Linear(1, n_position)
- self.wv = nn.Linear(n_dim, n_dim)
- self.active = nn.Tanh()
- self.we = nn.Linear(n_dim, 1)
- self.sigmoid = nn.Sigmoid()
- def forward(self, input, label_pos):
- # transformer unit for generating mask_c
- feature_v_seq = self.MLM_SequenceModeling_mask(input, src_mask=None)
- # position embedding layer
- pos_emb = self.pos_embedding(label_pos.long())
- pos_emb = self.w0_linear(torch.unsqueeze(pos_emb,
- dim=2)).transpose(1, 2)
- # fusion position embedding with features V & generate mask_c
- att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
- att_map_sub = self.we(att_map_sub) # b,256,1
- att_map_sub = self.sigmoid(att_map_sub.transpose(1, 2)) # b,1,256
- # WCL
- # generate inputs for WCL
- f_res = input * (1 - att_map_sub.transpose(1, 2)
- ) # second path with remaining string
- f_sub = input * (att_map_sub.transpose(1, 2)
- ) # first path with occluded character
- # transformer units in WCL
- f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
- f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
- return f_res, f_sub, att_map_sub
- class MLM_VRM(nn.Module):
- def __init__(
- self,
- n_layers=3,
- n_position=256,
- n_dim=512,
- n_head=8,
- dim_feedforward=2048,
- max_text_length=25,
- nclass=37,
- ):
- super(MLM_VRM, self).__init__()
- self.MLM = MLM(
- n_dim=n_dim,
- n_position=n_position,
- n_head=n_head,
- dim_feedforward=dim_feedforward,
- max_text_length=max_text_length,
- )
- self.SequenceModeling = Transformer_Encoder(
- n_layers=n_layers,
- n_head=n_head,
- d_model=n_dim,
- d_inner=dim_feedforward,
- n_position=n_position,
- )
- self.Prediction = Prediction(
- n_dim=n_dim,
- n_position=n_position,
- N_max_character=max_text_length + 1,
- n_class=nclass,
- ) # N_max_character = 1 eos + 25 characters
- self.nclass = nclass
- self.max_text_length = max_text_length
- def forward(self, input, label_pos, training_step, is_Train=False):
- nT = self.max_text_length
- b, c, h, w = input.shape
- input = input.reshape(b, c, -1)
- input = input.transpose(1, 2)
- if is_Train:
- if training_step == 'LF_1':
- f_res = 0
- f_sub = 0
- input = self.SequenceModeling(input, src_mask=None)
- text_pre, text_rem, text_mas = self.Prediction(input,
- f_res,
- f_sub,
- is_Train=True,
- use_mlm=False)
- return text_pre, text_pre, text_pre
- elif training_step == 'LF_2':
- # MLM
- f_res, f_sub, mask_c = self.MLM(input, label_pos)
- input = self.SequenceModeling(input, src_mask=None)
- text_pre, text_rem, text_mas = self.Prediction(input,
- f_res,
- f_sub,
- is_Train=True)
- return text_pre, text_rem, text_mas
- elif training_step == 'LA':
- # MLM
- f_res, f_sub, mask_c = self.MLM(input, label_pos)
- # use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
- # ratio controls the occluded number in a batch
- ratio = 2
- character_mask = torch.zeros_like(mask_c)
- character_mask[0:b // ratio, :, :] = mask_c[0:b // ratio, :, :]
- input = input * (1 - character_mask.transpose(1, 2))
- # VRM
- # transformer unit for VRM
- input = self.SequenceModeling(input, src_mask=None)
- # prediction layer for MLM and VSR
- text_pre, text_rem, text_mas = self.Prediction(input,
- f_res,
- f_sub,
- is_Train=True)
- return text_pre, text_rem, text_mas
- else: # VRM is only used in the testing stage
- f_res = 0
- f_sub = 0
- contextual_feature = self.SequenceModeling(input, src_mask=None)
- C = self.Prediction(contextual_feature,
- f_res,
- f_sub,
- is_Train=False,
- use_mlm=False)
- C = C.transpose(1, 0) # (25, b, 38))
- out_res = torch.zeros(nT, b, self.nclass).type_as(input.data)
- out_length = torch.zeros(b).type_as(input.data)
- now_step = 0
- while 0 in out_length and now_step < nT:
- tmp_result = C[now_step, :, :]
- out_res[now_step] = tmp_result
- tmp_result = tmp_result.topk(1)[1].squeeze(dim=1)
- for j in range(b):
- if out_length[j] == 0 and tmp_result[j] == 0:
- out_length[j] = now_step + 1
- now_step += 1
- for j in range(0, b):
- if int(out_length[j]) == 0:
- out_length[j] = nT
- start = 0
- output = torch.zeros(int(out_length.sum()),
- self.nclass).type_as(input.data)
- for i in range(0, b):
- cur_length = int(out_length[i])
- output[start:start + cur_length] = out_res[0:cur_length, i, :]
- start += cur_length
- return output, out_length
- class VisionLANDecoder(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- n_head=None,
- training_step='LA',
- n_layers=3,
- n_position=256,
- max_text_length=25,
- ):
- super(VisionLANDecoder, self).__init__()
- self.training_step = training_step
- n_dim = in_channels
- dim_feedforward = n_dim * 4
- n_head = n_head if n_head is not None else n_dim // 32
- self.MLM_VRM = MLM_VRM(
- n_layers=n_layers,
- n_position=n_position,
- n_dim=n_dim,
- n_head=n_head,
- dim_feedforward=dim_feedforward,
- max_text_length=max_text_length,
- nclass=out_channels + 1,
- )
- def forward(self, x, data=None):
- # MLM + VRM
- if self.training:
- label_pos = data[-2]
- text_pre, text_rem, text_mas = self.MLM_VRM(x,
- label_pos,
- self.training_step,
- is_Train=True)
- return text_pre, text_rem, text_mas
- else:
- output, out_length = self.MLM_VRM(x,
- None,
- self.training_step,
- is_Train=False)
- return output, out_length
|