focalsvtr.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. # --------------------------------------------------------
  2. # FocalNets -- Focal Modulation Networks
  3. # Copyright (c) 2022 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Jianwei Yang (jianwyan@microsoft.com)
  6. # --------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. import torch.utils.checkpoint as checkpoint
  10. from torch.nn.init import trunc_normal_
  11. from openrec.modeling.common import DropPath, Mlp
  12. from openrec.modeling.encoders.svtrnet import ConvBNLayer
  13. class FocalModulation(nn.Module):
  14. def __init__(self,
  15. dim,
  16. focal_window,
  17. focal_level,
  18. max_kh=None,
  19. focal_factor=2,
  20. bias=True,
  21. proj_drop=0.0,
  22. use_postln_in_modulation=False,
  23. normalize_modulator=False):
  24. super().__init__()
  25. self.dim = dim
  26. self.focal_window = focal_window
  27. self.focal_level = focal_level
  28. self.focal_factor = focal_factor
  29. self.use_postln_in_modulation = use_postln_in_modulation
  30. self.normalize_modulator = normalize_modulator
  31. self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
  32. self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
  33. self.act = nn.GELU()
  34. self.proj = nn.Linear(dim, dim)
  35. self.proj_drop = nn.Dropout(proj_drop)
  36. self.focal_layers = nn.ModuleList()
  37. self.kernel_sizes = []
  38. for k in range(self.focal_level):
  39. kernel_size = self.focal_factor * k + self.focal_window
  40. if max_kh is not None:
  41. k_h, k_w = [min(kernel_size, max_kh), kernel_size]
  42. kernel_size = [k_h, k_w]
  43. padding = [k_h // 2, k_w // 2]
  44. else:
  45. padding = kernel_size // 2
  46. self.focal_layers.append(
  47. nn.Sequential(
  48. nn.Conv2d(dim,
  49. dim,
  50. kernel_size=kernel_size,
  51. stride=1,
  52. groups=dim,
  53. padding=padding,
  54. bias=False),
  55. nn.GELU(),
  56. ))
  57. self.kernel_sizes.append(kernel_size)
  58. if self.use_postln_in_modulation:
  59. self.ln = nn.LayerNorm(dim)
  60. def forward(self, x):
  61. """
  62. Args:
  63. x: input features with shape of (B, H, W, C)
  64. """
  65. C = x.shape[-1]
  66. # pre linear projection
  67. x = self.f(x).permute(0, 3, 1, 2).contiguous()
  68. q, ctx, self.gates = torch.split(x, (C, C, self.focal_level + 1), 1)
  69. # context aggreation
  70. ctx_all = 0
  71. for l in range(self.focal_level):
  72. ctx = self.focal_layers[l](ctx)
  73. ctx_all = ctx_all + ctx * self.gates[:, l:l + 1]
  74. ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
  75. ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level:]
  76. # normalize context
  77. if self.normalize_modulator:
  78. ctx_all = ctx_all / (self.focal_level + 1)
  79. # focal modulation
  80. self.modulator = self.h(ctx_all)
  81. x_out = q * self.modulator
  82. x_out = x_out.permute(0, 2, 3, 1).contiguous()
  83. if self.use_postln_in_modulation:
  84. x_out = self.ln(x_out)
  85. # post linear porjection
  86. x_out = self.proj(x_out)
  87. x_out = self.proj_drop(x_out)
  88. return x_out
  89. def extra_repr(self) -> str:
  90. return f'dim={self.dim}'
  91. def flops(self, N):
  92. # calculate flops for 1 window with token length of N
  93. flops = 0
  94. flops += N * self.dim * (self.dim * 2 + (self.focal_level + 1))
  95. # focal convolution
  96. for k in range(self.focal_level):
  97. flops += N * (self.kernel_sizes[k]**2 + 1) * self.dim
  98. # global gating
  99. flops += N * 1 * self.dim
  100. # self.linear
  101. flops += N * self.dim * (self.dim + 1)
  102. # x = self.proj(x)
  103. flops += N * self.dim * self.dim
  104. return flops
  105. class FocalNetBlock(nn.Module):
  106. r"""Focal Modulation Network Block.
  107. Args:
  108. dim (int): Number of input channels.
  109. input_resolution (tuple[int]): Input resulotion.
  110. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  111. drop (float, optional): Dropout rate. Default: 0.0
  112. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  113. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  114. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  115. focal_level (int): Number of focal levels.
  116. focal_window (int): Focal window size at first focal level
  117. use_layerscale (bool): Whether use layerscale
  118. layerscale_value (float): Initial layerscale value
  119. use_postln (bool): Whether use layernorm after modulation
  120. """
  121. def __init__(
  122. self,
  123. dim,
  124. input_resolution=None,
  125. mlp_ratio=4.0,
  126. drop=0.0,
  127. drop_path=0.0,
  128. act_layer=nn.GELU,
  129. norm_layer=nn.LayerNorm,
  130. focal_level=1,
  131. focal_window=3,
  132. max_kh=None,
  133. use_layerscale=False,
  134. layerscale_value=1e-4,
  135. use_postln=False,
  136. use_postln_in_modulation=False,
  137. normalize_modulator=False,
  138. ):
  139. super().__init__()
  140. self.dim = dim
  141. self.input_resolution = input_resolution
  142. self.mlp_ratio = mlp_ratio
  143. self.focal_window = focal_window
  144. self.focal_level = focal_level
  145. self.use_postln = use_postln
  146. self.norm1 = norm_layer(dim)
  147. self.modulation = FocalModulation(
  148. dim,
  149. proj_drop=drop,
  150. focal_window=focal_window,
  151. focal_level=self.focal_level,
  152. max_kh=max_kh,
  153. use_postln_in_modulation=use_postln_in_modulation,
  154. normalize_modulator=normalize_modulator,
  155. )
  156. self.drop_path = DropPath(
  157. drop_path) if drop_path > 0.0 else nn.Identity()
  158. self.norm2 = norm_layer(dim)
  159. mlp_hidden_dim = int(dim * mlp_ratio)
  160. self.mlp = Mlp(in_features=dim,
  161. hidden_features=mlp_hidden_dim,
  162. act_layer=act_layer,
  163. drop=drop)
  164. self.gamma_1 = 1.0
  165. self.gamma_2 = 1.0
  166. if use_layerscale:
  167. self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)),
  168. requires_grad=True)
  169. self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)),
  170. requires_grad=True)
  171. self.H = None
  172. self.W = None
  173. def forward(self, x):
  174. H, W = self.H, self.W
  175. B, L, C = x.shape
  176. shortcut = x
  177. # Focal Modulation
  178. x = x if self.use_postln else self.norm1(x)
  179. x = x.view(B, H, W, C)
  180. x = self.modulation(x).view(B, H * W, C)
  181. x = x if not self.use_postln else self.norm1(x)
  182. # FFN
  183. x = shortcut + self.drop_path(self.gamma_1 * x)
  184. x = x + self.drop_path(self.gamma_2 * (self.norm2(
  185. self.mlp(x)) if self.use_postln else self.mlp(self.norm2(x))))
  186. return x
  187. def extra_repr(self) -> str:
  188. return f'dim={self.dim}, input_resolution={self.input_resolution}, ' f'mlp_ratio={self.mlp_ratio}'
  189. def flops(self):
  190. flops = 0
  191. H, W = self.input_resolution
  192. # norm1
  193. flops += self.dim * H * W
  194. # W-MSA/SW-MSA
  195. flops += self.modulation.flops(H * W)
  196. # mlp
  197. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  198. # norm2
  199. flops += self.dim * H * W
  200. return flops
  201. class BasicLayer(nn.Module):
  202. """A basic Focal Transformer layer for one stage.
  203. Args:
  204. dim (int): Number of input channels.
  205. input_resolution (tuple[int]): Input resolution.
  206. depth (int): Number of blocks.
  207. window_size (int): Local window size.
  208. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  209. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  210. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  211. drop (float, optional): Dropout rate. Default: 0.0
  212. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  213. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  214. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  215. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  216. focal_level (int): Number of focal levels
  217. focal_window (int): Focal window size at first focal level
  218. use_layerscale (bool): Whether use layerscale
  219. layerscale_value (float): Initial layerscale value
  220. use_postln (bool): Whether use layernorm after modulation
  221. """
  222. def __init__(
  223. self,
  224. dim,
  225. out_dim,
  226. input_resolution,
  227. depth,
  228. mlp_ratio=4.0,
  229. drop=0.0,
  230. drop_path=0.0,
  231. norm_layer=nn.LayerNorm,
  232. downsample=None,
  233. downsample_kernel=[],
  234. use_checkpoint=False,
  235. focal_level=1,
  236. focal_window=1,
  237. use_conv_embed=False,
  238. use_layerscale=False,
  239. layerscale_value=1e-4,
  240. use_postln=False,
  241. use_postln_in_modulation=False,
  242. normalize_modulator=False,
  243. ):
  244. super().__init__()
  245. self.dim = dim
  246. self.input_resolution = input_resolution
  247. self.depth = depth
  248. self.use_checkpoint = use_checkpoint
  249. # build blocks
  250. self.blocks = nn.ModuleList([
  251. FocalNetBlock(
  252. dim=dim,
  253. input_resolution=input_resolution,
  254. mlp_ratio=mlp_ratio,
  255. drop=drop,
  256. drop_path=drop_path[i]
  257. if isinstance(drop_path, list) else drop_path,
  258. norm_layer=norm_layer,
  259. focal_level=focal_level,
  260. focal_window=focal_window,
  261. use_layerscale=use_layerscale,
  262. layerscale_value=layerscale_value,
  263. use_postln=use_postln,
  264. use_postln_in_modulation=use_postln_in_modulation,
  265. normalize_modulator=normalize_modulator,
  266. ) for i in range(depth)
  267. ])
  268. if downsample is not None:
  269. self.downsample = downsample(
  270. img_size=input_resolution,
  271. patch_size=downsample_kernel,
  272. in_chans=dim,
  273. embed_dim=out_dim,
  274. use_conv_embed=use_conv_embed,
  275. norm_layer=norm_layer,
  276. is_stem=False,
  277. )
  278. else:
  279. self.downsample = None
  280. def forward(self, x, H, W):
  281. for blk in self.blocks:
  282. blk.H, blk.W = H, W
  283. if self.use_checkpoint:
  284. x = checkpoint.checkpoint(blk, x)
  285. else:
  286. x = blk(x)
  287. if self.downsample is not None:
  288. x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
  289. x, Ho, Wo = self.downsample(x)
  290. else:
  291. Ho, Wo = H, W
  292. return x, Ho, Wo
  293. def extra_repr(self) -> str:
  294. return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
  295. def flops(self):
  296. flops = 0
  297. for blk in self.blocks:
  298. flops += blk.flops()
  299. if self.downsample is not None:
  300. flops += self.downsample.flops()
  301. return flops
  302. class PatchEmbed(nn.Module):
  303. r"""Image to Patch Embedding
  304. Args:
  305. img_size (int): Image size. Default: 224.
  306. patch_size (int): Patch token size. Default: 4.
  307. in_chans (int): Number of input image channels. Default: 3.
  308. embed_dim (int): Number of linear projection output channels. Default: 96.
  309. norm_layer (nn.Module, optional): Normalization layer. Default: None
  310. """
  311. def __init__(self,
  312. img_size=(224, 224),
  313. patch_size=[4, 4],
  314. in_chans=3,
  315. embed_dim=96,
  316. use_conv_embed=False,
  317. norm_layer=None,
  318. is_stem=False):
  319. super().__init__()
  320. # patch_size = to_2tuple(patch_size)
  321. patches_resolution = [
  322. img_size[0] // patch_size[0], img_size[1] // patch_size[1]
  323. ]
  324. self.img_size = img_size
  325. self.patch_size = patch_size
  326. self.patches_resolution = patches_resolution
  327. self.num_patches = patches_resolution[0] * patches_resolution[1]
  328. self.in_chans = in_chans
  329. self.embed_dim = embed_dim
  330. if use_conv_embed:
  331. # if we choose to use conv embedding, then we treat the stem and non-stem differently
  332. if is_stem:
  333. kernel_size = 7
  334. padding = 2
  335. stride = 4
  336. else:
  337. kernel_size = 3
  338. padding = 1
  339. stride = 2
  340. self.proj = nn.Conv2d(in_chans,
  341. embed_dim,
  342. kernel_size=kernel_size,
  343. stride=stride,
  344. padding=padding)
  345. else:
  346. self.proj = nn.Conv2d(in_chans,
  347. embed_dim,
  348. kernel_size=patch_size,
  349. stride=patch_size)
  350. if norm_layer is not None:
  351. self.norm = norm_layer(embed_dim)
  352. else:
  353. self.norm = None
  354. def forward(self, x):
  355. B, C, H, W = x.shape
  356. x = self.proj(x)
  357. H, W = x.shape[2:]
  358. x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
  359. if self.norm is not None:
  360. x = self.norm(x)
  361. return x, H, W
  362. def flops(self):
  363. Ho, Wo = self.patches_resolution
  364. flops = Ho * Wo * self.embed_dim * self.in_chans * (
  365. self.patch_size[0] * self.patch_size[1])
  366. if self.norm is not None:
  367. flops += Ho * Wo * self.embed_dim
  368. return flops
  369. class FocalSVTR(nn.Module):
  370. r"""Focal Modulation Networks (FocalNets)
  371. Args:
  372. img_size (int | tuple(int)): Input image size. Default [32, 128]
  373. patch_size (int | tuple(int)): Patch size. Default: [4, 4]
  374. in_chans (int): Number of input image channels. Default: 3
  375. num_classes (int): Number of classes for classification head. Default: 1000
  376. embed_dim (int): Patch embedding dimension. Default: 96
  377. depths (tuple(int)): Depth of each Focal Transformer layer.
  378. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  379. drop_rate (float): Dropout rate. Default: 0
  380. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  381. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  382. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  383. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  384. focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
  385. focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
  386. use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance,
  387. but we do not use it by default. Default: False
  388. use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
  389. layerscale_value (float): Value for layer scale. Default: 1e-4
  390. use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models)
  391. """
  392. def __init__(
  393. self,
  394. img_size=[32, 128],
  395. patch_size=[4, 4],
  396. out_channels=256,
  397. out_char_num=25,
  398. in_channels=3,
  399. embed_dim=96,
  400. depths=[3, 6, 3],
  401. sub_k=[[2, 1], [2, 1], [1, 1]],
  402. last_stage=False,
  403. mlp_ratio=4.0,
  404. drop_rate=0.0,
  405. drop_path_rate=0.1,
  406. norm_layer=nn.LayerNorm,
  407. patch_norm=True,
  408. use_checkpoint=False,
  409. focal_levels=[6, 6, 6],
  410. focal_windows=[3, 3, 3],
  411. use_conv_embed=False,
  412. use_layerscale=False,
  413. layerscale_value=1e-4,
  414. use_postln=False,
  415. use_postln_in_modulation=False,
  416. normalize_modulator=False,
  417. feat2d=False,
  418. **kwargs,
  419. ):
  420. super().__init__()
  421. self.num_layers = len(depths)
  422. embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
  423. self.feat2d = feat2d
  424. self.embed_dim = embed_dim
  425. self.patch_norm = patch_norm
  426. self.num_features = embed_dim[-1]
  427. self.mlp_ratio = mlp_ratio
  428. self.patch_embed = nn.Sequential(
  429. ConvBNLayer(
  430. in_channels=in_channels,
  431. out_channels=embed_dim[0] // 2,
  432. kernel_size=3,
  433. stride=2,
  434. padding=1,
  435. act=nn.GELU,
  436. bias=None,
  437. ),
  438. ConvBNLayer(
  439. in_channels=embed_dim[0] // 2,
  440. out_channels=embed_dim[0],
  441. kernel_size=3,
  442. stride=2,
  443. padding=1,
  444. act=nn.GELU,
  445. bias=None,
  446. ),
  447. )
  448. patches_resolution = [
  449. img_size[0] // patch_size[0], img_size[1] // patch_size[1]
  450. ]
  451. self.patches_resolution = patches_resolution
  452. self.pos_drop = nn.Dropout(p=drop_rate)
  453. # stochastic depth
  454. dpr = [
  455. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  456. ] # stochastic depth decay rule
  457. # build layers
  458. self.layers = nn.ModuleList()
  459. for i_layer in range(self.num_layers):
  460. layer = BasicLayer(
  461. dim=embed_dim[i_layer],
  462. out_dim=embed_dim[i_layer + 1] if
  463. (i_layer < self.num_layers - 1) else None,
  464. input_resolution=patches_resolution,
  465. depth=depths[i_layer],
  466. mlp_ratio=self.mlp_ratio,
  467. drop=drop_rate,
  468. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  469. norm_layer=norm_layer,
  470. downsample=PatchEmbed if
  471. (i_layer < self.num_layers - 1) else None,
  472. downsample_kernel=sub_k[i_layer],
  473. focal_level=focal_levels[i_layer],
  474. focal_window=focal_windows[i_layer],
  475. use_conv_embed=use_conv_embed,
  476. use_checkpoint=use_checkpoint,
  477. use_layerscale=use_layerscale,
  478. layerscale_value=layerscale_value,
  479. use_postln=use_postln,
  480. use_postln_in_modulation=use_postln_in_modulation,
  481. normalize_modulator=normalize_modulator,
  482. )
  483. patches_resolution = [
  484. patches_resolution[0] // sub_k[i_layer][0],
  485. patches_resolution[1] // sub_k[i_layer][1]
  486. ]
  487. self.layers.append(layer)
  488. self.out_channels = self.num_features
  489. self.last_stage = last_stage
  490. if last_stage:
  491. self.out_channels = out_channels
  492. self.last_conv = nn.Linear(self.num_features,
  493. self.out_channels,
  494. bias=False)
  495. self.hardswish = nn.Hardswish()
  496. self.dropout = nn.Dropout(p=0.1)
  497. # self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
  498. # self.last_conv = nn.Conv2d(
  499. # in_channels=self.num_features,
  500. # out_channels=self.out_channels,
  501. # kernel_size=1,
  502. # stride=1,
  503. # padding=0,
  504. # bias=False,
  505. # )
  506. # self.hardswish = nn.Hardswish()
  507. # self.dropout = nn.Dropout(p=0.1)
  508. self.apply(self._init_weights)
  509. def _init_weights(self, m):
  510. if isinstance(m, nn.Linear):
  511. trunc_normal_(m.weight, std=0.02)
  512. if isinstance(m, nn.Linear) and m.bias is not None:
  513. nn.init.constant_(m.bias, 0)
  514. elif isinstance(m, nn.LayerNorm):
  515. nn.init.constant_(m.bias, 0)
  516. nn.init.constant_(m.weight, 1.0)
  517. elif isinstance(m, nn.Conv2d):
  518. nn.init.kaiming_normal_(m.weight,
  519. mode='fan_out',
  520. nonlinearity='relu')
  521. @torch.jit.ignore
  522. def no_weight_decay(self):
  523. return {'patch_embed', 'downsample'}
  524. def forward(self, x):
  525. if len(x.shape) == 5:
  526. x = x.flatten(0, 1)
  527. x = self.patch_embed(x)
  528. H, W = x.shape[2:]
  529. x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
  530. x = self.pos_drop(x)
  531. for layer in self.layers:
  532. x, H, W = layer(x, H, W)
  533. if self.feat2d:
  534. x = x.transpose(1, 2).reshape(-1, self.num_features, H, W)
  535. if self.last_stage:
  536. x = x.reshape(-1, H, W, self.num_features).mean(1)
  537. x = self.last_conv(x)
  538. x = self.hardswish(x)
  539. x = self.dropout(x)
  540. # x = self.avg_pool(x.transpose(1, 2).reshape(-1, self.num_features, H, W))
  541. # x = self.last_conv(x)
  542. # x = self.hardswish(x)
  543. # x = self.dropout(x)
  544. # x = x.flatten(2).transpose(1, 2)
  545. return x
  546. def flops(self):
  547. flops = 0
  548. flops += self.patch_embed.flops()
  549. for i, layer in enumerate(self.layers):
  550. flops += layer.flops()
  551. flops += self.num_features * self.patches_resolution[
  552. 0] * self.patches_resolution[1] // (2**self.num_layers)
  553. return flops