ctc_decoder.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from openrec.modeling.encoders.svtrnet import (
  5. Block,
  6. ConvBNLayer,
  7. kaiming_normal_,
  8. trunc_normal_,
  9. zeros_,
  10. ones_,
  11. )
  12. class Swish(nn.Module):
  13. def __init__(self):
  14. super().__init__()
  15. def forward(self, x):
  16. return x * F.sigmoid(x)
  17. class EncoderWithSVTR(nn.Module):
  18. def __init__(
  19. self,
  20. in_channels,
  21. dims=64, # XS
  22. depth=2,
  23. hidden_dims=120,
  24. use_guide=False,
  25. num_heads=8,
  26. qkv_bias=True,
  27. mlp_ratio=2.0,
  28. drop_rate=0.1,
  29. attn_drop_rate=0.1,
  30. drop_path=0.0,
  31. kernel_size=[3, 3],
  32. qk_scale=None,
  33. use_pool=True,
  34. ):
  35. super(EncoderWithSVTR, self).__init__()
  36. self.depth = depth
  37. self.use_guide = use_guide
  38. self.use_pool = use_pool
  39. self.conv1 = ConvBNLayer(
  40. in_channels,
  41. in_channels // 8,
  42. kernel_size=kernel_size,
  43. padding=[kernel_size[0] // 2, kernel_size[1] // 2],
  44. act=Swish,
  45. bias=False)
  46. self.conv2 = ConvBNLayer(in_channels // 8,
  47. hidden_dims,
  48. kernel_size=1,
  49. act=Swish,
  50. bias=False)
  51. self.svtr_block = nn.ModuleList([
  52. Block(
  53. dim=hidden_dims,
  54. num_heads=num_heads,
  55. mixer='Global',
  56. HW=None,
  57. mlp_ratio=mlp_ratio,
  58. qkv_bias=qkv_bias,
  59. qk_scale=qk_scale,
  60. drop=drop_rate,
  61. act_layer=Swish,
  62. attn_drop=attn_drop_rate,
  63. drop_path=drop_path,
  64. norm_layer='nn.LayerNorm',
  65. eps=1e-05,
  66. prenorm=False,
  67. ) for i in range(depth)
  68. ])
  69. self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
  70. self.conv3 = ConvBNLayer(hidden_dims,
  71. in_channels,
  72. kernel_size=1,
  73. act=Swish,
  74. bias=False)
  75. # last conv-nxn, the input is concat of input tensor and conv3 output tensor
  76. self.conv4 = ConvBNLayer(
  77. 2 * in_channels,
  78. in_channels // 8,
  79. kernel_size=kernel_size,
  80. padding=[kernel_size[0] // 2, kernel_size[1] // 2],
  81. act=Swish,
  82. bias=False)
  83. self.conv1x1 = ConvBNLayer(in_channels // 8,
  84. dims,
  85. kernel_size=1,
  86. act=Swish,
  87. bias=False)
  88. self.out_channels = dims
  89. self.apply(self._init_weights)
  90. def _init_weights(self, m):
  91. if isinstance(m, nn.Linear):
  92. trunc_normal_(m.weight, mean=0, std=0.02)
  93. if isinstance(m, nn.Linear) and m.bias is not None:
  94. zeros_(m.bias)
  95. if isinstance(m, nn.LayerNorm):
  96. zeros_(m.bias)
  97. ones_(m.weight)
  98. if isinstance(m, nn.Conv2d):
  99. kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  100. def pool_h_2(self, x):
  101. # x: B, C, H, W
  102. x = x.mean(dim=2, keepdim=True)
  103. x = F.avg_pool2d(x, kernel_size=(1, 2))
  104. return x # B, C, 1, W//2
  105. def forward(self, x):
  106. if self.use_pool:
  107. x = self.pool_h_2(x)
  108. # for use guide
  109. if self.use_guide:
  110. z = x.detach()
  111. else:
  112. z = x
  113. # for short cut
  114. h = z
  115. # reduce dim
  116. z = self.conv1(z)
  117. z = self.conv2(z)
  118. # SVTR global block
  119. B, C, H, W = z.shape
  120. z = z.flatten(2).transpose(1, 2).contiguous()
  121. for blk in self.svtr_block:
  122. z = blk(z)
  123. z = self.norm(z)
  124. # last stage
  125. z = z.reshape(-1, H, W, C).permute(0, 3, 1, 2)
  126. z = self.conv3(z)
  127. z = torch.concat((h, z), dim=1)
  128. z = self.conv1x1(self.conv4(z))
  129. return z
  130. class CTCDecoder(nn.Module):
  131. def __init__(self,
  132. in_channels,
  133. out_channels=6625,
  134. mid_channels=None,
  135. return_feats=False,
  136. svtr_encoder=None,
  137. **kwargs):
  138. super(CTCDecoder, self).__init__()
  139. if svtr_encoder is not None:
  140. svtr_encoder['in_channels'] = in_channels
  141. self.svtr_encoder = EncoderWithSVTR(**svtr_encoder)
  142. in_channels = self.svtr_encoder.out_channels
  143. else:
  144. self.svtr_encoder = None
  145. if mid_channels is None:
  146. self.fc = nn.Linear(
  147. in_channels,
  148. out_channels,
  149. )
  150. else:
  151. self.fc1 = nn.Linear(
  152. in_channels,
  153. mid_channels,
  154. )
  155. self.fc2 = nn.Linear(
  156. mid_channels,
  157. out_channels,
  158. )
  159. self.out_channels = out_channels
  160. self.mid_channels = mid_channels
  161. self.return_feats = return_feats
  162. def forward(self, x, data=None):
  163. if self.svtr_encoder is not None:
  164. x = self.svtr_encoder(x)
  165. x = x.flatten(2).transpose(1, 2)
  166. if self.mid_channels is None:
  167. predicts = self.fc(x)
  168. else:
  169. x = self.fc1(x)
  170. predicts = self.fc2(x)
  171. if self.return_feats:
  172. result = (x, predicts)
  173. else:
  174. result = predicts
  175. if not self.training:
  176. predicts = F.softmax(predicts, dim=2)
  177. result = predicts
  178. return result