repvit.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. """
  2. This code is refer from:
  3. https://github.com/THU-MIG/RepViT
  4. """
  5. import torch.nn as nn
  6. import torch
  7. from torch.nn.init import constant_
  8. def _make_divisible(v, divisor, min_value=None):
  9. """
  10. This function is taken from the original tf repo.
  11. It ensures that all layers have a channel number that is divisible by 8
  12. It can be seen here:
  13. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  14. :param v:
  15. :param divisor:
  16. :param min_value:
  17. :return:
  18. """
  19. if min_value is None:
  20. min_value = divisor
  21. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  22. # Make sure that round down does not go down by more than 10%.
  23. if new_v < 0.9 * v:
  24. new_v += divisor
  25. return new_v
  26. def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
  27. min_value = min_value or divisor
  28. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  29. # Make sure that round down does not go down by more than 10%.
  30. if new_v < round_limit * v:
  31. new_v += divisor
  32. return new_v
  33. class SEModule(nn.Module):
  34. """SE Module as defined in original SE-Nets with a few additions
  35. Additions include:
  36. * divisor can be specified to keep channels % div == 0 (default: 8)
  37. * reduction channels can be specified directly by arg (if rd_channels is set)
  38. * reduction channels can be specified by float rd_ratio (default: 1/16)
  39. * global max pooling can be added to the squeeze aggregation
  40. * customizable activation, normalization, and gate layer
  41. """
  42. def __init__(
  43. self,
  44. channels,
  45. rd_ratio=1.0 / 16,
  46. rd_channels=None,
  47. rd_divisor=8,
  48. act_layer=nn.ReLU,
  49. ):
  50. super(SEModule, self).__init__()
  51. if not rd_channels:
  52. rd_channels = make_divisible(channels * rd_ratio,
  53. rd_divisor,
  54. round_limit=0.0)
  55. self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
  56. self.act = act_layer()
  57. self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
  58. def forward(self, x):
  59. x_se = x.mean((2, 3), keepdim=True)
  60. x_se = self.fc1(x_se)
  61. x_se = self.act(x_se)
  62. x_se = self.fc2(x_se)
  63. return x * torch.sigmoid(x_se)
  64. class Conv2D_BN(nn.Sequential):
  65. def __init__(
  66. self,
  67. a,
  68. b,
  69. ks=1,
  70. stride=1,
  71. pad=0,
  72. dilation=1,
  73. groups=1,
  74. bn_weight_init=1,
  75. resolution=-10000,
  76. ):
  77. super().__init__()
  78. self.add_module(
  79. 'c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups,
  80. bias=False))
  81. self.add_module('bn', nn.BatchNorm2d(b))
  82. constant_(self.bn.weight, bn_weight_init)
  83. constant_(self.bn.bias, 0)
  84. @torch.no_grad()
  85. def fuse(self):
  86. c, bn = self._modules.values()
  87. w = bn.weight / (bn.running_var + bn.eps)**0.5
  88. w = c.weight * w[:, None, None, None]
  89. b = bn.bias - bn.running_mean * bn.weight / \
  90. (bn.running_var + bn.eps)**0.5
  91. m = nn.Conv2d(w.size(1) * self.c.groups,
  92. w.size(0),
  93. w.shape[2:],
  94. stride=self.c.stride,
  95. padding=self.c.padding,
  96. dilation=self.c.dilation,
  97. groups=self.c.groups,
  98. device=c.weight.device)
  99. m.weight.data.copy_(w)
  100. m.bias.data.copy_(b)
  101. return m
  102. class Residual(torch.nn.Module):
  103. def __init__(self, m, drop=0.):
  104. super().__init__()
  105. self.m = m
  106. self.drop = drop
  107. def forward(self, x):
  108. if self.training and self.drop > 0:
  109. return x + self.m(x) * torch.rand(
  110. x.size(0), 1, 1, 1, device=x.device).ge_(
  111. self.drop).div(1 - self.drop).detach()
  112. else:
  113. return x + self.m(x)
  114. @torch.no_grad()
  115. def fuse(self):
  116. if isinstance(self.m, Conv2D_BN):
  117. m = self.m.fuse()
  118. assert (m.groups == m.in_channels)
  119. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  120. identity = nn.functional.pad(identity, [1, 1, 1, 1])
  121. m.weight += identity.to(m.weight.device)
  122. return m
  123. elif isinstance(self.m, nn.Conv2d):
  124. m = self.m
  125. assert (m.groups != m.in_channels)
  126. identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
  127. identity = nn.functional.pad(identity, [1, 1, 1, 1])
  128. m.weight += identity.to(m.weight.device)
  129. return m
  130. else:
  131. return self
  132. class RepVGGDW(nn.Module):
  133. def __init__(self, ed) -> None:
  134. super().__init__()
  135. self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed)
  136. self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
  137. self.dim = ed
  138. self.bn = nn.BatchNorm2d(ed)
  139. def forward(self, x):
  140. return self.bn((self.conv(x) + self.conv1(x)) + x)
  141. @torch.no_grad()
  142. def fuse(self):
  143. conv = self.conv.fuse()
  144. conv1 = self.conv1
  145. conv_w = conv.weight
  146. conv_b = conv.bias
  147. conv1_w = conv1.weight
  148. conv1_b = conv1.bias
  149. conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
  150. identity = nn.functional.pad(
  151. torch.ones(conv1_w.shape[0],
  152. conv1_w.shape[1],
  153. 1,
  154. 1,
  155. device=conv1_w.device), [1, 1, 1, 1])
  156. final_conv_w = conv_w + conv1_w + identity
  157. final_conv_b = conv_b + conv1_b
  158. conv.weight.data.copy_(final_conv_w)
  159. conv.bias.data.copy_(final_conv_b)
  160. bn = self.bn
  161. w = bn.weight / (bn.running_var + bn.eps)**0.5
  162. w = conv.weight * w[:, None, None, None]
  163. b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
  164. (bn.running_var + bn.eps)**0.5
  165. conv.weight.data.copy_(w)
  166. conv.bias.data.copy_(b)
  167. return conv
  168. class RepViTBlock(nn.Module):
  169. def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se,
  170. use_hs):
  171. super(RepViTBlock, self).__init__()
  172. self.identity = stride == 1 and inp == oup
  173. assert hidden_dim == 2 * inp
  174. if stride != 1:
  175. self.token_mixer = nn.Sequential(
  176. Conv2D_BN(inp,
  177. inp,
  178. kernel_size,
  179. stride, (kernel_size - 1) // 2,
  180. groups=inp),
  181. SEModule(inp, 0.25) if use_se else nn.Identity(),
  182. Conv2D_BN(inp, oup, ks=1, stride=1, pad=0),
  183. )
  184. self.channel_mixer = Residual(
  185. nn.Sequential(
  186. # pw
  187. Conv2D_BN(oup, 2 * oup, 1, 1, 0),
  188. nn.GELU() if use_hs else nn.GELU(),
  189. # pw-linear
  190. Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
  191. ))
  192. else:
  193. assert self.identity
  194. self.token_mixer = nn.Sequential(
  195. RepVGGDW(inp),
  196. SEModule(inp, 0.25) if use_se else nn.Identity(),
  197. )
  198. self.channel_mixer = Residual(
  199. nn.Sequential(
  200. # pw
  201. Conv2D_BN(inp, hidden_dim, 1, 1, 0),
  202. nn.GELU() if use_hs else nn.GELU(),
  203. # pw-linear
  204. Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
  205. ))
  206. def forward(self, x):
  207. return self.channel_mixer(self.token_mixer(x))
  208. class RepViT(nn.Module):
  209. def __init__(self, cfgs, in_channels=3, out_indices=None):
  210. super(RepViT, self).__init__()
  211. # setting of inverted residual blocks
  212. self.cfgs = cfgs
  213. # building first layer
  214. input_channel = self.cfgs[0][2]
  215. patch_embed = nn.Sequential(
  216. Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1),
  217. nn.GELU(),
  218. Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1),
  219. )
  220. layers = [patch_embed]
  221. # building inverted residual blocks
  222. block = RepViTBlock
  223. for k, t, c, use_se, use_hs, s in self.cfgs:
  224. output_channel = _make_divisible(c, 8)
  225. exp_size = _make_divisible(input_channel * t, 8)
  226. layers.append(
  227. block(input_channel, exp_size, output_channel, k, s, use_se,
  228. use_hs))
  229. input_channel = output_channel
  230. self.features = nn.ModuleList(layers)
  231. self.out_indices = out_indices
  232. if out_indices is not None:
  233. self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices]
  234. else:
  235. self.out_channels = self.cfgs[-1][2]
  236. def forward(self, x):
  237. if self.out_indices is not None:
  238. return self.forward_det(x)
  239. return self.forward_rec(x)
  240. def forward_det(self, x):
  241. outs = []
  242. for i, f in enumerate(self.features):
  243. x = f(x)
  244. if i in self.out_indices:
  245. outs.append(x)
  246. return outs
  247. def forward_rec(self, x):
  248. for f in self.features:
  249. x = f(x)
  250. h = x.shape[2]
  251. x = nn.functional.avg_pool2d(x, [h, 2])
  252. return x
  253. def RepSVTR(in_channels=3):
  254. """
  255. Constructs a MobileNetV3-Large model
  256. """
  257. # k, t, c, SE, HS, s
  258. cfgs = [
  259. [3, 2, 96, 1, 0, 1],
  260. [3, 2, 96, 0, 0, 1],
  261. [3, 2, 96, 0, 0, 1],
  262. [3, 2, 192, 0, 1, (2, 1)],
  263. [3, 2, 192, 1, 1, 1],
  264. [3, 2, 192, 0, 1, 1],
  265. [3, 2, 192, 1, 1, 1],
  266. [3, 2, 192, 0, 1, 1],
  267. [3, 2, 192, 1, 1, 1],
  268. [3, 2, 192, 0, 1, 1],
  269. [3, 2, 384, 0, 1, (2, 1)],
  270. [3, 2, 384, 1, 1, 1],
  271. [3, 2, 384, 0, 1, 1],
  272. ]
  273. return RepViT(cfgs, in_channels=in_channels)
  274. def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]):
  275. """
  276. Constructs a MobileNetV3-Large model
  277. """
  278. # k, t, c, SE, HS, s
  279. cfgs = [
  280. [3, 2, 48, 1, 0, 1],
  281. [3, 2, 48, 0, 0, 1],
  282. [3, 2, 96, 0, 0, 2],
  283. [3, 2, 96, 1, 0, 1],
  284. [3, 2, 96, 0, 0, 1],
  285. [3, 2, 192, 0, 1, 2],
  286. [3, 2, 192, 1, 1, 1],
  287. [3, 2, 192, 0, 1, 1],
  288. [3, 2, 192, 1, 1, 1],
  289. [3, 2, 192, 0, 1, 1],
  290. [3, 2, 384, 0, 1, 2],
  291. [3, 2, 384, 1, 1, 1],
  292. [3, 2, 384, 0, 1, 1],
  293. ]
  294. return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices)