rec_hgnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class ConvBNAct(nn.Module):
  5. def __init__(self,
  6. in_channels,
  7. out_channels,
  8. kernel_size,
  9. stride,
  10. groups=1,
  11. use_act=True):
  12. super().__init__()
  13. self.use_act = use_act
  14. self.conv = nn.Conv2d(
  15. in_channels,
  16. out_channels,
  17. kernel_size,
  18. stride,
  19. padding=(kernel_size - 1) // 2,
  20. groups=groups,
  21. bias=False,
  22. )
  23. self.bn = nn.BatchNorm2d(out_channels)
  24. if self.use_act:
  25. self.act = nn.ReLU()
  26. def forward(self, x):
  27. x = self.conv(x)
  28. x = self.bn(x)
  29. if self.use_act:
  30. x = self.act(x)
  31. return x
  32. class ESEModule(nn.Module):
  33. def __init__(self, channels):
  34. super().__init__()
  35. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  36. self.conv = nn.Conv2d(
  37. in_channels=channels,
  38. out_channels=channels,
  39. kernel_size=1,
  40. stride=1,
  41. padding=0,
  42. )
  43. self.sigmoid = nn.Sigmoid()
  44. def forward(self, x):
  45. identity = x
  46. x = self.avg_pool(x)
  47. x = self.conv(x)
  48. x = self.sigmoid(x)
  49. return x * identity
  50. class HG_Block(nn.Module):
  51. def __init__(
  52. self,
  53. in_channels,
  54. mid_channels,
  55. out_channels,
  56. layer_num,
  57. identity=False,
  58. ):
  59. super().__init__()
  60. self.identity = identity
  61. self.layers = nn.ModuleList()
  62. self.layers.append(
  63. ConvBNAct(
  64. in_channels=in_channels,
  65. out_channels=mid_channels,
  66. kernel_size=3,
  67. stride=1,
  68. ))
  69. for _ in range(layer_num - 1):
  70. self.layers.append(
  71. ConvBNAct(
  72. in_channels=mid_channels,
  73. out_channels=mid_channels,
  74. kernel_size=3,
  75. stride=1,
  76. ))
  77. # feature aggregation
  78. total_channels = in_channels + layer_num * mid_channels
  79. self.aggregation_conv = ConvBNAct(
  80. in_channels=total_channels,
  81. out_channels=out_channels,
  82. kernel_size=1,
  83. stride=1,
  84. )
  85. self.att = ESEModule(out_channels)
  86. def forward(self, x):
  87. identity = x
  88. output = []
  89. output.append(x)
  90. for layer in self.layers:
  91. x = layer(x)
  92. output.append(x)
  93. x = torch.cat(output, dim=1)
  94. x = self.aggregation_conv(x)
  95. x = self.att(x)
  96. if self.identity:
  97. x += identity
  98. return x
  99. class HG_Stage(nn.Module):
  100. def __init__(
  101. self,
  102. in_channels,
  103. mid_channels,
  104. out_channels,
  105. block_num,
  106. layer_num,
  107. downsample=True,
  108. stride=[2, 1],
  109. ):
  110. super().__init__()
  111. self.downsample = downsample
  112. if downsample:
  113. self.downsample = ConvBNAct(
  114. in_channels=in_channels,
  115. out_channels=in_channels,
  116. kernel_size=3,
  117. stride=stride,
  118. groups=in_channels,
  119. use_act=False,
  120. )
  121. blocks_list = []
  122. blocks_list.append(
  123. HG_Block(in_channels,
  124. mid_channels,
  125. out_channels,
  126. layer_num,
  127. identity=False))
  128. for _ in range(block_num - 1):
  129. blocks_list.append(
  130. HG_Block(out_channels,
  131. mid_channels,
  132. out_channels,
  133. layer_num,
  134. identity=True))
  135. self.blocks = nn.Sequential(*blocks_list)
  136. def forward(self, x):
  137. if self.downsample:
  138. x = self.downsample(x)
  139. x = self.blocks(x)
  140. return x
  141. class PPHGNet(nn.Module):
  142. """
  143. PPHGNet
  144. Args:
  145. stem_channels: list. Stem channel list of PPHGNet.
  146. stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
  147. layer_num: int. Number of layers of HG_Block.
  148. use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
  149. class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
  150. dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
  151. class_num: int=1000. The number of classes.
  152. Returns:
  153. model: nn.Layer. Specific PPHGNet model depends on args.
  154. """
  155. def __init__(
  156. self,
  157. stem_channels,
  158. stage_config,
  159. layer_num,
  160. in_channels=3,
  161. det=False,
  162. out_indices=None,
  163. ):
  164. super().__init__()
  165. self.det = det
  166. self.out_indices = out_indices if out_indices is not None else [
  167. 0, 1, 2, 3
  168. ]
  169. # stem
  170. stem_channels.insert(0, in_channels)
  171. self.stem = nn.Sequential(*[
  172. ConvBNAct(
  173. in_channels=stem_channels[i],
  174. out_channels=stem_channels[i + 1],
  175. kernel_size=3,
  176. stride=2 if i == 0 else 1,
  177. ) for i in range(len(stem_channels) - 1)
  178. ])
  179. if self.det:
  180. self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  181. # stages
  182. self.stages = nn.ModuleList()
  183. self.out_channels = []
  184. for block_id, k in enumerate(stage_config):
  185. (
  186. in_channels,
  187. mid_channels,
  188. out_channels,
  189. block_num,
  190. downsample,
  191. stride,
  192. ) = stage_config[k]
  193. self.stages.append(
  194. HG_Stage(
  195. in_channels,
  196. mid_channels,
  197. out_channels,
  198. block_num,
  199. layer_num,
  200. downsample,
  201. stride,
  202. ))
  203. if block_id in self.out_indices:
  204. self.out_channels.append(out_channels)
  205. if not self.det:
  206. self.out_channels = stage_config['stage4'][2]
  207. self._init_weights()
  208. def _init_weights(self):
  209. for m in self.modules():
  210. if isinstance(m, nn.Conv2d):
  211. nn.init.kaiming_normal_(m.weight)
  212. elif isinstance(m, nn.BatchNorm2d):
  213. nn.init.ones_(m.weight)
  214. nn.init.zeros_(m.bias)
  215. elif isinstance(m, nn.Linear):
  216. nn.init.zeros_(m.bias)
  217. def forward(self, x):
  218. x = self.stem(x)
  219. if self.det:
  220. x = self.pool(x)
  221. out = []
  222. for i, stage in enumerate(self.stages):
  223. x = stage(x)
  224. if self.det and i in self.out_indices:
  225. out.append(x)
  226. if self.det:
  227. return out
  228. if self.training:
  229. x = F.adaptive_avg_pool2d(x, [1, 40])
  230. else:
  231. x = F.avg_pool2d(x, [3, 2])
  232. return x
  233. def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
  234. """
  235. PPHGNet_tiny
  236. Args:
  237. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  238. If str, means the path of the pretrained model.
  239. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  240. Returns:
  241. model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
  242. """
  243. stage_config = {
  244. # in_channels, mid_channels, out_channels, blocks, downsample
  245. 'stage1': [96, 96, 224, 1, False, [2, 1]],
  246. 'stage2': [224, 128, 448, 1, True, [1, 2]],
  247. 'stage3': [448, 160, 512, 2, True, [2, 1]],
  248. 'stage4': [512, 192, 768, 1, True, [2, 1]],
  249. }
  250. model = PPHGNet(stem_channels=[48, 48, 96],
  251. stage_config=stage_config,
  252. layer_num=5,
  253. **kwargs)
  254. return model
  255. def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
  256. """
  257. PPHGNet_small
  258. Args:
  259. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  260. If str, means the path of the pretrained model.
  261. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  262. Returns:
  263. model: nn.Layer. Specific `PPHGNet_small` model depends on args.
  264. """
  265. stage_config_det = {
  266. # in_channels, mid_channels, out_channels, blocks, downsample
  267. 'stage1': [128, 128, 256, 1, False, 2],
  268. 'stage2': [256, 160, 512, 1, True, 2],
  269. 'stage3': [512, 192, 768, 2, True, 2],
  270. 'stage4': [768, 224, 1024, 1, True, 2],
  271. }
  272. stage_config_rec = {
  273. # in_channels, mid_channels, out_channels, blocks, downsample
  274. 'stage1': [128, 128, 256, 1, True, [2, 1]],
  275. 'stage2': [256, 160, 512, 1, True, [1, 2]],
  276. 'stage3': [512, 192, 768, 2, True, [2, 1]],
  277. 'stage4': [768, 224, 1024, 1, True, [2, 1]],
  278. }
  279. model = PPHGNet(stem_channels=[64, 64, 128],
  280. stage_config=stage_config_det if det else stage_config_rec,
  281. layer_num=6,
  282. det=det,
  283. **kwargs)
  284. return model
  285. def PPHGNet_base(pretrained=False, use_ssld=True, **kwargs):
  286. """
  287. PPHGNet_base
  288. Args:
  289. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  290. If str, means the path of the pretrained model.
  291. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  292. Returns:
  293. model: nn.Layer. Specific `PPHGNet_base` model depends on args.
  294. """
  295. stage_config = {
  296. # in_channels, mid_channels, out_channels, blocks, downsample
  297. 'stage1': [160, 192, 320, 1, False, [2, 1]],
  298. 'stage2': [320, 224, 640, 2, True, [1, 2]],
  299. 'stage3': [640, 256, 960, 3, True, [2, 1]],
  300. 'stage4': [960, 288, 1280, 2, True, [2, 1]],
  301. }
  302. model = PPHGNet(stem_channels=[96, 96, 160],
  303. stage_config=stage_config,
  304. layer_num=7,
  305. **kwargs)
  306. return model