rec_resnet_31.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import torch.nn as nn
  2. __all__ = ['ResNet31']
  3. def conv3x3(in_channel, out_channel, stride=1):
  4. return nn.Conv2d(in_channel,
  5. out_channel,
  6. kernel_size=3,
  7. stride=stride,
  8. padding=1,
  9. bias=False)
  10. class BasicBlock(nn.Module):
  11. expansion = 1
  12. def __init__(self, in_channels, channels, stride=1, downsample=False):
  13. super().__init__()
  14. self.conv1 = conv3x3(in_channels, channels, stride)
  15. self.bn1 = nn.BatchNorm2d(channels)
  16. self.relu = nn.ReLU()
  17. self.conv2 = conv3x3(channels, channels)
  18. self.bn2 = nn.BatchNorm2d(channels)
  19. self.downsample = downsample
  20. if downsample:
  21. self.downsample = nn.Sequential(
  22. nn.Conv2d(in_channels,
  23. channels * self.expansion,
  24. 1,
  25. stride,
  26. bias=False),
  27. nn.BatchNorm2d(channels * self.expansion),
  28. )
  29. else:
  30. self.downsample = nn.Sequential()
  31. self.stride = stride
  32. def forward(self, x):
  33. residual = x
  34. out = self.conv1(x)
  35. out = self.bn1(out)
  36. out = self.relu(out)
  37. out = self.conv2(out)
  38. out = self.bn2(out)
  39. if self.downsample:
  40. residual = self.downsample(x)
  41. out += residual
  42. out = self.relu(out)
  43. return out
  44. class ResNet31(nn.Module):
  45. """
  46. Args:
  47. in_channels (int): Number of channels of input image tensor.
  48. layers (list[int]): List of BasicBlock number for each stage.
  49. channels (list[int]): List of out_channels of Conv2d layer.
  50. out_indices (None | Sequence[int]): Indices of output stages.
  51. last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
  52. """
  53. def __init__(
  54. self,
  55. in_channels=3,
  56. layers=[1, 2, 5, 3],
  57. channels=[64, 128, 256, 256, 512, 512, 512],
  58. out_indices=None,
  59. last_stage_pool=False,
  60. ):
  61. super(ResNet31, self).__init__()
  62. assert isinstance(in_channels, int)
  63. assert isinstance(last_stage_pool, bool)
  64. self.out_indices = out_indices
  65. self.last_stage_pool = last_stage_pool
  66. # conv 1 (Conv Conv)
  67. self.conv1_1 = nn.Conv2d(in_channels,
  68. channels[0],
  69. kernel_size=3,
  70. stride=1,
  71. padding=1)
  72. self.bn1_1 = nn.BatchNorm2d(channels[0])
  73. self.relu1_1 = nn.ReLU(inplace=True)
  74. self.conv1_2 = nn.Conv2d(channels[0],
  75. channels[1],
  76. kernel_size=3,
  77. stride=1,
  78. padding=1)
  79. self.bn1_2 = nn.BatchNorm2d(channels[1])
  80. self.relu1_2 = nn.ReLU(inplace=True)
  81. # conv 2 (Max-pooling, Residual block, Conv)
  82. self.pool2 = nn.MaxPool2d(kernel_size=2,
  83. stride=2,
  84. padding=0,
  85. ceil_mode=True)
  86. self.block2 = self._make_layer(channels[1], channels[2], layers[0])
  87. self.conv2 = nn.Conv2d(channels[2],
  88. channels[2],
  89. kernel_size=3,
  90. stride=1,
  91. padding=1)
  92. self.bn2 = nn.BatchNorm2d(channels[2])
  93. self.relu2 = nn.ReLU(inplace=True)
  94. # conv 3 (Max-pooling, Residual block, Conv)
  95. self.pool3 = nn.MaxPool2d(kernel_size=2,
  96. stride=2,
  97. padding=0,
  98. ceil_mode=True)
  99. self.block3 = self._make_layer(channels[2], channels[3], layers[1])
  100. self.conv3 = nn.Conv2d(channels[3],
  101. channels[3],
  102. kernel_size=3,
  103. stride=1,
  104. padding=1)
  105. self.bn3 = nn.BatchNorm2d(channels[3])
  106. self.relu3 = nn.ReLU(inplace=True)
  107. # conv 4 (Max-pooling, Residual block, Conv)
  108. self.pool4 = nn.MaxPool2d(kernel_size=(2, 1),
  109. stride=(2, 1),
  110. padding=0,
  111. ceil_mode=True)
  112. self.block4 = self._make_layer(channels[3], channels[4], layers[2])
  113. self.conv4 = nn.Conv2d(channels[4],
  114. channels[4],
  115. kernel_size=3,
  116. stride=1,
  117. padding=1)
  118. self.bn4 = nn.BatchNorm2d(channels[4])
  119. self.relu4 = nn.ReLU(inplace=True)
  120. # conv 5 ((Max-pooling), Residual block, Conv)
  121. self.pool5 = None
  122. if self.last_stage_pool:
  123. self.pool5 = nn.MaxPool2d(kernel_size=2,
  124. stride=2,
  125. padding=0,
  126. ceil_mode=True)
  127. self.block5 = self._make_layer(channels[4], channels[5], layers[3])
  128. self.conv5 = nn.Conv2d(channels[5],
  129. channels[5],
  130. kernel_size=3,
  131. stride=1,
  132. padding=1)
  133. self.bn5 = nn.BatchNorm2d(channels[5])
  134. self.relu5 = nn.ReLU(inplace=True)
  135. self.out_channels = channels[-1]
  136. def _make_layer(self, input_channels, output_channels, blocks):
  137. layers = []
  138. for _ in range(blocks):
  139. downsample = None
  140. if input_channels != output_channels:
  141. downsample = nn.Sequential(
  142. nn.Conv2d(
  143. input_channels,
  144. output_channels,
  145. kernel_size=1,
  146. stride=1,
  147. bias=False,
  148. ),
  149. nn.BatchNorm2d(output_channels),
  150. )
  151. layers.append(
  152. BasicBlock(input_channels,
  153. output_channels,
  154. downsample=downsample))
  155. input_channels = output_channels
  156. return nn.Sequential(*layers)
  157. def forward(self, x):
  158. x = self.conv1_1(x)
  159. x = self.bn1_1(x)
  160. x = self.relu1_1(x)
  161. x = self.conv1_2(x)
  162. x = self.bn1_2(x)
  163. x = self.relu1_2(x)
  164. outs = []
  165. for i in range(4):
  166. layer_index = i + 2
  167. pool_layer = getattr(self, 'pool{}'.format(layer_index))
  168. block_layer = getattr(self, 'block{}'.format(layer_index))
  169. conv_layer = getattr(self, 'conv{}'.format(layer_index))
  170. bn_layer = getattr(self, 'bn{}'.format(layer_index))
  171. relu_layer = getattr(self, 'relu{}'.format(layer_index))
  172. if pool_layer is not None:
  173. x = pool_layer(x)
  174. x = block_layer(x)
  175. x = conv_layer(x)
  176. x = bn_layer(x)
  177. x = relu_layer(x)
  178. outs.append(x)
  179. if self.out_indices is not None:
  180. return tuple([outs[i] for i in self.out_indices])
  181. return x