svtrnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
  5. from openrec.modeling.common import DropPath, Identity, Mlp
  6. class ConvBNLayer(nn.Module):
  7. def __init__(
  8. self,
  9. in_channels,
  10. out_channels,
  11. kernel_size=3,
  12. stride=1,
  13. padding=0,
  14. bias=False,
  15. groups=1,
  16. act=nn.GELU,
  17. ):
  18. super().__init__()
  19. self.conv = nn.Conv2d(
  20. in_channels=in_channels,
  21. out_channels=out_channels,
  22. kernel_size=kernel_size,
  23. stride=stride,
  24. padding=padding,
  25. groups=groups,
  26. bias=bias,
  27. )
  28. self.norm = nn.BatchNorm2d(out_channels)
  29. self.act = act()
  30. def forward(self, inputs):
  31. out = self.conv(inputs)
  32. out = self.norm(out)
  33. out = self.act(out)
  34. return out
  35. class ConvMixer(nn.Module):
  36. def __init__(
  37. self,
  38. dim,
  39. num_heads=8,
  40. HW=[8, 25],
  41. local_k=[3, 3],
  42. ):
  43. super().__init__()
  44. self.HW = HW
  45. self.dim = dim
  46. self.local_mixer = nn.Conv2d(dim,
  47. dim,
  48. local_k,
  49. 1, [local_k[0] // 2, local_k[1] // 2],
  50. groups=num_heads)
  51. def forward(self, x):
  52. h = self.HW[0]
  53. w = self.HW[1]
  54. x = x.transpose(1, 2).reshape([x.shape[0], self.dim, h, w])
  55. x = self.local_mixer(x)
  56. x = x.flatten(2).transpose(1, 2)
  57. return x
  58. class Attention(nn.Module):
  59. def __init__(
  60. self,
  61. dim,
  62. num_heads=8,
  63. mixer='Global',
  64. HW=None,
  65. local_k=[7, 11],
  66. qkv_bias=False,
  67. qk_scale=None,
  68. attn_drop=0.0,
  69. proj_drop=0.0,
  70. ):
  71. super().__init__()
  72. self.num_heads = num_heads
  73. self.dim = dim
  74. self.head_dim = dim // num_heads
  75. self.scale = qk_scale or self.head_dim**-0.5
  76. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  77. self.attn_drop = nn.Dropout(attn_drop)
  78. self.proj = nn.Linear(dim, dim)
  79. self.proj_drop = nn.Dropout(proj_drop)
  80. self.HW = HW
  81. if HW is not None:
  82. H = HW[0]
  83. W = HW[1]
  84. self.N = H * W
  85. self.C = dim
  86. if mixer == 'Local' and HW is not None:
  87. hk = local_k[0]
  88. wk = local_k[1]
  89. mask = torch.ones(H * W,
  90. H + hk - 1,
  91. W + wk - 1,
  92. dtype=torch.float32,
  93. requires_grad=False)
  94. for h in range(0, H):
  95. for w in range(0, W):
  96. mask[h * W + w, h:h + hk, w:w + wk] = 0.0
  97. mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2].flatten(1)
  98. mask[mask >= 1] = -np.inf
  99. self.register_buffer('mask', mask[None, None, :, :])
  100. self.mixer = mixer
  101. def forward(self, x):
  102. B, N, _ = x.shape
  103. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  104. self.head_dim).permute(2, 0, 3, 1, 4)
  105. q, k, v = qkv.unbind(0)
  106. # x = F.scaled_dot_product_attention(
  107. # q, k, v,
  108. # attn_mask=mask,
  109. # dropout_p=self.attn_drop.p
  110. # )
  111. q = q * self.scale
  112. attn = q @ k.transpose(-2, -1)
  113. if self.mixer == 'Local':
  114. attn += self.mask
  115. attn = attn.softmax(dim=-1)
  116. attn = self.attn_drop(attn)
  117. x = attn @ v
  118. x = x.transpose(1, 2).reshape(B, N, self.dim)
  119. x = self.proj(x)
  120. x = self.proj_drop(x)
  121. return x
  122. class Block(nn.Module):
  123. def __init__(
  124. self,
  125. dim,
  126. num_heads,
  127. mixer='Global',
  128. local_mixer=[7, 11],
  129. HW=None,
  130. mlp_ratio=4.0,
  131. qkv_bias=False,
  132. qk_scale=None,
  133. drop=0.0,
  134. attn_drop=0.0,
  135. drop_path=0.0,
  136. act_layer=nn.GELU,
  137. norm_layer='nn.LayerNorm',
  138. eps=1e-6,
  139. prenorm=True,
  140. ):
  141. super().__init__()
  142. if isinstance(norm_layer, str):
  143. self.norm1 = eval(norm_layer)(dim, eps=eps)
  144. else:
  145. self.norm1 = norm_layer(dim)
  146. if mixer == 'Global' or mixer == 'Local':
  147. self.mixer = Attention(
  148. dim,
  149. num_heads=num_heads,
  150. mixer=mixer,
  151. HW=HW,
  152. local_k=local_mixer,
  153. qkv_bias=qkv_bias,
  154. qk_scale=qk_scale,
  155. attn_drop=attn_drop,
  156. proj_drop=drop,
  157. )
  158. elif mixer == 'Conv':
  159. self.mixer = ConvMixer(dim,
  160. num_heads=num_heads,
  161. HW=HW,
  162. local_k=local_mixer)
  163. else:
  164. raise TypeError('The mixer must be one of [Global, Local, Conv]')
  165. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  166. if isinstance(norm_layer, str):
  167. self.norm2 = eval(norm_layer)(dim, eps=eps)
  168. else:
  169. self.norm2 = norm_layer(dim)
  170. mlp_hidden_dim = int(dim * mlp_ratio)
  171. self.mlp_ratio = mlp_ratio
  172. self.mlp = Mlp(
  173. in_features=dim,
  174. hidden_features=mlp_hidden_dim,
  175. act_layer=act_layer,
  176. drop=drop,
  177. )
  178. self.prenorm = prenorm
  179. def forward(self, x):
  180. if self.prenorm:
  181. x = self.norm1(x + self.drop_path(self.mixer(x)))
  182. x = self.norm2(x + self.drop_path(self.mlp(x)))
  183. else:
  184. x = x + self.drop_path(self.mixer(self.norm1(x)))
  185. x = x + self.drop_path(self.mlp(self.norm2(x)))
  186. return x
  187. class PatchEmbed(nn.Module):
  188. """Image to Patch Embedding."""
  189. def __init__(
  190. self,
  191. img_size=[32, 100],
  192. in_channels=3,
  193. embed_dim=768,
  194. sub_num=2,
  195. patch_size=[4, 4],
  196. mode='pope',
  197. ):
  198. super().__init__()
  199. num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] //
  200. (2**sub_num))
  201. self.img_size = img_size
  202. self.num_patches = num_patches
  203. self.embed_dim = embed_dim
  204. self.norm = None
  205. if mode == 'pope':
  206. if sub_num == 2:
  207. self.proj = nn.Sequential(
  208. ConvBNLayer(
  209. in_channels=in_channels,
  210. out_channels=embed_dim // 2,
  211. kernel_size=3,
  212. stride=2,
  213. padding=1,
  214. act=nn.GELU,
  215. bias=None,
  216. ),
  217. ConvBNLayer(
  218. in_channels=embed_dim // 2,
  219. out_channels=embed_dim,
  220. kernel_size=3,
  221. stride=2,
  222. padding=1,
  223. act=nn.GELU,
  224. bias=None,
  225. ),
  226. )
  227. if sub_num == 3:
  228. self.proj = nn.Sequential(
  229. ConvBNLayer(
  230. in_channels=in_channels,
  231. out_channels=embed_dim // 4,
  232. kernel_size=3,
  233. stride=2,
  234. padding=1,
  235. act=nn.GELU,
  236. bias=None,
  237. ),
  238. ConvBNLayer(
  239. in_channels=embed_dim // 4,
  240. out_channels=embed_dim // 2,
  241. kernel_size=3,
  242. stride=2,
  243. padding=1,
  244. act=nn.GELU,
  245. bias=None,
  246. ),
  247. ConvBNLayer(
  248. in_channels=embed_dim // 2,
  249. out_channels=embed_dim,
  250. kernel_size=3,
  251. stride=2,
  252. padding=1,
  253. act=nn.GELU,
  254. bias=None,
  255. ),
  256. )
  257. elif mode == 'linear':
  258. self.proj = nn.Conv2d(1,
  259. embed_dim,
  260. kernel_size=patch_size,
  261. stride=patch_size)
  262. self.num_patches = img_size[0] // patch_size[0] * img_size[
  263. 1] // patch_size[1]
  264. def forward(self, x):
  265. B, C, H, W = x.shape
  266. assert (
  267. H == self.img_size[0] and W == self.img_size[1]
  268. ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  269. x = self.proj(x).flatten(2).transpose(1, 2)
  270. return x
  271. class SubSample(nn.Module):
  272. def __init__(
  273. self,
  274. in_channels,
  275. out_channels,
  276. types='Pool',
  277. stride=[2, 1],
  278. sub_norm='nn.LayerNorm',
  279. act=None,
  280. ):
  281. super().__init__()
  282. self.types = types
  283. if types == 'Pool':
  284. self.avgpool = nn.AvgPool2d(kernel_size=[3, 5],
  285. stride=stride,
  286. padding=[1, 2])
  287. self.maxpool = nn.MaxPool2d(kernel_size=[3, 5],
  288. stride=stride,
  289. padding=[1, 2])
  290. self.proj = nn.Linear(in_channels, out_channels)
  291. else:
  292. self.conv = nn.Conv2d(in_channels,
  293. out_channels,
  294. kernel_size=3,
  295. stride=stride,
  296. padding=1)
  297. self.norm = eval(sub_norm)(out_channels)
  298. if act is not None:
  299. self.act = act()
  300. else:
  301. self.act = None
  302. def forward(self, x):
  303. if self.types == 'Pool':
  304. x1 = self.avgpool(x)
  305. x2 = self.maxpool(x)
  306. x = (x1 + x2) * 0.5
  307. out = self.proj(x.flatten(2).transpose(1, 2))
  308. else:
  309. x = self.conv(x)
  310. out = x.flatten(2).transpose(1, 2)
  311. out = self.norm(out)
  312. if self.act is not None:
  313. out = self.act(out)
  314. return out
  315. class SVTRNet(nn.Module):
  316. def __init__(
  317. self,
  318. img_size=[32, 100],
  319. in_channels=3,
  320. embed_dim=[64, 128, 256],
  321. depth=[3, 6, 3],
  322. num_heads=[2, 4, 8],
  323. mixer=['Local'] * 6 +
  324. ['Global'] * 6, # Local atten, Global atten, Conv
  325. local_mixer=[[7, 11], [7, 11], [7, 11]],
  326. patch_merging='Conv', # Conv, Pool, None
  327. sub_k=[[2, 1], [2, 1]],
  328. mlp_ratio=4,
  329. qkv_bias=True,
  330. qk_scale=None,
  331. drop_rate=0.0,
  332. last_drop=0.1,
  333. attn_drop_rate=0.0,
  334. drop_path_rate=0.1,
  335. norm_layer='nn.LayerNorm',
  336. sub_norm='nn.LayerNorm',
  337. eps=1e-6,
  338. out_channels=192,
  339. out_char_num=25,
  340. block_unit='Block',
  341. act='nn.GELU',
  342. last_stage=True,
  343. sub_num=2,
  344. prenorm=True,
  345. use_lenhead=False,
  346. feature2d=False,
  347. **kwargs,
  348. ):
  349. super().__init__()
  350. self.img_size = img_size
  351. self.embed_dim = embed_dim
  352. self.out_channels = out_channels
  353. self.prenorm = prenorm
  354. self.feature2d = feature2d
  355. patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
  356. self.patch_embed = PatchEmbed(
  357. img_size=img_size,
  358. in_channels=in_channels,
  359. embed_dim=embed_dim[0],
  360. sub_num=sub_num,
  361. )
  362. num_patches = self.patch_embed.num_patches
  363. self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
  364. self.hw = [
  365. [self.HW[0] // sub_k[0][0], self.HW[1] // sub_k[0][1]],
  366. [
  367. self.HW[0] // (sub_k[0][0] * sub_k[1][0]),
  368. self.HW[1] // (sub_k[0][1] * sub_k[1][1])
  369. ],
  370. ]
  371. # self.pos_embed = self.create_parameter(
  372. # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
  373. # self.add_parameter("pos_embed", self.pos_embed)
  374. self.pos_embed = nn.Parameter(
  375. torch.zeros([1, num_patches, embed_dim[0]], dtype=torch.float32),
  376. requires_grad=True,
  377. )
  378. self.pos_drop = nn.Dropout(p=drop_rate)
  379. Block_unit = eval(block_unit)
  380. dpr = np.linspace(0, drop_path_rate, sum(depth))
  381. self.blocks1 = nn.ModuleList([
  382. Block_unit(
  383. dim=embed_dim[0],
  384. num_heads=num_heads[0],
  385. mixer=mixer[0:depth[0]][i],
  386. HW=self.HW,
  387. local_mixer=local_mixer[0],
  388. mlp_ratio=mlp_ratio,
  389. qkv_bias=qkv_bias,
  390. qk_scale=qk_scale,
  391. drop=drop_rate,
  392. act_layer=eval(act),
  393. attn_drop=attn_drop_rate,
  394. drop_path=dpr[0:depth[0]][i],
  395. norm_layer=norm_layer,
  396. eps=eps,
  397. prenorm=prenorm,
  398. ) for i in range(depth[0])
  399. ])
  400. if patch_merging is not None:
  401. self.sub_sample1 = SubSample(
  402. embed_dim[0],
  403. embed_dim[1],
  404. sub_norm=sub_norm,
  405. stride=sub_k[0],
  406. types=patch_merging,
  407. )
  408. HW = self.hw[0]
  409. else:
  410. HW = self.HW
  411. self.patch_merging = patch_merging
  412. self.blocks2 = nn.ModuleList([
  413. Block_unit(
  414. dim=embed_dim[1],
  415. num_heads=num_heads[1],
  416. mixer=mixer[depth[0]:depth[0] + depth[1]][i],
  417. HW=HW,
  418. local_mixer=local_mixer[1],
  419. mlp_ratio=mlp_ratio,
  420. qkv_bias=qkv_bias,
  421. qk_scale=qk_scale,
  422. drop=drop_rate,
  423. act_layer=eval(act),
  424. attn_drop=attn_drop_rate,
  425. drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
  426. norm_layer=norm_layer,
  427. eps=eps,
  428. prenorm=prenorm,
  429. ) for i in range(depth[1])
  430. ])
  431. if patch_merging is not None:
  432. self.sub_sample2 = SubSample(
  433. embed_dim[1],
  434. embed_dim[2],
  435. sub_norm=sub_norm,
  436. stride=sub_k[1],
  437. types=patch_merging,
  438. )
  439. HW = self.hw[1]
  440. self.blocks3 = nn.ModuleList([
  441. Block_unit(
  442. dim=embed_dim[2],
  443. num_heads=num_heads[2],
  444. mixer=mixer[depth[0] + depth[1]:][i],
  445. HW=HW,
  446. local_mixer=local_mixer[2],
  447. mlp_ratio=mlp_ratio,
  448. qkv_bias=qkv_bias,
  449. qk_scale=qk_scale,
  450. drop=drop_rate,
  451. act_layer=eval(act),
  452. attn_drop=attn_drop_rate,
  453. drop_path=dpr[depth[0] + depth[1]:][i],
  454. norm_layer=norm_layer,
  455. eps=eps,
  456. prenorm=prenorm,
  457. ) for i in range(depth[2])
  458. ])
  459. self.last_stage = last_stage
  460. if last_stage:
  461. self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
  462. self.last_conv = nn.Conv2d(
  463. in_channels=embed_dim[2],
  464. out_channels=self.out_channels,
  465. kernel_size=1,
  466. stride=1,
  467. padding=0,
  468. bias=False,
  469. )
  470. self.hardswish = nn.Hardswish()
  471. self.dropout = nn.Dropout(p=last_drop)
  472. else:
  473. self.out_channels = embed_dim[2]
  474. if not prenorm:
  475. self.norm = eval(norm_layer)(embed_dim[-1], eps=eps)
  476. self.use_lenhead = use_lenhead
  477. if use_lenhead:
  478. self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
  479. self.hardswish_len = nn.Hardswish()
  480. self.dropout_len = nn.Dropout(p=last_drop)
  481. trunc_normal_(self.pos_embed, mean=0, std=0.02)
  482. self.apply(self._init_weights)
  483. def _init_weights(self, m):
  484. if isinstance(m, nn.Linear):
  485. trunc_normal_(m.weight, mean=0, std=0.02)
  486. if isinstance(m, nn.Linear) and m.bias is not None:
  487. zeros_(m.bias)
  488. if isinstance(m, nn.LayerNorm):
  489. zeros_(m.bias)
  490. ones_(m.weight)
  491. if isinstance(m, nn.Conv2d):
  492. kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  493. @torch.jit.ignore
  494. def no_weight_decay(self):
  495. return {'pos_embed', 'sub_sample1', 'sub_sample2', 'sub_sample3'}
  496. def forward_features(self, x):
  497. x = self.patch_embed(x)
  498. x = x + self.pos_embed
  499. x = self.pos_drop(x)
  500. for blk in self.blocks1:
  501. x = blk(x)
  502. if self.patch_merging is not None:
  503. x = self.sub_sample1(
  504. x.transpose(1, 2).reshape(-1, self.embed_dim[0], self.HW[0],
  505. self.HW[1]))
  506. for blk in self.blocks2:
  507. x = blk(x)
  508. if self.patch_merging is not None:
  509. x = self.sub_sample2(
  510. x.transpose(1, 2).reshape(-1, self.embed_dim[1], self.hw[0][0],
  511. self.hw[0][1]))
  512. for blk in self.blocks3:
  513. x = blk(x)
  514. if not self.prenorm:
  515. x = self.norm(x)
  516. return x
  517. def forward(self, x):
  518. x = self.forward_features(x)
  519. if self.feature2d:
  520. x = x.transpose(1, 2).reshape(-1, self.embed_dim[2], self.hw[1][0],
  521. self.hw[1][1])
  522. if self.use_lenhead:
  523. len_x = self.len_conv(x.mean(1))
  524. len_x = self.dropout_len(self.hardswish_len(len_x))
  525. if self.last_stage:
  526. x = self.avg_pool(
  527. x.transpose(1, 2).reshape(-1, self.embed_dim[2], self.hw[1][0],
  528. self.hw[1][1]))
  529. x = self.last_conv(x)
  530. x = self.hardswish(x)
  531. x = self.dropout(x)
  532. x = x.flatten(2).transpose(1, 2)
  533. if self.use_lenhead:
  534. return x, len_x
  535. return x