robustscanner_decoder.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class RobustScannerDecoder(nn.Module):
  6. def __init__(
  7. self,
  8. out_channels, # 90 + unknown + start + padding
  9. in_channels,
  10. enc_outchannles=128,
  11. hybrid_dec_rnn_layers=2,
  12. hybrid_dec_dropout=0,
  13. position_dec_rnn_layers=2,
  14. max_len=25,
  15. mask=True,
  16. encode_value=False,
  17. **kwargs):
  18. super(RobustScannerDecoder, self).__init__()
  19. start_idx = out_channels - 2
  20. padding_idx = out_channels - 1
  21. end_idx = 0
  22. # encoder module
  23. self.encoder = ChannelReductionEncoder(in_channels=in_channels,
  24. out_channels=enc_outchannles)
  25. self.max_text_length = max_len + 1
  26. self.mask = mask
  27. # decoder module
  28. self.decoder = Decoder(
  29. num_classes=out_channels,
  30. dim_input=in_channels,
  31. dim_model=enc_outchannles,
  32. hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
  33. hybrid_decoder_dropout=hybrid_dec_dropout,
  34. position_decoder_rnn_layers=position_dec_rnn_layers,
  35. max_len=max_len + 1,
  36. start_idx=start_idx,
  37. mask=mask,
  38. padding_idx=padding_idx,
  39. end_idx=end_idx,
  40. encode_value=encode_value)
  41. def forward(self, inputs, data=None):
  42. '''
  43. data: [label, valid_ratio, 'length']
  44. '''
  45. out_enc = self.encoder(inputs)
  46. bs = out_enc.shape[0]
  47. valid_ratios = None
  48. word_positions = torch.arange(0,
  49. self.max_text_length,
  50. device=inputs.device).unsqueeze(0).tile(
  51. [bs, 1])
  52. if self.mask:
  53. valid_ratios = data[-1]
  54. if self.training:
  55. max_len = data[1].max()
  56. label = data[0][:, :1 + max_len] # label
  57. final_out = self.decoder(inputs, out_enc, label, valid_ratios,
  58. word_positions[:, :1 + max_len])
  59. if not self.training:
  60. final_out = self.decoder(inputs,
  61. out_enc,
  62. label=None,
  63. valid_ratios=valid_ratios,
  64. word_positions=word_positions,
  65. train_mode=False)
  66. return final_out
  67. class BaseDecoder(nn.Module):
  68. def __init__(self, **kwargs):
  69. super().__init__()
  70. def forward_train(self, feat, out_enc, targets, img_metas):
  71. raise NotImplementedError
  72. def forward_test(self, feat, out_enc, img_metas):
  73. raise NotImplementedError
  74. def forward(self,
  75. feat,
  76. out_enc,
  77. label=None,
  78. valid_ratios=None,
  79. word_positions=None,
  80. train_mode=True):
  81. self.train_mode = train_mode
  82. if train_mode:
  83. return self.forward_train(feat, out_enc, label, valid_ratios,
  84. word_positions)
  85. return self.forward_test(feat, out_enc, valid_ratios, word_positions)
  86. class ChannelReductionEncoder(nn.Module):
  87. """Change the channel number with a one by one convoluational layer.
  88. Args:
  89. in_channels (int): Number of input channels.
  90. out_channels (int): Number of output channels.
  91. """
  92. def __init__(self, in_channels, out_channels, **kwargs):
  93. super(ChannelReductionEncoder, self).__init__()
  94. weight = torch.nn.Parameter(
  95. torch.nn.init.xavier_normal_(torch.empty(out_channels, in_channels,
  96. 1, 1),
  97. gain=1.0))
  98. self.layer = nn.Conv2d(in_channels,
  99. out_channels,
  100. kernel_size=1,
  101. stride=1,
  102. padding=0)
  103. use_xavier_normal = 1
  104. if use_xavier_normal:
  105. self.layer.weight = weight
  106. def forward(self, feat):
  107. """
  108. Args:
  109. feat (Tensor): Image features with the shape of
  110. :math:`(N, C_{in}, H, W)`.
  111. Returns:
  112. Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
  113. """
  114. return self.layer(feat)
  115. def masked_fill(x, mask, value):
  116. y = torch.full(x.shape, value, x.dtype)
  117. return torch.where(mask, y, x)
  118. class DotProductAttentionLayer(nn.Module):
  119. def __init__(self, dim_model=None):
  120. super().__init__()
  121. self.scale = dim_model**-0.5 if dim_model is not None else 1.
  122. def forward(self, query, key, value, mask=None):
  123. query = query.permute(0, 2, 1)
  124. logits = query @ key * self.scale
  125. if mask is not None:
  126. n, seq_len = mask.size()
  127. mask = mask.view(n, 1, seq_len)
  128. logits = logits.masked_fill(mask, float('-inf'))
  129. weights = F.softmax(logits, dim=2)
  130. value = value.transpose(1, 2)
  131. glimpse = weights @ value
  132. glimpse = glimpse.permute(0, 2, 1).contiguous()
  133. return glimpse
  134. class SequenceAttentionDecoder(BaseDecoder):
  135. """Sequence attention decoder for RobustScanner.
  136. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  137. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  138. Args:
  139. num_classes (int): Number of output classes :math:`C`.
  140. rnn_layers (int): Number of RNN layers.
  141. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  142. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  143. same as encoder output vector ``out_enc``.
  144. max_seq_len (int): Maximum output sequence length :math:`T`.
  145. start_idx (int): The index of `<SOS>`.
  146. mask (bool): Whether to mask input features according to
  147. ``img_meta['valid_ratio']``.
  148. padding_idx (int): The index of `<PAD>`.
  149. dropout (float): Dropout rate.
  150. return_feature (bool): Return feature or logits as the result.
  151. encode_value (bool): Whether to use the output of encoder ``out_enc``
  152. as `value` of attention layer. If False, the original feature
  153. ``feat`` will be used.
  154. Warning:
  155. This decoder will not predict the final class which is assumed to be
  156. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  157. is also ignored by loss as specified in
  158. :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
  159. """
  160. def __init__(self,
  161. num_classes=None,
  162. rnn_layers=2,
  163. dim_input=512,
  164. dim_model=128,
  165. max_seq_len=40,
  166. start_idx=0,
  167. mask=True,
  168. padding_idx=None,
  169. dropout=0,
  170. return_feature=False,
  171. encode_value=False):
  172. super().__init__()
  173. self.num_classes = num_classes
  174. self.dim_input = dim_input
  175. self.dim_model = dim_model
  176. self.return_feature = return_feature
  177. self.encode_value = encode_value
  178. self.max_seq_len = max_seq_len
  179. self.start_idx = start_idx
  180. self.mask = mask
  181. self.embedding = nn.Embedding(self.num_classes,
  182. self.dim_model,
  183. padding_idx=padding_idx)
  184. self.sequence_layer = nn.LSTM(input_size=dim_model,
  185. hidden_size=dim_model,
  186. num_layers=rnn_layers,
  187. batch_first=True,
  188. dropout=dropout)
  189. self.attention_layer = DotProductAttentionLayer()
  190. self.prediction = None
  191. if not self.return_feature:
  192. pred_num_classes = num_classes - 1
  193. self.prediction = nn.Linear(
  194. dim_model if encode_value else dim_input, pred_num_classes)
  195. def forward_train(self, feat, out_enc, targets, valid_ratios):
  196. """
  197. Args:
  198. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  199. out_enc (Tensor): Encoder output of shape
  200. :math:`(N, D_m, H, W)`.
  201. targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
  202. character.
  203. valid_ratios (Tensor): valid length ratio of img.
  204. Returns:
  205. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  206. ``return_feature=False``. Otherwise it would be the hidden feature
  207. before the prediction projection layer, whose shape is
  208. :math:`(N, T, D_m)`.
  209. """
  210. tgt_embedding = self.embedding(targets)
  211. n, c_enc, h, w = out_enc.shape
  212. assert c_enc == self.dim_model
  213. _, c_feat, _, _ = feat.shape
  214. assert c_feat == self.dim_input
  215. _, len_q, c_q = tgt_embedding.shape
  216. assert c_q == self.dim_model
  217. assert len_q <= self.max_seq_len
  218. query, _ = self.sequence_layer(tgt_embedding)
  219. query = query.permute(0, 2, 1).contiguous()
  220. key = out_enc.view(n, c_enc, h * w)
  221. if self.encode_value:
  222. value = key
  223. else:
  224. value = feat.view(n, c_feat, h * w)
  225. mask = None
  226. if valid_ratios is not None:
  227. mask = query.new_zeros((n, h, w))
  228. for i, valid_ratio in enumerate(valid_ratios):
  229. valid_width = min(w, math.ceil(w * valid_ratio))
  230. mask[i, :, valid_width:] = 1
  231. mask = mask.bool()
  232. mask = mask.view(n, h * w)
  233. attn_out = self.attention_layer(query, key, value, mask)
  234. attn_out = attn_out.permute(0, 2, 1).contiguous()
  235. if self.return_feature:
  236. return attn_out
  237. out = self.prediction(attn_out)
  238. return out
  239. def forward_test(self, feat, out_enc, valid_ratios):
  240. """
  241. Args:
  242. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  243. out_enc (Tensor): Encoder output of shape
  244. :math:`(N, D_m, H, W)`.
  245. valid_ratios (Tensor): valid length ratio of img.
  246. Returns:
  247. Tensor: The output logit sequence tensor of shape
  248. :math:`(N, T, C-1)`.
  249. """
  250. batch_size = feat.shape[0]
  251. decode_sequence = (torch.ones((batch_size, self.max_seq_len),
  252. dtype=torch.int64,
  253. device=feat.device) * self.start_idx)
  254. outputs = []
  255. for i in range(self.max_seq_len):
  256. step_out = self.forward_test_step(feat, out_enc, decode_sequence,
  257. i, valid_ratios)
  258. outputs.append(step_out)
  259. max_idx = torch.argmax(step_out, dim=1, keepdim=False)
  260. if i < self.max_seq_len - 1:
  261. decode_sequence[:, i + 1] = max_idx
  262. outputs = torch.stack(outputs, 1)
  263. return outputs
  264. def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
  265. valid_ratios):
  266. """
  267. Args:
  268. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  269. out_enc (Tensor): Encoder output of shape
  270. :math:`(N, D_m, H, W)`.
  271. decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
  272. stores history decoding result.
  273. current_step (int): Current decoding step.
  274. valid_ratios (Tensor): valid length ratio of img
  275. Returns:
  276. Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
  277. tokens at current time step.
  278. """
  279. embed = self.embedding(decode_sequence)
  280. n, c_enc, h, w = out_enc.shape
  281. assert c_enc == self.dim_model
  282. _, c_feat, _, _ = feat.shape
  283. assert c_feat == self.dim_input
  284. _, _, c_q = embed.shape
  285. assert c_q == self.dim_model
  286. query, _ = self.sequence_layer(embed)
  287. query = query.transpose(1, 2)
  288. key = torch.reshape(out_enc, (n, c_enc, h * w))
  289. if self.encode_value:
  290. value = key
  291. else:
  292. value = torch.reshape(feat, (n, c_feat, h * w))
  293. mask = None
  294. if valid_ratios is not None:
  295. mask = query.new_zeros((n, h, w))
  296. for i, valid_ratio in enumerate(valid_ratios):
  297. valid_width = min(w, math.ceil(w * valid_ratio))
  298. mask[i, :, valid_width:] = 1
  299. mask = mask.bool()
  300. mask = mask.view(n, h * w)
  301. # [n, c, l]
  302. attn_out = self.attention_layer(query, key, value, mask)
  303. out = attn_out[:, :, current_step]
  304. if self.return_feature:
  305. return out
  306. out = self.prediction(out)
  307. out = F.softmax(out, dim=-1)
  308. return out
  309. class PositionAwareLayer(nn.Module):
  310. def __init__(self, dim_model, rnn_layers=2):
  311. super().__init__()
  312. self.dim_model = dim_model
  313. self.rnn = nn.LSTM(input_size=dim_model,
  314. hidden_size=dim_model,
  315. num_layers=rnn_layers,
  316. batch_first=True)
  317. self.mixer = nn.Sequential(
  318. nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
  319. padding=1), nn.ReLU(True),
  320. nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
  321. padding=1))
  322. def forward(self, img_feature):
  323. n, c, h, w = img_feature.shape
  324. rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
  325. rnn_input = rnn_input.view(n * h, w, c)
  326. rnn_output, _ = self.rnn(rnn_input)
  327. rnn_output = rnn_output.view(n, h, w, c)
  328. rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
  329. out = self.mixer(rnn_output)
  330. return out
  331. class PositionAttentionDecoder(BaseDecoder):
  332. """Position attention decoder for RobustScanner.
  333. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  334. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  335. Args:
  336. num_classes (int): Number of output classes :math:`C`.
  337. rnn_layers (int): Number of RNN layers.
  338. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  339. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  340. same as encoder output vector ``out_enc``.
  341. max_seq_len (int): Maximum output sequence length :math:`T`.
  342. mask (bool): Whether to mask input features according to
  343. ``img_meta['valid_ratio']``.
  344. return_feature (bool): Return feature or logits as the result.
  345. encode_value (bool): Whether to use the output of encoder ``out_enc``
  346. as `value` of attention layer. If False, the original feature
  347. ``feat`` will be used.
  348. Warning:
  349. This decoder will not predict the final class which is assumed to be
  350. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  351. is also ignored by loss
  352. """
  353. def __init__(self,
  354. num_classes=None,
  355. rnn_layers=2,
  356. dim_input=512,
  357. dim_model=128,
  358. max_seq_len=40,
  359. mask=True,
  360. return_feature=False,
  361. encode_value=False):
  362. super().__init__()
  363. self.num_classes = num_classes
  364. self.dim_input = dim_input
  365. self.dim_model = dim_model
  366. self.max_seq_len = max_seq_len
  367. self.return_feature = return_feature
  368. self.encode_value = encode_value
  369. self.mask = mask
  370. self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
  371. self.position_aware_module = PositionAwareLayer(
  372. self.dim_model, rnn_layers)
  373. self.attention_layer = DotProductAttentionLayer()
  374. self.prediction = None
  375. if not self.return_feature:
  376. pred_num_classes = num_classes - 1
  377. self.prediction = nn.Linear(
  378. dim_model if encode_value else dim_input, pred_num_classes)
  379. def _get_position_index(self, length, batch_size):
  380. position_index_list = []
  381. for i in range(batch_size):
  382. position_index = torch.range(0, length, step=1, dtype='int64')
  383. position_index_list.append(position_index)
  384. batch_position_index = torch.stack(position_index_list, dim=0)
  385. return batch_position_index
  386. def forward_train(self, feat, out_enc, targets, valid_ratios,
  387. position_index):
  388. """
  389. Args:
  390. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  391. out_enc (Tensor): Encoder output of shape
  392. :math:`(N, D_m, H, W)`.
  393. targets (dict): A dict with the key ``padded_targets``, a
  394. tensor of shape :math:`(N, T)`. Each element is the index of a
  395. character.
  396. valid_ratios (Tensor): valid length ratio of img.
  397. position_index (Tensor): The position of each word.
  398. Returns:
  399. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  400. ``return_feature=False``. Otherwise it will be the hidden feature
  401. before the prediction projection layer, whose shape is
  402. :math:`(N, T, D_m)`.
  403. """
  404. n, c_enc, h, w = out_enc.shape
  405. assert c_enc == self.dim_model
  406. _, c_feat, _, _ = feat.shape
  407. assert c_feat == self.dim_input
  408. _, len_q = targets.shape
  409. assert len_q <= self.max_seq_len
  410. position_out_enc = self.position_aware_module(out_enc)
  411. query = self.embedding(position_index)
  412. query = query.permute(0, 2, 1).contiguous()
  413. key = position_out_enc.view(n, c_enc, h * w)
  414. if self.encode_value:
  415. value = out_enc.view(n, c_enc, h * w)
  416. else:
  417. value = feat.view(n, c_feat, h * w)
  418. mask = None
  419. if valid_ratios is not None:
  420. mask = query.new_zeros((n, h, w))
  421. for i, valid_ratio in enumerate(valid_ratios):
  422. valid_width = min(w, math.ceil(w * valid_ratio))
  423. mask[i, :, valid_width:] = 1
  424. mask = mask.bool()
  425. mask = mask.view(n, h * w)
  426. attn_out = self.attention_layer(query, key, value, mask)
  427. attn_out = attn_out.permute(0, 2, 1).contiguous()
  428. if self.return_feature:
  429. return attn_out
  430. return self.prediction(attn_out)
  431. def forward_test(self, feat, out_enc, valid_ratios, position_index):
  432. """
  433. Args:
  434. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  435. out_enc (Tensor): Encoder output of shape
  436. :math:`(N, D_m, H, W)`.
  437. valid_ratios (Tensor): valid length ratio of img
  438. position_index (Tensor): The position of each word.
  439. Returns:
  440. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  441. ``return_feature=False``. Otherwise it would be the hidden feature
  442. before the prediction projection layer, whose shape is
  443. :math:`(N, T, D_m)`.
  444. """
  445. n, c_enc, h, w = out_enc.shape
  446. assert c_enc == self.dim_model
  447. _, c_feat, _, _ = feat.shape
  448. assert c_feat == self.dim_input
  449. position_out_enc = self.position_aware_module(out_enc)
  450. query = self.embedding(position_index)
  451. query = query.permute(0, 2, 1).contiguous()
  452. key = position_out_enc.view(n, c_enc, h * w)
  453. if self.encode_value:
  454. value = torch.reshape(out_enc, (n, c_enc, h * w))
  455. else:
  456. value = torch.reshape(feat, (n, c_feat, h * w))
  457. mask = None
  458. if valid_ratios is not None:
  459. mask = query.new_zeros((n, h, w))
  460. for i, valid_ratio in enumerate(valid_ratios):
  461. valid_width = min(w, math.ceil(w * valid_ratio))
  462. mask[i, :, valid_width:] = 1
  463. mask = mask.bool()
  464. mask = mask.view(n, h * w)
  465. attn_out = self.attention_layer(query, key, value, mask)
  466. attn_out = attn_out.transpose(1, 2) # [n, len_q, dim_v]
  467. if self.return_feature:
  468. return attn_out
  469. return self.prediction(attn_out)
  470. class RobustScannerFusionLayer(nn.Module):
  471. def __init__(self, dim_model, dim=-1):
  472. super(RobustScannerFusionLayer, self).__init__()
  473. self.dim_model = dim_model
  474. self.dim = dim
  475. self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
  476. def forward(self, x0, x1):
  477. assert x0.shape == x1.shape
  478. fusion_input = torch.concat((x0, x1), self.dim)
  479. output = self.linear_layer(fusion_input)
  480. output = F.glu(output, self.dim)
  481. return output
  482. class Decoder(BaseDecoder):
  483. """Decoder for RobustScanner.
  484. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  485. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  486. Args:
  487. num_classes (int): Number of output classes :math:`C`.
  488. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  489. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  490. same as encoder output vector ``out_enc``.
  491. max_seq_len (int): Maximum output sequence length :math:`T`.
  492. start_idx (int): The index of `<SOS>`.
  493. mask (bool): Whether to mask input features according to
  494. ``img_meta['valid_ratio']``.
  495. padding_idx (int): The index of `<PAD>`.
  496. encode_value (bool): Whether to use the output of encoder ``out_enc``
  497. as `value` of attention layer. If False, the original feature
  498. ``feat`` will be used.
  499. Warning:
  500. This decoder will not predict the final class which is assumed to be
  501. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  502. is also ignored by loss as specified in
  503. :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
  504. """
  505. def __init__(self,
  506. num_classes=None,
  507. dim_input=512,
  508. dim_model=128,
  509. hybrid_decoder_rnn_layers=2,
  510. hybrid_decoder_dropout=0,
  511. position_decoder_rnn_layers=2,
  512. max_len=40,
  513. start_idx=0,
  514. mask=True,
  515. padding_idx=None,
  516. end_idx=0,
  517. encode_value=False):
  518. super().__init__()
  519. self.num_classes = num_classes
  520. self.dim_input = dim_input
  521. self.dim_model = dim_model
  522. self.max_seq_len = max_len
  523. self.encode_value = encode_value
  524. self.start_idx = start_idx
  525. self.padding_idx = padding_idx
  526. self.end_idx = end_idx
  527. self.mask = mask
  528. # init hybrid decoder
  529. self.hybrid_decoder = SequenceAttentionDecoder(
  530. num_classes=num_classes,
  531. rnn_layers=hybrid_decoder_rnn_layers,
  532. dim_input=dim_input,
  533. dim_model=dim_model,
  534. max_seq_len=max_len,
  535. start_idx=start_idx,
  536. mask=mask,
  537. padding_idx=padding_idx,
  538. dropout=hybrid_decoder_dropout,
  539. encode_value=encode_value,
  540. return_feature=True)
  541. # init position decoder
  542. self.position_decoder = PositionAttentionDecoder(
  543. num_classes=num_classes,
  544. rnn_layers=position_decoder_rnn_layers,
  545. dim_input=dim_input,
  546. dim_model=dim_model,
  547. max_seq_len=max_len,
  548. mask=mask,
  549. encode_value=encode_value,
  550. return_feature=True)
  551. self.fusion_module = RobustScannerFusionLayer(
  552. self.dim_model if encode_value else dim_input)
  553. pred_num_classes = num_classes
  554. self.prediction = nn.Linear(dim_model if encode_value else dim_input,
  555. pred_num_classes)
  556. def forward_train(self, feat, out_enc, target, valid_ratios,
  557. word_positions):
  558. """
  559. Args:
  560. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  561. out_enc (Tensor): Encoder output of shape
  562. :math:`(N, D_m, H, W)`.
  563. target (dict): A dict with the key ``padded_targets``, a
  564. tensor of shape :math:`(N, T)`. Each element is the index of a
  565. character.
  566. valid_ratios (Tensor):
  567. word_positions (Tensor): The position of each word.
  568. Returns:
  569. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
  570. """
  571. hybrid_glimpse = self.hybrid_decoder.forward_train(
  572. feat, out_enc, target, valid_ratios)
  573. position_glimpse = self.position_decoder.forward_train(
  574. feat, out_enc, target, valid_ratios, word_positions)
  575. fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
  576. out = self.prediction(fusion_out)
  577. return out
  578. def forward_test(self, feat, out_enc, valid_ratios, word_positions):
  579. """
  580. Args:
  581. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  582. out_enc (Tensor): Encoder output of shape
  583. :math:`(N, D_m, H, W)`.
  584. valid_ratios (Tensor):
  585. word_positions (Tensor): The position of each word.
  586. Returns:
  587. Tensor: The output logit sequence tensor of shape
  588. :math:`(N, T, C-1)`.
  589. """
  590. seq_len = self.max_seq_len
  591. batch_size = feat.shape[0]
  592. decode_sequence = (torch.ones(
  593. (batch_size, seq_len), dtype=torch.int64, device=feat.device) *
  594. self.start_idx)
  595. position_glimpse = self.position_decoder.forward_test(
  596. feat, out_enc, valid_ratios, word_positions)
  597. outputs = []
  598. for i in range(seq_len):
  599. hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
  600. feat, out_enc, decode_sequence, i, valid_ratios)
  601. fusion_out = self.fusion_module(hybrid_glimpse_step,
  602. position_glimpse[:, i, :])
  603. char_out = self.prediction(fusion_out)
  604. char_out = F.softmax(char_out, -1)
  605. outputs.append(char_out)
  606. max_idx = torch.argmax(char_out, dim=1, keepdim=False)
  607. if i < seq_len - 1:
  608. decode_sequence[:, i + 1] = max_idx
  609. if (decode_sequence == self.end_idx).any(dim=-1).all():
  610. break
  611. outputs = torch.stack(outputs, 1)
  612. return outputs