autostr_encoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. class IdentityLayer(nn.Module):
  5. def __init__(self):
  6. super(IdentityLayer, self).__init__()
  7. def forward(self, x):
  8. return x
  9. @staticmethod
  10. def is_zero_layer():
  11. return False
  12. class ZeroLayer(nn.Module):
  13. def __init__(self, stride):
  14. super(ZeroLayer, self).__init__()
  15. self.stride = stride
  16. def forward(self, x):
  17. n, c, h, w = x.shape
  18. h //= self.stride[0]
  19. w //= self.stride[1]
  20. device = x.device
  21. padding = torch.zeros(n, c, h, w, device=device, requires_grad=False)
  22. return padding
  23. @staticmethod
  24. def is_zero_layer():
  25. return True
  26. def get_flops(self, x):
  27. return 0, self.forward(x)
  28. def get_same_padding(kernel_size):
  29. if isinstance(kernel_size, tuple):
  30. assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
  31. p1 = get_same_padding(kernel_size[0])
  32. p2 = get_same_padding(kernel_size[1])
  33. return p1, p2
  34. assert isinstance(kernel_size,
  35. int), 'kernel size should be either `int` or `tuple`'
  36. assert kernel_size % 2 > 0, 'kernel size should be odd number'
  37. return kernel_size // 2
  38. class MBInvertedConvLayer(nn.Module):
  39. def __init__(self,
  40. in_channels,
  41. out_channels,
  42. kernel_size=3,
  43. stride=(1, 1),
  44. expand_ratio=6,
  45. mid_channels=None):
  46. super(MBInvertedConvLayer, self).__init__()
  47. self.in_channels = in_channels
  48. self.out_channels = out_channels
  49. self.kernel_size = kernel_size
  50. self.stride = stride
  51. self.expand_ratio = expand_ratio
  52. self.mid_channels = mid_channels
  53. feature_dim = round(
  54. self.in_channels *
  55. self.expand_ratio) if mid_channels is None else mid_channels
  56. if self.expand_ratio == 1:
  57. self.inverted_bottleneck = None
  58. else:
  59. self.inverted_bottleneck = nn.Sequential(
  60. OrderedDict([
  61. ('conv',
  62. nn.Conv2d(self.in_channels,
  63. feature_dim,
  64. 1,
  65. 1,
  66. 0,
  67. bias=False)),
  68. ('bn', nn.BatchNorm2d(feature_dim)),
  69. ('act', nn.ReLU6(inplace=True)),
  70. ]))
  71. pad = get_same_padding(self.kernel_size)
  72. self.depth_conv = nn.Sequential(
  73. OrderedDict([
  74. ('conv',
  75. nn.Conv2d(feature_dim,
  76. feature_dim,
  77. kernel_size,
  78. stride,
  79. pad,
  80. groups=feature_dim,
  81. bias=False)),
  82. ('bn', nn.BatchNorm2d(feature_dim)),
  83. ('act', nn.ReLU6(inplace=True)),
  84. ]))
  85. self.point_conv = nn.Sequential(
  86. OrderedDict([
  87. ('conv',
  88. nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
  89. ('bn', nn.BatchNorm2d(out_channels)),
  90. ]))
  91. def forward(self, x):
  92. if self.inverted_bottleneck:
  93. x = self.inverted_bottleneck(x)
  94. x = self.depth_conv(x)
  95. x = self.point_conv(x)
  96. return x
  97. @staticmethod
  98. def is_zero_layer():
  99. return False
  100. def conv_func_by_name(name):
  101. name2ops = {
  102. 'Identity': lambda in_C, out_C, S: IdentityLayer(),
  103. 'Zero': lambda in_C, out_C, S: ZeroLayer(stride=S),
  104. }
  105. name2ops.update({
  106. '3x3_MBConv1':
  107. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
  108. '3x3_MBConv2':
  109. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
  110. '3x3_MBConv3':
  111. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
  112. '3x3_MBConv4':
  113. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
  114. '3x3_MBConv5':
  115. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
  116. '3x3_MBConv6':
  117. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
  118. #######################################################################################
  119. '5x5_MBConv1':
  120. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
  121. '5x5_MBConv2':
  122. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
  123. '5x5_MBConv3':
  124. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
  125. '5x5_MBConv4':
  126. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
  127. '5x5_MBConv5':
  128. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
  129. '5x5_MBConv6':
  130. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
  131. #######################################################################################
  132. '7x7_MBConv1':
  133. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
  134. '7x7_MBConv2':
  135. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
  136. '7x7_MBConv3':
  137. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
  138. '7x7_MBConv4':
  139. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
  140. '7x7_MBConv5':
  141. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
  142. '7x7_MBConv6':
  143. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
  144. })
  145. return name2ops[name]
  146. def build_candidate_ops(candidate_ops, in_channels, out_channels, stride,
  147. ops_order):
  148. if candidate_ops is None:
  149. raise ValueError('please specify a candidate set')
  150. name2ops = {
  151. 'Identity':
  152. lambda in_C, out_C, S: IdentityLayer(in_C, out_C, ops_order=ops_order),
  153. 'Zero':
  154. lambda in_C, out_C, S: ZeroLayer(stride=S),
  155. }
  156. # add MBConv layers
  157. name2ops.update({
  158. '3x3_MBConv1':
  159. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1),
  160. '3x3_MBConv2':
  161. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2),
  162. '3x3_MBConv3':
  163. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3),
  164. '3x3_MBConv4':
  165. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4),
  166. '3x3_MBConv5':
  167. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5),
  168. '3x3_MBConv6':
  169. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6),
  170. #######################################################################################
  171. '5x5_MBConv1':
  172. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1),
  173. '5x5_MBConv2':
  174. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2),
  175. '5x5_MBConv3':
  176. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3),
  177. '5x5_MBConv4':
  178. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4),
  179. '5x5_MBConv5':
  180. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5),
  181. '5x5_MBConv6':
  182. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6),
  183. #######################################################################################
  184. '7x7_MBConv1':
  185. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1),
  186. '7x7_MBConv2':
  187. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2),
  188. '7x7_MBConv3':
  189. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3),
  190. '7x7_MBConv4':
  191. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4),
  192. '7x7_MBConv5':
  193. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5),
  194. '7x7_MBConv6':
  195. lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6),
  196. })
  197. return [
  198. name2ops[name](in_channels, out_channels, stride)
  199. for name in candidate_ops
  200. ]
  201. class MobileInvertedResidualBlock(nn.Module):
  202. def __init__(self, mobile_inverted_conv, shortcut):
  203. super(MobileInvertedResidualBlock, self).__init__()
  204. self.mobile_inverted_conv = mobile_inverted_conv
  205. self.shortcut = shortcut
  206. def forward(self, x):
  207. if self.mobile_inverted_conv.is_zero_layer():
  208. res = x
  209. elif self.shortcut is None or self.shortcut.is_zero_layer():
  210. res = self.mobile_inverted_conv(x)
  211. else:
  212. conv_x = self.mobile_inverted_conv(x)
  213. skip_x = self.shortcut(x)
  214. res = skip_x + conv_x
  215. return res
  216. class AutoSTREncoder(nn.Module):
  217. def __init__(self,
  218. in_channels,
  219. out_dim=256,
  220. with_lstm=True,
  221. stride_stages='[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]',
  222. n_cell_stages=[3, 3, 3, 3, 3],
  223. conv_op_ids=[5, 5, 5, 5, 5, 5, 5, 6, 6, 5, 4, 3, 4, 6, 6],
  224. **kwargs):
  225. super().__init__()
  226. self.first_conv = nn.Sequential(
  227. nn.Conv2d(in_channels,
  228. 32,
  229. kernel_size=(3, 3),
  230. stride=1,
  231. padding=1,
  232. bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
  233. stride_stages = eval(stride_stages)
  234. width_stages = [32, 64, 128, 256, 512]
  235. conv_candidates = [
  236. '5x5_MBConv1', '5x5_MBConv3', '5x5_MBConv6', '3x3_MBConv1',
  237. '3x3_MBConv3', '3x3_MBConv6', 'Zero'
  238. ]
  239. assert len(conv_op_ids) == sum(n_cell_stages)
  240. blocks = []
  241. input_channel = 32
  242. for width, n_cell, s in zip(width_stages, n_cell_stages,
  243. stride_stages):
  244. for i in range(n_cell):
  245. if i == 0:
  246. stride = s
  247. else:
  248. stride = (1, 1)
  249. block_i = len(blocks)
  250. conv_op = conv_func_by_name(
  251. conv_candidates[conv_op_ids[block_i]])(input_channel,
  252. width, stride)
  253. if stride == (1, 1) and input_channel == width:
  254. shortcut = IdentityLayer()
  255. else:
  256. shortcut = None
  257. inverted_residual_block = MobileInvertedResidualBlock(
  258. conv_op, shortcut)
  259. blocks.append(inverted_residual_block)
  260. input_channel = width
  261. self.out_channels = input_channel
  262. self.blocks = nn.ModuleList(blocks)
  263. # with_lstm = False
  264. self.with_lstm = with_lstm
  265. if with_lstm:
  266. self.rnn = nn.LSTM(input_channel,
  267. out_dim // 2,
  268. bidirectional=True,
  269. num_layers=2,
  270. batch_first=True)
  271. self.out_channels = out_dim
  272. for m in self.modules():
  273. if isinstance(m, nn.Conv2d):
  274. nn.init.kaiming_normal_(m.weight,
  275. mode='fan_out',
  276. nonlinearity='relu')
  277. elif isinstance(m, nn.BatchNorm2d):
  278. nn.init.constant_(m.weight, 1)
  279. nn.init.constant_(m.bias, 0)
  280. def forward(self, x):
  281. x = self.first_conv(x)
  282. for block in self.blocks:
  283. x = block(x)
  284. cnn_feat = x.squeeze(dim=2)
  285. cnn_feat = cnn_feat.transpose(2, 1)
  286. if self.with_lstm:
  287. rnn_feat, _ = self.rnn(cnn_feat)
  288. return rnn_feat
  289. else:
  290. return cnn_feat