svtrv2_lnconv_two33.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
  5. from openrec.modeling.common import DropPath, Identity, Mlp
  6. class ConvBNLayer(nn.Module):
  7. def __init__(
  8. self,
  9. in_channels,
  10. out_channels,
  11. kernel_size=3,
  12. stride=1,
  13. padding=0,
  14. bias=False,
  15. groups=1,
  16. act=nn.GELU,
  17. ):
  18. super().__init__()
  19. self.conv = nn.Conv2d(
  20. in_channels=in_channels,
  21. out_channels=out_channels,
  22. kernel_size=kernel_size,
  23. stride=stride,
  24. padding=padding,
  25. groups=groups,
  26. bias=bias,
  27. )
  28. self.norm = nn.BatchNorm2d(out_channels)
  29. self.act = act()
  30. def forward(self, inputs):
  31. out = self.conv(inputs)
  32. out = self.norm(out)
  33. out = self.act(out)
  34. return out
  35. class Attention(nn.Module):
  36. def __init__(
  37. self,
  38. dim,
  39. num_heads=8,
  40. qkv_bias=False,
  41. qk_scale=None,
  42. attn_drop=0.0,
  43. proj_drop=0.0,
  44. ):
  45. super().__init__()
  46. self.num_heads = num_heads
  47. self.dim = dim
  48. self.head_dim = dim // num_heads
  49. self.scale = qk_scale or self.head_dim**-0.5
  50. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  51. self.attn_drop = nn.Dropout(attn_drop)
  52. self.proj = nn.Linear(dim, dim)
  53. self.proj_drop = nn.Dropout(proj_drop)
  54. def forward(self, x):
  55. B, N, _ = x.shape
  56. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  57. self.head_dim).permute(2, 0, 3, 1, 4)
  58. q, k, v = qkv.unbind(0)
  59. attn = q @ k.transpose(-2, -1) * self.scale
  60. attn = attn.softmax(dim=-1)
  61. attn = self.attn_drop(attn)
  62. x = attn @ v
  63. x = x.transpose(1, 2).reshape(B, N, self.dim)
  64. x = self.proj(x)
  65. x = self.proj_drop(x)
  66. return x
  67. class Block(nn.Module):
  68. def __init__(
  69. self,
  70. dim,
  71. num_heads,
  72. mlp_ratio=4.0,
  73. qkv_bias=False,
  74. qk_scale=None,
  75. drop=0.0,
  76. attn_drop=0.0,
  77. drop_path=0.0,
  78. act_layer=nn.GELU,
  79. norm_layer=nn.LayerNorm,
  80. eps=1e-6,
  81. ):
  82. super().__init__()
  83. mlp_hidden_dim = int(dim * mlp_ratio)
  84. self.norm1 = norm_layer(dim, eps=eps)
  85. self.mixer = Attention(
  86. dim,
  87. num_heads=num_heads,
  88. qkv_bias=qkv_bias,
  89. qk_scale=qk_scale,
  90. attn_drop=attn_drop,
  91. proj_drop=drop,
  92. )
  93. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  94. self.norm2 = norm_layer(dim, eps=eps)
  95. self.mlp = Mlp(
  96. in_features=dim,
  97. hidden_features=mlp_hidden_dim,
  98. act_layer=act_layer,
  99. drop=drop,
  100. )
  101. def forward(self, x):
  102. x = self.norm1(x + self.drop_path(self.mixer(x)))
  103. x = self.norm2(x + self.drop_path(self.mlp(x)))
  104. return x
  105. class FlattenBlockRe2D(Block):
  106. def __init__(self,
  107. dim,
  108. num_heads,
  109. mlp_ratio=4,
  110. qkv_bias=False,
  111. qk_scale=None,
  112. drop=0,
  113. attn_drop=0,
  114. drop_path=0,
  115. act_layer=nn.GELU,
  116. norm_layer=nn.LayerNorm,
  117. eps=0.000001):
  118. super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
  119. attn_drop, drop_path, act_layer, norm_layer, eps)
  120. def forward(self, x):
  121. B, C, H, W = x.shape
  122. x = x.flatten(2).transpose(1, 2)
  123. x = super().forward(x)
  124. x = x.transpose(1, 2).reshape(B, C, H, W)
  125. return x
  126. class ConvBlock(nn.Module):
  127. def __init__(
  128. self,
  129. dim,
  130. num_heads,
  131. mlp_ratio=4.0,
  132. drop=0.0,
  133. drop_path=0.0,
  134. act_layer=nn.GELU,
  135. norm_layer=nn.LayerNorm,
  136. eps=1e-6,
  137. num_conv=2,
  138. kernel_size=3,
  139. ):
  140. super().__init__()
  141. mlp_hidden_dim = int(dim * mlp_ratio)
  142. self.norm1 = norm_layer(dim, eps=eps)
  143. self.mixer = nn.Sequential(*[
  144. nn.Conv2d(
  145. dim, dim, kernel_size, 1, kernel_size // 2, groups=num_heads)
  146. for i in range(num_conv)
  147. ])
  148. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  149. self.norm2 = norm_layer(dim, eps=eps)
  150. self.mlp = Mlp(
  151. in_features=dim,
  152. hidden_features=mlp_hidden_dim,
  153. act_layer=act_layer,
  154. drop=drop,
  155. )
  156. def forward(self, x):
  157. C, H, W = x.shape[1:]
  158. x = x + self.drop_path(self.mixer(x))
  159. x = self.norm1(x.flatten(2).transpose(1, 2))
  160. x = self.norm2(x + self.drop_path(self.mlp(x)))
  161. x = x.transpose(1, 2).reshape(-1, C, H, W)
  162. return x
  163. class FlattenTranspose(nn.Module):
  164. def forward(self, x):
  165. return x.flatten(2).transpose(1, 2)
  166. class SubSample2D(nn.Module):
  167. def __init__(
  168. self,
  169. in_channels,
  170. out_channels,
  171. stride=[2, 1],
  172. ):
  173. super().__init__()
  174. self.conv = nn.Conv2d(in_channels,
  175. out_channels,
  176. kernel_size=3,
  177. stride=stride,
  178. padding=1)
  179. self.norm = nn.LayerNorm(out_channels)
  180. def forward(self, x, sz):
  181. # print(x.shape)
  182. x = self.conv(x)
  183. C, H, W = x.shape[1:]
  184. x = self.norm(x.flatten(2).transpose(1, 2))
  185. x = x.transpose(1, 2).reshape(-1, C, H, W)
  186. return x, [H, W]
  187. class SubSample1D(nn.Module):
  188. def __init__(
  189. self,
  190. in_channels,
  191. out_channels,
  192. stride=[2, 1],
  193. ):
  194. super().__init__()
  195. self.conv = nn.Conv2d(in_channels,
  196. out_channels,
  197. kernel_size=3,
  198. stride=stride,
  199. padding=1)
  200. self.norm = nn.LayerNorm(out_channels)
  201. def forward(self, x, sz):
  202. C = x.shape[-1]
  203. x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
  204. x = self.conv(x)
  205. C, H, W = x.shape[1:]
  206. x = self.norm(x.flatten(2).transpose(1, 2))
  207. return x, [H, W]
  208. class IdentitySize(nn.Module):
  209. def forward(self, x, sz):
  210. return x, sz
  211. class SVTRStage(nn.Module):
  212. def __init__(self,
  213. dim=64,
  214. out_dim=256,
  215. depth=3,
  216. mixer=['Local'] * 3,
  217. kernel_sizes=[3] * 3,
  218. sub_k=[2, 1],
  219. num_heads=2,
  220. mlp_ratio=4,
  221. qkv_bias=True,
  222. qk_scale=None,
  223. drop_rate=0.0,
  224. attn_drop_rate=0.0,
  225. drop_path=[0.1] * 3,
  226. norm_layer=nn.LayerNorm,
  227. act=nn.GELU,
  228. eps=1e-6,
  229. num_conv=[2] * 3,
  230. downsample=None,
  231. **kwargs):
  232. super().__init__()
  233. self.dim = dim
  234. self.blocks = nn.Sequential()
  235. for i in range(depth):
  236. if mixer[i] == 'Conv':
  237. self.blocks.append(
  238. ConvBlock(dim=dim,
  239. kernel_size=kernel_sizes[i],
  240. num_heads=num_heads,
  241. mlp_ratio=mlp_ratio,
  242. drop=drop_rate,
  243. act_layer=act,
  244. drop_path=drop_path[i],
  245. norm_layer=norm_layer,
  246. eps=eps,
  247. num_conv=num_conv[i]))
  248. else:
  249. if mixer[i] == 'Global':
  250. block = Block
  251. elif mixer[i] == 'FGlobal':
  252. block = Block
  253. self.blocks.append(FlattenTranspose())
  254. elif mixer[i] == 'FGlobalRe2D':
  255. block = FlattenBlockRe2D
  256. self.blocks.append(
  257. block(
  258. dim=dim,
  259. num_heads=num_heads,
  260. mlp_ratio=mlp_ratio,
  261. qkv_bias=qkv_bias,
  262. qk_scale=qk_scale,
  263. drop=drop_rate,
  264. act_layer=act,
  265. attn_drop=attn_drop_rate,
  266. drop_path=drop_path[i],
  267. norm_layer=norm_layer,
  268. eps=eps,
  269. ))
  270. if downsample:
  271. if mixer[-1] == 'Conv' or mixer[-1] == 'FGlobalRe2D':
  272. self.downsample = SubSample2D(dim, out_dim, stride=sub_k)
  273. else:
  274. self.downsample = SubSample1D(dim, out_dim, stride=sub_k)
  275. else:
  276. self.downsample = IdentitySize()
  277. def forward(self, x, sz):
  278. for blk in self.blocks:
  279. x = blk(x)
  280. x, sz = self.downsample(x, sz)
  281. return x, sz
  282. class ADDPosEmbed(nn.Module):
  283. def __init__(self, feat_max_size=[8, 32], embed_dim=768):
  284. super().__init__()
  285. pos_embed = torch.zeros(
  286. [1, feat_max_size[0] * feat_max_size[1], embed_dim],
  287. dtype=torch.float32)
  288. trunc_normal_(pos_embed, mean=0, std=0.02)
  289. self.pos_embed = nn.Parameter(
  290. pos_embed.transpose(1, 2).reshape(1, embed_dim, feat_max_size[0],
  291. feat_max_size[1]),
  292. requires_grad=True,
  293. )
  294. def forward(self, x):
  295. sz = x.shape[2:]
  296. x = x + self.pos_embed[:, :, :sz[0], :sz[1]]
  297. return x
  298. class POPatchEmbed(nn.Module):
  299. """Image to Patch Embedding."""
  300. def __init__(self,
  301. in_channels=3,
  302. feat_max_size=[8, 32],
  303. embed_dim=768,
  304. use_pos_embed=False,
  305. flatten=False,
  306. bias=False):
  307. super().__init__()
  308. self.patch_embed = nn.Sequential(
  309. ConvBNLayer(
  310. in_channels=in_channels,
  311. out_channels=embed_dim // 2,
  312. kernel_size=3,
  313. stride=2,
  314. padding=1,
  315. act=nn.GELU,
  316. bias=bias,
  317. ),
  318. ConvBNLayer(
  319. in_channels=embed_dim // 2,
  320. out_channels=embed_dim,
  321. kernel_size=3,
  322. stride=2,
  323. padding=1,
  324. act=nn.GELU,
  325. bias=bias,
  326. ),
  327. )
  328. if use_pos_embed:
  329. self.patch_embed.append(ADDPosEmbed(feat_max_size, embed_dim))
  330. if flatten:
  331. self.patch_embed.append(FlattenTranspose())
  332. def forward(self, x):
  333. sz = x.shape[2:]
  334. x = self.patch_embed(x)
  335. return x, [sz[0] // 4, sz[1] // 4]
  336. class LastStage(nn.Module):
  337. def __init__(self, in_channels, out_channels, last_drop, out_char_num=0):
  338. super().__init__()
  339. self.last_conv = nn.Linear(in_channels, out_channels, bias=False)
  340. self.hardswish = nn.Hardswish()
  341. self.dropout = nn.Dropout(p=last_drop)
  342. def forward(self, x, sz):
  343. x = x.reshape(-1, sz[0], sz[1], x.shape[-1])
  344. x = x.mean(1)
  345. x = self.last_conv(x)
  346. x = self.hardswish(x)
  347. x = self.dropout(x)
  348. return x, [1, sz[1]]
  349. class Feat2D(nn.Module):
  350. def __init__(self):
  351. super().__init__()
  352. def forward(self, x, sz):
  353. C = x.shape[-1]
  354. x = x.transpose(1, 2).reshape(-1, C, sz[0], sz[1])
  355. return x, sz
  356. class SVTRv2LNConvTwo33(nn.Module):
  357. def __init__(self,
  358. max_sz=[32, 128],
  359. in_channels=3,
  360. out_channels=192,
  361. depths=[3, 6, 3],
  362. dims=[64, 128, 256],
  363. mixer=[['Conv'] * 3, ['Conv'] * 3 + ['Global'] * 3,
  364. ['Global'] * 3],
  365. use_pos_embed=True,
  366. sub_k=[[1, 1], [2, 1], [1, 1]],
  367. num_heads=[2, 4, 8],
  368. mlp_ratio=4,
  369. qkv_bias=True,
  370. qk_scale=None,
  371. drop_rate=0.0,
  372. last_drop=0.1,
  373. attn_drop_rate=0.0,
  374. drop_path_rate=0.1,
  375. norm_layer=nn.LayerNorm,
  376. act=nn.GELU,
  377. last_stage=False,
  378. feat2d=False,
  379. eps=1e-6,
  380. num_convs=[[2] * 3, [2] * 3 + [3] * 3, [3] * 3],
  381. kernel_sizes=[[3] * 3, [3] * 3 + [3] * 3, [3] * 3],
  382. pope_bias=False,
  383. **kwargs):
  384. super().__init__()
  385. num_stages = len(depths)
  386. self.num_features = dims[-1]
  387. feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
  388. self.pope = POPatchEmbed(in_channels=in_channels,
  389. feat_max_size=feat_max_size,
  390. embed_dim=dims[0],
  391. use_pos_embed=use_pos_embed,
  392. flatten=mixer[0][0] != 'Conv',
  393. bias=pope_bias)
  394. dpr = np.linspace(0, drop_path_rate,
  395. sum(depths)) # stochastic depth decay rule
  396. self.stages = nn.ModuleList()
  397. for i_stage in range(num_stages):
  398. stage = SVTRStage(
  399. dim=dims[i_stage],
  400. out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
  401. depth=depths[i_stage],
  402. mixer=mixer[i_stage],
  403. kernel_sizes=kernel_sizes[i_stage]
  404. if len(kernel_sizes[i_stage]) == len(mixer[i_stage]) else [3] *
  405. len(mixer[i_stage]),
  406. sub_k=sub_k[i_stage],
  407. num_heads=num_heads[i_stage],
  408. mlp_ratio=mlp_ratio,
  409. qkv_bias=qkv_bias,
  410. qk_scale=qk_scale,
  411. drop=drop_rate,
  412. attn_drop=attn_drop_rate,
  413. drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
  414. norm_layer=norm_layer,
  415. act=act,
  416. downsample=False if i_stage == num_stages - 1 else True,
  417. eps=eps,
  418. num_conv=num_convs[i_stage] if len(num_convs[i_stage]) == len(
  419. mixer[i_stage]) else [2] * len(mixer[i_stage]),
  420. )
  421. self.stages.append(stage)
  422. self.out_channels = self.num_features
  423. self.last_stage = last_stage
  424. if last_stage:
  425. self.out_channels = out_channels
  426. self.stages.append(
  427. LastStage(self.num_features, out_channels, last_drop))
  428. if feat2d:
  429. self.stages.append(Feat2D())
  430. self.apply(self._init_weights)
  431. def _init_weights(self, m: nn.Module):
  432. if isinstance(m, nn.Linear):
  433. trunc_normal_(m.weight, mean=0, std=0.02)
  434. if isinstance(m, nn.Linear) and m.bias is not None:
  435. zeros_(m.bias)
  436. if isinstance(m, nn.LayerNorm):
  437. zeros_(m.bias)
  438. ones_(m.weight)
  439. if isinstance(m, nn.Conv2d):
  440. kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  441. @torch.jit.ignore
  442. def no_weight_decay(self):
  443. return {'patch_embed', 'downsample', 'pos_embed'}
  444. def forward(self, x):
  445. if len(x.shape) == 5:
  446. x = x.flatten(0, 1)
  447. x, sz = self.pope(x)
  448. for stage in self.stages:
  449. x, sz = stage(x, sz)
  450. return x