mgp_decoder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. '''
  2. This code is refer from:
  3. https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
  4. '''
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. class TokenLearner(nn.Module):
  9. def __init__(self, input_embed_dim, out_token=30):
  10. super().__init__()
  11. self.token_norm = nn.LayerNorm(input_embed_dim)
  12. self.tokenLearner = nn.Sequential(
  13. nn.Conv2d(input_embed_dim,
  14. input_embed_dim,
  15. kernel_size=(1, 1),
  16. stride=1,
  17. groups=8,
  18. bias=False),
  19. nn.Conv2d(input_embed_dim,
  20. out_token,
  21. kernel_size=(1, 1),
  22. stride=1,
  23. bias=False))
  24. self.feat = nn.Conv2d(input_embed_dim,
  25. input_embed_dim,
  26. kernel_size=(1, 1),
  27. stride=1,
  28. groups=8,
  29. bias=False)
  30. self.norm = nn.LayerNorm(input_embed_dim)
  31. def forward(self, x):
  32. x = self.token_norm(x) # [bs, 257, 768]
  33. x = x.transpose(1, 2).unsqueeze(-1) # [bs, 768, 257, 1]
  34. selected = self.tokenLearner(x) # [bs, 27, 257, 1].
  35. selected = selected.flatten(2) # [bs, 27, 257].
  36. selected = F.softmax(selected, dim=-1)
  37. feat = self.feat(x) # [bs, 768, 257, 1].
  38. feat = feat.flatten(2).transpose(1, 2) # [bs, 257, 768]
  39. x = torch.einsum('...si,...id->...sd', selected, feat) # [bs, 27, 768]
  40. x = self.norm(x)
  41. return selected, x
  42. class MGPDecoder(nn.Module):
  43. def __init__(self,
  44. in_channels,
  45. out_channels,
  46. max_len=25,
  47. only_char=False,
  48. *args,
  49. **kwargs):
  50. super().__init__(*args, **kwargs)
  51. num_classes = out_channels
  52. embed_dim = in_channels
  53. self.batch_max_length = max_len + 2
  54. self.char_tokenLearner = TokenLearner(embed_dim, self.batch_max_length)
  55. self.char_head = nn.Linear(
  56. embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  57. self.only_char = only_char
  58. if not only_char:
  59. self.bpe_tokenLearner = TokenLearner(embed_dim,
  60. self.batch_max_length)
  61. self.wp_tokenLearner = TokenLearner(embed_dim,
  62. self.batch_max_length)
  63. self.bpe_head = nn.Linear(
  64. embed_dim, 50257) if num_classes > 0 else nn.Identity()
  65. self.wp_head = nn.Linear(
  66. embed_dim, 30522) if num_classes > 0 else nn.Identity()
  67. def forward(self, x, data=None):
  68. # attens = []
  69. # char
  70. char_attn, x_char = self.char_tokenLearner(x)
  71. x_char = self.char_head(x_char)
  72. char_out = x_char
  73. # attens = [char_attn]
  74. if not self.only_char:
  75. # bpe
  76. bpe_attn, x_bpe = self.bpe_tokenLearner(x)
  77. bpe_out = self.bpe_head(x_bpe)
  78. # attens += [bpe_attn]
  79. # wp
  80. wp_attn, x_wp = self.wp_tokenLearner(x)
  81. wp_out = self.wp_head(x_wp)
  82. return [char_out, bpe_out, wp_out] if self.training else [
  83. F.softmax(char_out, -1),
  84. F.softmax(bpe_out, -1),
  85. F.softmax(wp_out, -1)
  86. ]
  87. # attens += [wp_attn]
  88. return char_out if self.training else F.softmax(char_out, -1)