svtrv2_lnconv.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
  5. from openrec.modeling.common import DropPath, Identity, Mlp
  6. class ConvBNLayer(nn.Module):
  7. def __init__(
  8. self,
  9. in_channels,
  10. out_channels,
  11. kernel_size=3,
  12. stride=1,
  13. padding=0,
  14. bias=False,
  15. groups=1,
  16. act=nn.GELU,
  17. ):
  18. super().__init__()
  19. self.conv = nn.Conv2d(
  20. in_channels=in_channels,
  21. out_channels=out_channels,
  22. kernel_size=kernel_size,
  23. stride=stride,
  24. padding=padding,
  25. groups=groups,
  26. bias=bias,
  27. )
  28. self.norm = nn.BatchNorm2d(out_channels)
  29. self.act = act()
  30. def forward(self, inputs):
  31. out = self.conv(inputs)
  32. out = self.norm(out)
  33. out = self.act(out)
  34. return out
  35. class Attention(nn.Module):
  36. def __init__(
  37. self,
  38. dim,
  39. num_heads=8,
  40. qkv_bias=False,
  41. qk_scale=None,
  42. attn_drop=0.0,
  43. proj_drop=0.0,
  44. ):
  45. super().__init__()
  46. self.num_heads = num_heads
  47. self.dim = dim
  48. self.head_dim = dim // num_heads
  49. self.scale = qk_scale or self.head_dim**-0.5
  50. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  51. self.attn_drop = nn.Dropout(attn_drop)
  52. self.proj = nn.Linear(dim, dim)
  53. self.proj_drop = nn.Dropout(proj_drop)
  54. def forward(self, x):
  55. B, N, _ = x.shape
  56. qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
  57. self.head_dim).permute(2, 0, 3, 1, 4))
  58. q, k, v = qkv.unbind(0)
  59. attn = q @ k.transpose(-2, -1) * self.scale
  60. attn = attn.softmax(dim=-1)
  61. attn = self.attn_drop(attn)
  62. x = attn @ v
  63. x = x.transpose(1, 2).reshape(B, N, self.dim)
  64. x = self.proj(x)
  65. x = self.proj_drop(x)
  66. return x
  67. class Block(nn.Module):
  68. def __init__(
  69. self,
  70. dim,
  71. num_heads,
  72. mlp_ratio=4.0,
  73. qkv_bias=False,
  74. qk_scale=None,
  75. drop=0.0,
  76. attn_drop=0.0,
  77. drop_path=0.0,
  78. act_layer=nn.GELU,
  79. norm_layer=nn.LayerNorm,
  80. eps=1e-6,
  81. ):
  82. super().__init__()
  83. mlp_hidden_dim = int(dim * mlp_ratio)
  84. self.norm1 = norm_layer(dim, eps=eps)
  85. self.mixer = Attention(
  86. dim,
  87. num_heads=num_heads,
  88. qkv_bias=qkv_bias,
  89. qk_scale=qk_scale,
  90. attn_drop=attn_drop,
  91. proj_drop=drop,
  92. )
  93. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  94. self.norm2 = norm_layer(dim, eps=eps)
  95. self.mlp = Mlp(
  96. in_features=dim,
  97. hidden_features=mlp_hidden_dim,
  98. act_layer=act_layer,
  99. drop=drop,
  100. )
  101. def forward(self, x):
  102. x = self.norm1(x + self.drop_path(self.mixer(x)))
  103. x = self.norm2(x + self.drop_path(self.mlp(x)))
  104. return x
  105. class ConvBlock(nn.Module):
  106. def __init__(
  107. self,
  108. dim,
  109. num_heads,
  110. mlp_ratio=4.0,
  111. drop=0.0,
  112. drop_path=0.0,
  113. act_layer=nn.GELU,
  114. norm_layer=nn.LayerNorm,
  115. eps=1e-6,
  116. ):
  117. super().__init__()
  118. mlp_hidden_dim = int(dim * mlp_ratio)
  119. self.norm1 = norm_layer(dim, eps=eps)
  120. self.mixer = nn.Conv2d(dim, dim, 5, 1, 2, groups=num_heads)
  121. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  122. self.norm2 = norm_layer(dim, eps=eps)
  123. self.mlp = Mlp(
  124. in_features=dim,
  125. hidden_features=mlp_hidden_dim,
  126. act_layer=act_layer,
  127. drop=drop,
  128. )
  129. def forward(self, x):
  130. C, H, W = x.shape[1:]
  131. x = x + self.drop_path(self.mixer(x))
  132. x = self.norm1(x.flatten(2).transpose(1, 2))
  133. x = self.norm2(x + self.drop_path(self.mlp(x)))
  134. x = x.transpose(1, 2).reshape(-1, C, H, W)
  135. return x
  136. class FlattenTranspose(nn.Module):
  137. def forward(self, x):
  138. return x.flatten(2).transpose(1, 2)
  139. class SubSample2D(nn.Module):
  140. def __init__(
  141. self,
  142. in_channels,
  143. out_channels,
  144. stride=[2, 1],
  145. ):
  146. super().__init__()
  147. self.conv = nn.Conv2d(in_channels,
  148. out_channels,
  149. kernel_size=3,
  150. stride=stride,
  151. padding=1)
  152. self.norm = nn.LayerNorm(out_channels)
  153. def forward(self, x, sz):
  154. # print(x.shape)
  155. x = self.conv(x)
  156. C, H, W = x.shape[1:]
  157. x = self.norm(x.flatten(2).transpose(1, 2))
  158. x = x.transpose(1, 2).reshape(-1, C, H, W)
  159. return x, [H, W]
  160. class SubSample1D(nn.Module):
  161. def __init__(
  162. self,
  163. in_channels,
  164. out_channels,
  165. stride=[2, 1],
  166. ):
  167. super().__init__()
  168. self.conv = nn.Conv2d(in_channels,
  169. out_channels,
  170. kernel_size=3,
  171. stride=stride,
  172. padding=1)
  173. self.norm = nn.LayerNorm(out_channels)
  174. def forward(self, x, sz):
  175. C = x.shape[-1]
  176. x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
  177. x = self.conv(x)
  178. C, H, W = x.shape[1:]
  179. x = self.norm(x.flatten(2).transpose(1, 2))
  180. return x, [H, W]
  181. class IdentitySize(nn.Module):
  182. def forward(self, x, sz):
  183. return x, sz
  184. class SVTRStage(nn.Module):
  185. def __init__(self,
  186. dim=64,
  187. out_dim=256,
  188. depth=3,
  189. mixer=['Local'] * 3,
  190. sub_k=[2, 1],
  191. num_heads=2,
  192. mlp_ratio=4,
  193. qkv_bias=True,
  194. qk_scale=None,
  195. drop_rate=0.0,
  196. attn_drop_rate=0.0,
  197. drop_path=[0.1] * 3,
  198. norm_layer=nn.LayerNorm,
  199. act=nn.GELU,
  200. eps=1e-6,
  201. downsample=None,
  202. **kwargs):
  203. super().__init__()
  204. self.dim = dim
  205. conv_block_num = sum([1 if mix == 'Conv' else 0 for mix in mixer])
  206. self.blocks = nn.Sequential()
  207. for i in range(depth):
  208. if mixer[i] == 'Conv':
  209. self.blocks.append(
  210. ConvBlock(
  211. dim=dim,
  212. num_heads=num_heads,
  213. mlp_ratio=mlp_ratio,
  214. drop=drop_rate,
  215. act_layer=act,
  216. drop_path=drop_path[i],
  217. norm_layer=norm_layer,
  218. eps=eps,
  219. ))
  220. else:
  221. self.blocks.append(
  222. Block(
  223. dim=dim,
  224. num_heads=num_heads,
  225. mlp_ratio=mlp_ratio,
  226. qkv_bias=qkv_bias,
  227. qk_scale=qk_scale,
  228. drop=drop_rate,
  229. act_layer=act,
  230. attn_drop=attn_drop_rate,
  231. drop_path=drop_path[i],
  232. norm_layer=norm_layer,
  233. eps=eps,
  234. ))
  235. if i == conv_block_num - 1 and mixer[-1] != 'Conv':
  236. self.blocks.append(FlattenTranspose())
  237. if downsample:
  238. if mixer[-1] == 'Conv':
  239. self.downsample = SubSample2D(dim, out_dim, stride=sub_k)
  240. elif mixer[-1] == 'Global':
  241. self.downsample = SubSample1D(dim, out_dim, stride=sub_k)
  242. else:
  243. self.downsample = IdentitySize()
  244. def forward(self, x, sz):
  245. for blk in self.blocks:
  246. x = blk(x)
  247. x, sz = self.downsample(x, sz)
  248. return x, sz
  249. class ADDPosEmbed(nn.Module):
  250. def __init__(self, feat_max_size=[8, 32], embed_dim=768):
  251. super().__init__()
  252. pos_embed = torch.zeros(
  253. [1, feat_max_size[0] * feat_max_size[1], embed_dim],
  254. dtype=torch.float32)
  255. trunc_normal_(pos_embed, mean=0, std=0.02)
  256. self.pos_embed = nn.Parameter(
  257. pos_embed.transpose(1, 2).reshape(1, embed_dim, feat_max_size[0],
  258. feat_max_size[1]),
  259. requires_grad=True,
  260. )
  261. def forward(self, x):
  262. sz = x.shape[2:]
  263. x = x + self.pos_embed[:, :, :sz[0], :sz[1]]
  264. return x
  265. class POPatchEmbed(nn.Module):
  266. """Image to Patch Embedding."""
  267. def __init__(
  268. self,
  269. in_channels=3,
  270. feat_max_size=[8, 32],
  271. embed_dim=768,
  272. use_pos_embed=False,
  273. flatten=False,
  274. ):
  275. super().__init__()
  276. self.patch_embed = nn.Sequential(
  277. ConvBNLayer(
  278. in_channels=in_channels,
  279. out_channels=embed_dim // 2,
  280. kernel_size=3,
  281. stride=2,
  282. padding=1,
  283. act=nn.GELU,
  284. bias=None,
  285. ),
  286. ConvBNLayer(
  287. in_channels=embed_dim // 2,
  288. out_channels=embed_dim,
  289. kernel_size=3,
  290. stride=2,
  291. padding=1,
  292. act=nn.GELU,
  293. bias=None,
  294. ),
  295. )
  296. if use_pos_embed:
  297. self.patch_embed.append(ADDPosEmbed(feat_max_size, embed_dim))
  298. if flatten:
  299. self.patch_embed.append(FlattenTranspose())
  300. def forward(self, x):
  301. sz = x.shape[2:]
  302. x = self.patch_embed(x)
  303. return x, [sz[0] // 4, sz[1] // 4]
  304. class LastStage(nn.Module):
  305. def __init__(self, in_channels, out_channels, last_drop, out_char_num):
  306. super().__init__()
  307. self.last_conv = nn.Linear(
  308. in_channels, out_channels,
  309. bias=False) # self.num_features, self.out_channels, bias=False)
  310. self.hardswish = nn.Hardswish()
  311. self.dropout = nn.Dropout(p=last_drop)
  312. def forward(self, x, sz):
  313. x = x.reshape(-1, sz[0], sz[1], x.shape[-1])
  314. x = x.mean(1)
  315. x = self.last_conv(x)
  316. x = self.hardswish(x)
  317. x = self.dropout(x)
  318. return x, [1, sz[1]]
  319. class Feat2D(nn.Module):
  320. def __init__(self):
  321. super().__init__()
  322. def forward(self, x, sz):
  323. # b, L c
  324. # H W
  325. C = x.shape[-1]
  326. x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
  327. return x, sz
  328. # class LastStage(nn.Module):
  329. # def __init__(self, in_channels, out_channels, last_drop, out_char_num):
  330. # super().__init__()
  331. # self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
  332. # self.last_conv = nn.Conv2d(
  333. # in_channels=in_channels,
  334. # out_channels=out_channels,
  335. # kernel_size=1,
  336. # stride=1,
  337. # padding=0,
  338. # bias=False,
  339. # )
  340. # self.hardswish = nn.Hardswish()
  341. # self.dropout = nn.Dropout(p=last_drop)
  342. # def forward(self, x, sz):
  343. # # x = x.reshape(-1, sz[0], sz[1], x.shape[-1])
  344. # C = x.shape[-1]
  345. # x = self.avg_pool(x.transpose(1, 2).reshape(-1, C, sz[0], sz[1]))
  346. # x = self.last_conv(x)
  347. # sz = x.shape[2:]
  348. # x = self.hardswish(x)
  349. # x = self.dropout(x)
  350. # x = x.flatten(2).transpose(1, 2)
  351. # return x, sz
  352. class SVTRv2LNConv(nn.Module):
  353. def __init__(self,
  354. max_sz=[32, 128],
  355. in_channels=3,
  356. out_channels=192,
  357. out_char_num=25,
  358. depths=[3, 6, 3],
  359. dims=[64, 128, 256],
  360. mixer=[['Conv'] * 3, ['Conv'] * 3 + ['Global'] * 3,
  361. ['Global'] * 3],
  362. use_pos_embed=True,
  363. sub_k=[[1, 1], [2, 1], [1, 1]],
  364. num_heads=[2, 4, 8],
  365. mlp_ratio=4,
  366. qkv_bias=True,
  367. qk_scale=None,
  368. drop_rate=0.0,
  369. last_drop=0.1,
  370. attn_drop_rate=0.0,
  371. drop_path_rate=0.1,
  372. norm_layer=nn.LayerNorm,
  373. act=nn.GELU,
  374. last_stage=False,
  375. feat2d=False,
  376. eps=1e-6,
  377. **kwargs):
  378. super().__init__()
  379. num_stages = len(depths)
  380. self.num_features = dims[-1]
  381. feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
  382. self.pope = POPatchEmbed(
  383. in_channels=in_channels,
  384. feat_max_size=feat_max_size,
  385. embed_dim=dims[0],
  386. use_pos_embed=use_pos_embed,
  387. flatten=mixer[0][0] != 'Conv',
  388. )
  389. dpr = np.linspace(0, drop_path_rate,
  390. sum(depths)) # stochastic depth decay rule
  391. self.stages = nn.ModuleList()
  392. for i_stage in range(num_stages):
  393. stage = SVTRStage(
  394. dim=dims[i_stage],
  395. out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
  396. depth=depths[i_stage],
  397. mixer=mixer[i_stage],
  398. sub_k=sub_k[i_stage],
  399. num_heads=num_heads[i_stage],
  400. mlp_ratio=mlp_ratio,
  401. qkv_bias=qkv_bias,
  402. qk_scale=qk_scale,
  403. drop=drop_rate,
  404. attn_drop=attn_drop_rate,
  405. drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
  406. norm_layer=norm_layer,
  407. act=act,
  408. downsample=False if i_stage == num_stages - 1 else True,
  409. eps=eps,
  410. )
  411. self.stages.append(stage)
  412. self.out_channels = self.num_features
  413. self.last_stage = last_stage
  414. if last_stage:
  415. self.out_channels = out_channels
  416. self.stages.append(
  417. LastStage(self.num_features, out_channels, last_drop,
  418. out_char_num))
  419. if feat2d:
  420. self.stages.append(Feat2D())
  421. self.apply(self._init_weights)
  422. def _init_weights(self, m: nn.Module):
  423. if isinstance(m, nn.Linear):
  424. trunc_normal_(m.weight, mean=0, std=0.02)
  425. if isinstance(m, nn.Linear) and m.bias is not None:
  426. zeros_(m.bias)
  427. if isinstance(m, nn.LayerNorm):
  428. zeros_(m.bias)
  429. ones_(m.weight)
  430. if isinstance(m, nn.Conv2d):
  431. kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  432. @torch.jit.ignore
  433. def no_weight_decay(self):
  434. return {'patch_embed', 'downsample', 'pos_embed'}
  435. def forward(self, x):
  436. x, sz = self.pope(x)
  437. for stage in self.stages:
  438. x, sz = stage(x, sz)
  439. return x