svtrv2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
  5. from openrec.modeling.common import DropPath, Identity, Mlp
  6. class ConvBNLayer(nn.Module):
  7. def __init__(
  8. self,
  9. in_channels,
  10. out_channels,
  11. kernel_size=3,
  12. stride=1,
  13. padding=0,
  14. bias=False,
  15. groups=1,
  16. act=nn.GELU,
  17. ):
  18. super().__init__()
  19. self.conv = nn.Conv2d(
  20. in_channels=in_channels,
  21. out_channels=out_channels,
  22. kernel_size=kernel_size,
  23. stride=stride,
  24. padding=padding,
  25. groups=groups,
  26. bias=bias,
  27. )
  28. self.norm = nn.BatchNorm2d(out_channels)
  29. self.act = act()
  30. def forward(self, inputs):
  31. out = self.conv(inputs)
  32. out = self.norm(out)
  33. out = self.act(out)
  34. return out
  35. class ConvMixer(nn.Module):
  36. def __init__(
  37. self,
  38. dim,
  39. num_heads=8,
  40. local_k=[5, 5],
  41. ):
  42. super().__init__()
  43. self.local_mixer = nn.Conv2d(dim, dim, 5, 1, 2, groups=num_heads)
  44. def forward(self, x, mask=None):
  45. x = self.local_mixer(x)
  46. return x
  47. class ConvMlp(nn.Module):
  48. def __init__(
  49. self,
  50. in_features,
  51. hidden_features=None,
  52. out_features=None,
  53. act_layer=nn.GELU,
  54. drop=0.0,
  55. groups=1,
  56. ):
  57. super().__init__()
  58. out_features = out_features or in_features
  59. hidden_features = hidden_features or in_features
  60. self.fc1 = nn.Conv2d(in_features, hidden_features, 1, groups=groups)
  61. self.act = act_layer()
  62. self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  63. self.drop = nn.Dropout(drop)
  64. def forward(self, x):
  65. x = self.fc1(x)
  66. x = self.act(x)
  67. x = self.drop(x)
  68. x = self.fc2(x)
  69. x = self.drop(x)
  70. return x
  71. class Attention(nn.Module):
  72. def __init__(
  73. self,
  74. dim,
  75. num_heads=8,
  76. qkv_bias=False,
  77. qk_scale=None,
  78. attn_drop=0.0,
  79. proj_drop=0.0,
  80. ):
  81. super().__init__()
  82. self.num_heads = num_heads
  83. self.dim = dim
  84. self.head_dim = dim // num_heads
  85. self.scale = qk_scale or self.head_dim**-0.5
  86. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  87. self.attn_drop = nn.Dropout(attn_drop)
  88. self.proj = nn.Linear(dim, dim)
  89. self.proj_drop = nn.Dropout(proj_drop)
  90. def forward(self, x, mask=None):
  91. B, N, _ = x.shape
  92. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  93. self.head_dim).permute(2, 0, 3, 1, 4)
  94. q, k, v = qkv.unbind(0)
  95. attn = q @ k.transpose(-2, -1) * self.scale
  96. if mask is not None:
  97. attn += mask.unsqueeze(0)
  98. attn = attn.softmax(dim=-1)
  99. attn = self.attn_drop(attn)
  100. x = attn @ v
  101. x = x.transpose(1, 2).reshape(B, N, self.dim)
  102. x = self.proj(x)
  103. x = self.proj_drop(x)
  104. return x
  105. class Block(nn.Module):
  106. def __init__(
  107. self,
  108. dim,
  109. num_heads,
  110. mixer='Global',
  111. local_k=[7, 11],
  112. mlp_ratio=4.0,
  113. qkv_bias=False,
  114. qk_scale=None,
  115. drop=0.0,
  116. attn_drop=0.0,
  117. drop_path=0.0,
  118. act_layer=nn.GELU,
  119. norm_layer=nn.LayerNorm,
  120. eps=1e-6,
  121. ):
  122. super().__init__()
  123. mlp_hidden_dim = int(dim * mlp_ratio)
  124. if mixer == 'Global' or mixer == 'Local':
  125. self.norm1 = norm_layer(dim, eps=eps)
  126. self.mixer = Attention(
  127. dim,
  128. num_heads=num_heads,
  129. qkv_bias=qkv_bias,
  130. qk_scale=qk_scale,
  131. attn_drop=attn_drop,
  132. proj_drop=drop,
  133. )
  134. self.norm2 = norm_layer(dim, eps=eps)
  135. self.mlp = Mlp(
  136. in_features=dim,
  137. hidden_features=mlp_hidden_dim,
  138. act_layer=act_layer,
  139. drop=drop,
  140. )
  141. elif mixer == 'Conv':
  142. self.norm1 = nn.BatchNorm2d(dim)
  143. self.mixer = ConvMixer(dim, num_heads=num_heads, local_k=local_k)
  144. self.norm2 = nn.BatchNorm2d(dim)
  145. self.mlp = ConvMlp(in_features=dim,
  146. hidden_features=mlp_hidden_dim,
  147. act_layer=act_layer,
  148. drop=drop)
  149. else:
  150. raise TypeError('The mixer must be one of [Global, Local, Conv]')
  151. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  152. def forward(self, x, mask=None):
  153. x = self.norm1(x + self.drop_path(self.mixer(x, mask=mask)))
  154. x = self.norm2(x + self.drop_path(self.mlp(x)))
  155. return x
  156. class FlattenTranspose(nn.Module):
  157. def forward(self, x, mask=None):
  158. return x.flatten(2).transpose(1, 2)
  159. class SVTRStage(nn.Module):
  160. def __init__(self,
  161. feat_maxSize=[16, 128],
  162. dim=64,
  163. out_dim=256,
  164. depth=3,
  165. mixer=['Local'] * 3,
  166. local_k=[7, 11],
  167. sub_k=[2, 1],
  168. num_heads=2,
  169. mlp_ratio=4,
  170. qkv_bias=True,
  171. qk_scale=None,
  172. drop_rate=0.0,
  173. attn_drop_rate=0.0,
  174. drop_path=[0.1] * 3,
  175. norm_layer=nn.LayerNorm,
  176. act=nn.GELU,
  177. eps=1e-6,
  178. downsample=None,
  179. **kwargs):
  180. super().__init__()
  181. self.dim = dim
  182. conv_block_num = sum([1 if mix == 'Conv' else 0 for mix in mixer])
  183. if conv_block_num == depth:
  184. self.mask = None
  185. conv_block_num = 0
  186. if downsample:
  187. self.sub_norm = nn.BatchNorm2d(out_dim, eps=eps)
  188. else:
  189. if 'Local' in mixer:
  190. mask = self.get_max2d_mask(feat_maxSize[0], feat_maxSize[1],
  191. local_k)
  192. self.register_buffer('mask', mask)
  193. else:
  194. self.mask = None
  195. if downsample:
  196. self.sub_norm = norm_layer(out_dim, eps=eps)
  197. self.blocks = nn.ModuleList()
  198. for i in range(depth):
  199. self.blocks.append(
  200. Block(
  201. dim=dim,
  202. num_heads=num_heads,
  203. mixer=mixer[i],
  204. local_k=local_k,
  205. mlp_ratio=mlp_ratio,
  206. qkv_bias=qkv_bias,
  207. qk_scale=qk_scale,
  208. drop=drop_rate,
  209. act_layer=act,
  210. attn_drop=attn_drop_rate,
  211. drop_path=drop_path[i],
  212. norm_layer=norm_layer,
  213. eps=eps,
  214. ))
  215. if i == conv_block_num - 1:
  216. self.blocks.append(FlattenTranspose())
  217. if downsample:
  218. self.downsample = nn.Conv2d(dim,
  219. out_dim,
  220. kernel_size=3,
  221. stride=sub_k,
  222. padding=1)
  223. else:
  224. self.downsample = None
  225. def get_max2d_mask(self, H, W, local_k):
  226. hk, wk = local_k
  227. mask = torch.ones(H * W,
  228. H + hk - 1,
  229. W + wk - 1,
  230. dtype=torch.float32,
  231. requires_grad=False)
  232. for h in range(0, H):
  233. for w in range(0, W):
  234. mask[h * W + w, h:h + hk, w:w + wk] = 0.0
  235. mask = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2] # .flatten(1)
  236. mask[mask >= 1] = -np.inf
  237. return mask.reshape(H, W, H, W)
  238. def get_2d_mask(self, H1, W1):
  239. if H1 == self.mask.shape[0] and W1 == self.mask.shape[1]:
  240. return self.mask.flatten(0, 1).flatten(1, 2).unsqueeze(0)
  241. h_slice = H1 // 2
  242. offet_h = H1 - 2 * h_slice
  243. w_slice = W1 // 2
  244. offet_w = W1 - 2 * w_slice
  245. mask1 = self.mask[:h_slice + offet_h, :w_slice, :H1, :W1]
  246. mask2 = self.mask[:h_slice + offet_h, -w_slice:, :H1, -W1:]
  247. mask3 = self.mask[-h_slice:, :(w_slice + offet_w), -H1:, :W1]
  248. mask4 = self.mask[-h_slice:, -(w_slice + offet_w):, -H1:, -W1:]
  249. mask_top = torch.concat([mask1, mask2], 1)
  250. mask_bott = torch.concat([mask3, mask4], 1)
  251. mask = torch.concat([mask_top.flatten(2), mask_bott.flatten(2)], 0)
  252. return mask.flatten(0, 1).unsqueeze(0)
  253. def forward(self, x, sz=None):
  254. if self.mask is not None:
  255. mask = self.get_2d_mask(sz[0], sz[1])
  256. else:
  257. mask = self.mask
  258. for blk in self.blocks:
  259. x = blk(x, mask=mask)
  260. if self.downsample is not None:
  261. if x.dim() == 3:
  262. x = x.transpose(1, 2).reshape(-1, self.dim, sz[0], sz[1])
  263. x = self.downsample(x)
  264. sz = x.shape[2:]
  265. x = x.flatten(2).transpose(1, 2)
  266. else:
  267. x = self.downsample(x)
  268. sz = x.shape[2:]
  269. x = self.sub_norm(x)
  270. return x, sz
  271. class POPatchEmbed(nn.Module):
  272. """Image to Patch Embedding."""
  273. def __init__(self,
  274. in_channels=3,
  275. feat_max_size=[8, 32],
  276. embed_dim=768,
  277. use_pos_embed=False,
  278. flatten=False):
  279. super().__init__()
  280. self.patch_embed = nn.Sequential(
  281. ConvBNLayer(
  282. in_channels=in_channels,
  283. out_channels=embed_dim // 2,
  284. kernel_size=3,
  285. stride=2,
  286. padding=1,
  287. act=nn.GELU,
  288. bias=None,
  289. ),
  290. ConvBNLayer(
  291. in_channels=embed_dim // 2,
  292. out_channels=embed_dim,
  293. kernel_size=3,
  294. stride=2,
  295. padding=1,
  296. act=nn.GELU,
  297. bias=None,
  298. ),
  299. )
  300. self.use_pos_embed = use_pos_embed
  301. self.flatten = flatten
  302. if use_pos_embed:
  303. pos_embed = torch.zeros(
  304. [1, feat_max_size[0] * feat_max_size[1], embed_dim],
  305. dtype=torch.float32)
  306. trunc_normal_(pos_embed, mean=0, std=0.02)
  307. self.pos_embed = nn.Parameter(
  308. pos_embed.transpose(1,
  309. 2).reshape(1, embed_dim, feat_max_size[0],
  310. feat_max_size[1]),
  311. requires_grad=True,
  312. )
  313. def forward(self, x):
  314. x = self.patch_embed(x)
  315. sz = x.shape[2:]
  316. if self.use_pos_embed:
  317. x = x + self.pos_embed[:, :, :sz[0], :sz[1]]
  318. if self.flatten:
  319. x = x.flatten(2).transpose(1, 2)
  320. return x, sz
  321. class SVTRv2(nn.Module):
  322. def __init__(self,
  323. max_sz=[32, 128],
  324. in_channels=3,
  325. out_channels=192,
  326. depths=[3, 6, 3],
  327. dims=[64, 128, 256],
  328. mixer=[['Local'] * 3, ['Local'] * 3 + ['Global'] * 3,
  329. ['Global'] * 3],
  330. use_pos_embed=True,
  331. local_k=[[7, 11], [7, 11], [-1, -1]],
  332. sub_k=[[1, 1], [2, 1], [1, 1]],
  333. num_heads=[2, 4, 8],
  334. mlp_ratio=4,
  335. qkv_bias=True,
  336. qk_scale=None,
  337. drop_rate=0.0,
  338. last_drop=0.1,
  339. attn_drop_rate=0.0,
  340. drop_path_rate=0.1,
  341. norm_layer=nn.LayerNorm,
  342. act=nn.GELU,
  343. last_stage=False,
  344. eps=1e-6,
  345. **kwargs):
  346. super().__init__()
  347. num_stages = len(depths)
  348. self.num_features = dims[-1]
  349. feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
  350. self.pope = POPatchEmbed(in_channels=in_channels,
  351. feat_max_size=feat_max_size,
  352. embed_dim=dims[0],
  353. use_pos_embed=use_pos_embed,
  354. flatten=mixer[0][0] != 'Conv')
  355. dpr = np.linspace(0, drop_path_rate,
  356. sum(depths)) # stochastic depth decay rule
  357. self.stages = nn.ModuleList()
  358. for i_stage in range(num_stages):
  359. stage = SVTRStage(
  360. feat_maxSize=feat_max_size,
  361. dim=dims[i_stage],
  362. out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
  363. depth=depths[i_stage],
  364. mixer=mixer[i_stage],
  365. local_k=local_k[i_stage],
  366. sub_k=sub_k[i_stage],
  367. num_heads=num_heads[i_stage],
  368. mlp_ratio=mlp_ratio,
  369. qkv_bias=qkv_bias,
  370. qk_scale=qk_scale,
  371. drop=drop_rate,
  372. attn_drop=attn_drop_rate,
  373. drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
  374. norm_layer=norm_layer,
  375. act=act,
  376. downsample=False if i_stage == num_stages - 1 else True,
  377. eps=eps,
  378. )
  379. self.stages.append(stage)
  380. feat_max_size = [
  381. feat_max_size[0] // sub_k[i_stage][0],
  382. feat_max_size[1] // sub_k[i_stage][1]
  383. ]
  384. self.out_channels = self.num_features
  385. self.last_stage = last_stage
  386. if last_stage:
  387. self.out_channels = out_channels
  388. self.last_conv = nn.Linear(self.num_features,
  389. self.out_channels,
  390. bias=False)
  391. self.hardswish = nn.Hardswish()
  392. self.dropout = nn.Dropout(p=last_drop)
  393. self.apply(self._init_weights)
  394. def _init_weights(self, m: nn.Module):
  395. if isinstance(m, nn.Linear):
  396. trunc_normal_(m.weight, mean=0, std=0.02)
  397. if isinstance(m, nn.Linear) and m.bias is not None:
  398. zeros_(m.bias)
  399. if isinstance(m, nn.LayerNorm):
  400. zeros_(m.bias)
  401. ones_(m.weight)
  402. if isinstance(m, nn.Conv2d):
  403. kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  404. @torch.jit.ignore
  405. def no_weight_decay(self):
  406. return {'patch_embed', 'downsample', 'pos_embed'}
  407. def forward(self, x):
  408. x, sz = self.pope(x)
  409. for stage in self.stages:
  410. x, sz = stage(x, sz)
  411. if self.last_stage:
  412. x = x.reshape(-1, sz[0], sz[1], self.num_features)
  413. x = x.mean(1)
  414. x = self.last_conv(x)
  415. x = self.hardswish(x)
  416. x = self.dropout(x)
  417. return x