visionlan_decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import torch
  2. import torch.nn as nn
  3. from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
  4. class Transformer_Encoder(nn.Module):
  5. def __init__(
  6. self,
  7. n_layers=3,
  8. n_head=8,
  9. d_model=512,
  10. d_inner=2048,
  11. dropout=0.1,
  12. n_position=256,
  13. ):
  14. super(Transformer_Encoder, self).__init__()
  15. self.pe = PositionalEncoding(dropout=dropout,
  16. dim=d_model,
  17. max_len=n_position)
  18. self.layer_stack = nn.ModuleList([
  19. TransformerBlock(d_model, n_head, d_inner) for _ in range(n_layers)
  20. ])
  21. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  22. def forward(self, enc_output, src_mask):
  23. enc_output = self.pe(enc_output) # position embeding
  24. for enc_layer in self.layer_stack:
  25. enc_output = enc_layer(enc_output, self_mask=src_mask)
  26. enc_output = self.layer_norm(enc_output)
  27. return enc_output
  28. class PP_layer(nn.Module):
  29. def __init__(self, n_dim=512, N_max_character=25, n_position=256):
  30. super(PP_layer, self).__init__()
  31. self.character_len = N_max_character
  32. self.f0_embedding = nn.Embedding(N_max_character, n_dim)
  33. self.w0 = nn.Linear(N_max_character, n_position)
  34. self.wv = nn.Linear(n_dim, n_dim)
  35. self.we = nn.Linear(n_dim, N_max_character)
  36. self.active = nn.Tanh()
  37. self.softmax = nn.Softmax(dim=2)
  38. def forward(self, enc_output):
  39. reading_order = torch.arange(self.character_len,
  40. dtype=torch.long,
  41. device=enc_output.device)
  42. reading_order = reading_order.unsqueeze(0).expand(
  43. enc_output.shape[0], -1) # (S,) -> (B, S)
  44. reading_order = self.f0_embedding(reading_order) # b,25,512
  45. # calculate attention
  46. t = self.w0(reading_order.transpose(1, 2)) # b,512,256
  47. t = self.active(t.transpose(1, 2) + self.wv(enc_output)) # b,256,512
  48. t = self.we(t) # b,256,25
  49. t = self.softmax(t.transpose(1, 2)) # b,25,256
  50. g_output = torch.bmm(t, enc_output) # b,25,512
  51. return g_output
  52. class Prediction(nn.Module):
  53. def __init__(
  54. self,
  55. n_dim=512,
  56. n_class=37,
  57. N_max_character=25,
  58. n_position=256,
  59. ):
  60. super(Prediction, self).__init__()
  61. self.pp = PP_layer(n_dim=n_dim,
  62. N_max_character=N_max_character,
  63. n_position=n_position)
  64. self.pp_share = PP_layer(n_dim=n_dim,
  65. N_max_character=N_max_character,
  66. n_position=n_position)
  67. self.w_vrm = nn.Linear(n_dim, n_class) # output layer
  68. self.w_share = nn.Linear(n_dim, n_class) # output layer
  69. self.nclass = n_class
  70. def forward(self, cnn_feature, f_res, f_sub, is_Train=False, use_mlm=True):
  71. if is_Train:
  72. if not use_mlm:
  73. g_output = self.pp(cnn_feature) # b,25,512
  74. g_output = self.w_vrm(g_output)
  75. f_res = 0
  76. f_sub = 0
  77. return g_output, f_res, f_sub
  78. g_output = self.pp(cnn_feature) # b,25,512
  79. f_res = self.pp_share(f_res)
  80. f_sub = self.pp_share(f_sub)
  81. g_output = self.w_vrm(g_output)
  82. f_res = self.w_share(f_res)
  83. f_sub = self.w_share(f_sub)
  84. return g_output, f_res, f_sub
  85. else:
  86. g_output = self.pp(cnn_feature) # b,25,512
  87. g_output = self.w_vrm(g_output)
  88. return g_output
  89. class MLM(nn.Module):
  90. """Architecture of MLM."""
  91. def __init__(
  92. self,
  93. n_dim=512,
  94. n_position=256,
  95. n_head=8,
  96. dim_feedforward=2048,
  97. max_text_length=25,
  98. ):
  99. super(MLM, self).__init__()
  100. self.MLM_SequenceModeling_mask = Transformer_Encoder(
  101. n_layers=2,
  102. n_head=n_head,
  103. d_model=n_dim,
  104. d_inner=dim_feedforward,
  105. n_position=n_position,
  106. )
  107. self.MLM_SequenceModeling_WCL = Transformer_Encoder(
  108. n_layers=1,
  109. n_head=n_head,
  110. d_model=n_dim,
  111. d_inner=dim_feedforward,
  112. n_position=n_position,
  113. )
  114. self.pos_embedding = nn.Embedding(max_text_length, n_dim)
  115. self.w0_linear = nn.Linear(1, n_position)
  116. self.wv = nn.Linear(n_dim, n_dim)
  117. self.active = nn.Tanh()
  118. self.we = nn.Linear(n_dim, 1)
  119. self.sigmoid = nn.Sigmoid()
  120. def forward(self, input, label_pos):
  121. # transformer unit for generating mask_c
  122. feature_v_seq = self.MLM_SequenceModeling_mask(input, src_mask=None)
  123. # position embedding layer
  124. pos_emb = self.pos_embedding(label_pos.long())
  125. pos_emb = self.w0_linear(torch.unsqueeze(pos_emb,
  126. dim=2)).transpose(1, 2)
  127. # fusion position embedding with features V & generate mask_c
  128. att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
  129. att_map_sub = self.we(att_map_sub) # b,256,1
  130. att_map_sub = self.sigmoid(att_map_sub.transpose(1, 2)) # b,1,256
  131. # WCL
  132. # generate inputs for WCL
  133. f_res = input * (1 - att_map_sub.transpose(1, 2)
  134. ) # second path with remaining string
  135. f_sub = input * (att_map_sub.transpose(1, 2)
  136. ) # first path with occluded character
  137. # transformer units in WCL
  138. f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
  139. f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
  140. return f_res, f_sub, att_map_sub
  141. class MLM_VRM(nn.Module):
  142. def __init__(
  143. self,
  144. n_layers=3,
  145. n_position=256,
  146. n_dim=512,
  147. n_head=8,
  148. dim_feedforward=2048,
  149. max_text_length=25,
  150. nclass=37,
  151. ):
  152. super(MLM_VRM, self).__init__()
  153. self.MLM = MLM(
  154. n_dim=n_dim,
  155. n_position=n_position,
  156. n_head=n_head,
  157. dim_feedforward=dim_feedforward,
  158. max_text_length=max_text_length,
  159. )
  160. self.SequenceModeling = Transformer_Encoder(
  161. n_layers=n_layers,
  162. n_head=n_head,
  163. d_model=n_dim,
  164. d_inner=dim_feedforward,
  165. n_position=n_position,
  166. )
  167. self.Prediction = Prediction(
  168. n_dim=n_dim,
  169. n_position=n_position,
  170. N_max_character=max_text_length + 1,
  171. n_class=nclass,
  172. ) # N_max_character = 1 eos + 25 characters
  173. self.nclass = nclass
  174. self.max_text_length = max_text_length
  175. def forward(self, input, label_pos, training_step, is_Train=False):
  176. nT = self.max_text_length
  177. b, c, h, w = input.shape
  178. input = input.reshape(b, c, -1)
  179. input = input.transpose(1, 2)
  180. if is_Train:
  181. if training_step == 'LF_1':
  182. f_res = 0
  183. f_sub = 0
  184. input = self.SequenceModeling(input, src_mask=None)
  185. text_pre, text_rem, text_mas = self.Prediction(input,
  186. f_res,
  187. f_sub,
  188. is_Train=True,
  189. use_mlm=False)
  190. return text_pre, text_pre, text_pre
  191. elif training_step == 'LF_2':
  192. # MLM
  193. f_res, f_sub, mask_c = self.MLM(input, label_pos)
  194. input = self.SequenceModeling(input, src_mask=None)
  195. text_pre, text_rem, text_mas = self.Prediction(input,
  196. f_res,
  197. f_sub,
  198. is_Train=True)
  199. return text_pre, text_rem, text_mas
  200. elif training_step == 'LA':
  201. # MLM
  202. f_res, f_sub, mask_c = self.MLM(input, label_pos)
  203. # use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
  204. # ratio controls the occluded number in a batch
  205. ratio = 2
  206. character_mask = torch.zeros_like(mask_c)
  207. character_mask[0:b // ratio, :, :] = mask_c[0:b // ratio, :, :]
  208. input = input * (1 - character_mask.transpose(1, 2))
  209. # VRM
  210. # transformer unit for VRM
  211. input = self.SequenceModeling(input, src_mask=None)
  212. # prediction layer for MLM and VSR
  213. text_pre, text_rem, text_mas = self.Prediction(input,
  214. f_res,
  215. f_sub,
  216. is_Train=True)
  217. return text_pre, text_rem, text_mas
  218. else: # VRM is only used in the testing stage
  219. f_res = 0
  220. f_sub = 0
  221. contextual_feature = self.SequenceModeling(input, src_mask=None)
  222. C = self.Prediction(contextual_feature,
  223. f_res,
  224. f_sub,
  225. is_Train=False,
  226. use_mlm=False)
  227. C = C.transpose(1, 0) # (25, b, 38))
  228. out_res = torch.zeros(nT, b, self.nclass).type_as(input.data)
  229. out_length = torch.zeros(b).type_as(input.data)
  230. now_step = 0
  231. while 0 in out_length and now_step < nT:
  232. tmp_result = C[now_step, :, :]
  233. out_res[now_step] = tmp_result
  234. tmp_result = tmp_result.topk(1)[1].squeeze(dim=1)
  235. for j in range(b):
  236. if out_length[j] == 0 and tmp_result[j] == 0:
  237. out_length[j] = now_step + 1
  238. now_step += 1
  239. for j in range(0, b):
  240. if int(out_length[j]) == 0:
  241. out_length[j] = nT
  242. start = 0
  243. output = torch.zeros(int(out_length.sum()),
  244. self.nclass).type_as(input.data)
  245. for i in range(0, b):
  246. cur_length = int(out_length[i])
  247. output[start:start + cur_length] = out_res[0:cur_length, i, :]
  248. start += cur_length
  249. return output, out_length
  250. class VisionLANDecoder(nn.Module):
  251. def __init__(
  252. self,
  253. in_channels,
  254. out_channels,
  255. n_head=None,
  256. training_step='LA',
  257. n_layers=3,
  258. n_position=256,
  259. max_text_length=25,
  260. ):
  261. super(VisionLANDecoder, self).__init__()
  262. self.training_step = training_step
  263. n_dim = in_channels
  264. dim_feedforward = n_dim * 4
  265. n_head = n_head if n_head is not None else n_dim // 32
  266. self.MLM_VRM = MLM_VRM(
  267. n_layers=n_layers,
  268. n_position=n_position,
  269. n_dim=n_dim,
  270. n_head=n_head,
  271. dim_feedforward=dim_feedforward,
  272. max_text_length=max_text_length,
  273. nclass=out_channels + 1,
  274. )
  275. def forward(self, x, data=None):
  276. # MLM + VRM
  277. if self.training:
  278. label_pos = data[-2]
  279. text_pre, text_rem, text_mas = self.MLM_VRM(x,
  280. label_pos,
  281. self.training_step,
  282. is_Train=True)
  283. return text_pre, text_rem, text_mas
  284. else:
  285. output, out_length = self.MLM_VRM(x,
  286. None,
  287. self.training_step,
  288. is_Train=False)
  289. return output, out_length