rec_mobilenet_v3.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import torch.nn as nn
  2. from .det_mobilenet_v3 import ConvBNLayer, ResidualUnit, make_divisible
  3. class MobileNetV3(nn.Module):
  4. def __init__(self,
  5. in_channels=3,
  6. model_name='small',
  7. scale=0.5,
  8. large_stride=None,
  9. small_stride=None,
  10. **kwargs):
  11. super(MobileNetV3, self).__init__()
  12. if small_stride is None:
  13. small_stride = [2, 2, 2, 2]
  14. if large_stride is None:
  15. large_stride = [1, 2, 2, 2]
  16. assert isinstance(
  17. large_stride,
  18. list), 'large_stride type must ' 'be list but got {}'.format(
  19. type(large_stride))
  20. assert isinstance(
  21. small_stride,
  22. list), 'small_stride type must ' 'be list but got {}'.format(
  23. type(small_stride))
  24. assert len(
  25. large_stride
  26. ) == 4, 'large_stride length must be ' '4 but got {}'.format(
  27. len(large_stride))
  28. assert len(
  29. small_stride
  30. ) == 4, 'small_stride length must be ' '4 but got {}'.format(
  31. len(small_stride))
  32. if model_name == 'large':
  33. cfg = [
  34. # k, exp, c, se, nl, s,
  35. [3, 16, 16, False, 'relu', large_stride[0]],
  36. [3, 64, 24, False, 'relu', (large_stride[1], 1)],
  37. [3, 72, 24, False, 'relu', 1],
  38. [5, 72, 40, True, 'relu', (large_stride[2], 1)],
  39. [5, 120, 40, True, 'relu', 1],
  40. [5, 120, 40, True, 'relu', 1],
  41. [3, 240, 80, False, 'hard_swish', 1],
  42. [3, 200, 80, False, 'hard_swish', 1],
  43. [3, 184, 80, False, 'hard_swish', 1],
  44. [3, 184, 80, False, 'hard_swish', 1],
  45. [3, 480, 112, True, 'hard_swish', 1],
  46. [3, 672, 112, True, 'hard_swish', 1],
  47. [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
  48. [5, 960, 160, True, 'hard_swish', 1],
  49. [5, 960, 160, True, 'hard_swish', 1],
  50. ]
  51. cls_ch_squeeze = 960
  52. elif model_name == 'small':
  53. cfg = [
  54. # k, exp, c, se, nl, s,
  55. [3, 16, 16, True, 'relu', (small_stride[0], 1)],
  56. [3, 72, 24, False, 'relu', (small_stride[1], 1)],
  57. [3, 88, 24, False, 'relu', 1],
  58. [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
  59. [5, 240, 40, True, 'hard_swish', 1],
  60. [5, 240, 40, True, 'hard_swish', 1],
  61. [5, 120, 48, True, 'hard_swish', 1],
  62. [5, 144, 48, True, 'hard_swish', 1],
  63. [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
  64. [5, 576, 96, True, 'hard_swish', 1],
  65. [5, 576, 96, True, 'hard_swish', 1],
  66. ]
  67. cls_ch_squeeze = 576
  68. else:
  69. raise NotImplementedError('mode[' + model_name +
  70. '_model] is not implemented!')
  71. supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
  72. assert scale in supported_scale, 'supported scales are {} but input scale is {}'.format(
  73. supported_scale, scale)
  74. inplanes = 16
  75. # conv1
  76. self.conv1 = ConvBNLayer(
  77. in_channels=in_channels,
  78. out_channels=make_divisible(inplanes * scale),
  79. kernel_size=3,
  80. stride=2,
  81. padding=1,
  82. groups=1,
  83. if_act=True,
  84. act='hard_swish',
  85. )
  86. i = 0
  87. block_list = []
  88. inplanes = make_divisible(inplanes * scale)
  89. for k, exp, c, se, nl, s in cfg:
  90. block_list.append(
  91. ResidualUnit(
  92. in_channels=inplanes,
  93. mid_channels=make_divisible(scale * exp),
  94. out_channels=make_divisible(scale * c),
  95. kernel_size=k,
  96. stride=s,
  97. use_se=se,
  98. act=nl,
  99. name='conv' + str(i + 2),
  100. ))
  101. inplanes = make_divisible(scale * c)
  102. i += 1
  103. self.blocks = nn.Sequential(*block_list)
  104. self.conv2 = ConvBNLayer(
  105. in_channels=inplanes,
  106. out_channels=make_divisible(scale * cls_ch_squeeze),
  107. kernel_size=1,
  108. stride=1,
  109. padding=0,
  110. groups=1,
  111. if_act=True,
  112. act='hard_swish',
  113. )
  114. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  115. self.out_channels = make_divisible(scale * cls_ch_squeeze)
  116. def forward(self, x):
  117. x = self.conv1(x)
  118. x = self.blocks(x)
  119. x = self.conv2(x)
  120. x = self.pool(x)
  121. return x