rec_resnet_fpn.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. class ConvBNLayer(nn.Module):
  5. def __init__(self,
  6. in_channels,
  7. out_channels,
  8. kernel,
  9. stride=1,
  10. act='ReLU'):
  11. super(ConvBNLayer, self).__init__()
  12. self.act_flag = act
  13. self.conv = nn.Conv2d(in_channels,
  14. out_channels,
  15. kernel_size=2 if stride == (1, 1) else kernel,
  16. stride=stride,
  17. padding=(kernel - 1) // 2,
  18. dilation=2 if stride == (1, 1) else 1)
  19. self.bn = nn.BatchNorm2d(out_channels)
  20. self.act = nn.ReLU(True)
  21. def forward(self, x):
  22. x = self.conv(x)
  23. x = self.bn(x)
  24. if self.act_flag != 'None':
  25. x = self.act(x)
  26. return x
  27. class Shortcut(nn.Module):
  28. def __init__(self, in_channels, out_channels, stride, is_first=False):
  29. super(Shortcut, self).__init__()
  30. self.use_conv = True
  31. if in_channels != out_channels or stride != 1 or is_first is True:
  32. if stride == (1, 1):
  33. self.conv = ConvBNLayer(in_channels, out_channels, 1, 1)
  34. else:
  35. self.conv = ConvBNLayer(in_channels, out_channels, 1, stride)
  36. else:
  37. self.use_conv = False
  38. def forward(self, x):
  39. if self.use_conv:
  40. x = self.conv(x)
  41. return x
  42. class BottleneckBlock(nn.Module):
  43. def __init__(self, in_channels, out_channels, stride):
  44. super(BottleneckBlock, self).__init__()
  45. self.conv0 = ConvBNLayer(in_channels, out_channels, kernel=1)
  46. self.conv1 = ConvBNLayer(out_channels,
  47. out_channels,
  48. kernel=3,
  49. stride=stride)
  50. self.conv2 = ConvBNLayer(out_channels,
  51. out_channels * 4,
  52. kernel=1,
  53. act='None')
  54. self.short = Shortcut(in_channels, out_channels * 4, stride=stride)
  55. self.out_channels = out_channels * 4
  56. self.relu = nn.ReLU(True)
  57. def forward(self, x):
  58. y = self.conv0(x)
  59. y = self.conv1(y)
  60. y = self.conv2(y)
  61. y = y + self.short(x)
  62. y = self.relu(y)
  63. return y
  64. class BasicBlock(nn.Module):
  65. def __init__(self, in_channels, out_channels, stride, is_first):
  66. super(BasicBlock, self).__init__()
  67. self.conv0 = ConvBNLayer(in_channels,
  68. out_channels,
  69. kernel=3,
  70. stride=stride)
  71. self.conv1 = ConvBNLayer(out_channels,
  72. out_channels,
  73. kernel=3,
  74. act='None')
  75. self.short = Shortcut(in_channels, out_channels, stride, is_first)
  76. self.out_chanels = out_channels
  77. self.relu = nn.ReLU(True)
  78. def forward(self, x):
  79. y = self.conv0(x)
  80. y = self.conv1(y)
  81. y = y + self.short(x)
  82. y = self.relu(y)
  83. return y
  84. class ResNet_FPN(nn.Module):
  85. def __init__(self, in_channels=1, layers=50, **kwargs):
  86. super(ResNet_FPN, self).__init__()
  87. supported_layers = {
  88. 18: {
  89. 'depth': [2, 2, 2, 2],
  90. 'block_class': BasicBlock
  91. },
  92. 34: {
  93. 'depth': [3, 4, 6, 3],
  94. 'block_class': BasicBlock
  95. },
  96. 50: {
  97. 'depth': [3, 4, 6, 3],
  98. 'block_class': BottleneckBlock
  99. },
  100. 101: {
  101. 'depth': [3, 4, 23, 3],
  102. 'block_class': BottleneckBlock
  103. },
  104. 152: {
  105. 'depth': [3, 8, 36, 3],
  106. 'block_class': BottleneckBlock
  107. }
  108. }
  109. stride_list = [(2, 2), (
  110. 2,
  111. 2,
  112. ), (1, 1), (1, 1)]
  113. num_filters = [64, 128, 256, 512]
  114. self.depth = supported_layers[layers]['depth']
  115. self.F = []
  116. # print(f"in_channels:{in_channels}")
  117. self.conv = ConvBNLayer(in_channels=in_channels,
  118. out_channels=64,
  119. kernel=7,
  120. stride=2) #64*256 ->32*128
  121. self.block_list = nn.ModuleList()
  122. in_ch = 64
  123. if layers >= 50:
  124. for block in range(len(self.depth)):
  125. for i in range(self.depth[block]):
  126. self.block_list.append(
  127. BottleneckBlock(
  128. in_channels=in_ch,
  129. out_channels=num_filters[block],
  130. stride=stride_list[block] if i == 0 else 1))
  131. in_ch = num_filters[block] * 4
  132. else:
  133. for block in range(len(self.depth)):
  134. for i in range(self.depth[block]):
  135. if i == 0 and block != 0:
  136. stride = (2, 1)
  137. else:
  138. stride = (1, 1)
  139. basic_block = BasicBlock(
  140. in_channels=in_ch,
  141. out_channels=num_filters[block],
  142. stride=stride_list[block] if i == 0 else 1,
  143. is_first=block == i == 0)
  144. in_ch = basic_block.out_chanels
  145. self.block_list.append(basic_block)
  146. out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
  147. self.base_block = nn.ModuleList()
  148. self.conv_trans = []
  149. self.bn_block = []
  150. for i in [-2, -3]:
  151. in_channels = out_ch_list[i + 1] + out_ch_list[i]
  152. self.base_block.append(
  153. nn.Conv2d(in_channels, out_ch_list[i], kernel_size=1)) #进行升通道
  154. self.base_block.append(
  155. nn.Conv2d(out_ch_list[i],
  156. out_ch_list[i],
  157. kernel_size=3,
  158. padding=1)) #进行合并
  159. self.base_block.append(
  160. nn.Sequential(nn.BatchNorm2d(out_ch_list[i]), nn.ReLU(True)))
  161. self.base_block.append(nn.Conv2d(out_ch_list[i], 512, kernel_size=1))
  162. self.out_channels = 512
  163. def forward(self, x):
  164. # print(f"before resnetfpn x.shape:{x.shape}")
  165. x = self.conv(x)
  166. fpn_list = []
  167. F = []
  168. for i in range(len(self.depth)):
  169. fpn_list.append(np.sum(self.depth[:i + 1]))
  170. for i, block in enumerate(self.block_list):
  171. x = block(x)
  172. for number in fpn_list:
  173. if i + 1 == number:
  174. F.append(x)
  175. base = F[-1]
  176. j = 0
  177. for i, block in enumerate(self.base_block):
  178. if i % 3 == 0 and i < 6:
  179. j = j + 1
  180. b, c, w, h = F[-j - 1].size()
  181. if [w, h] == list(base.size()[2:]):
  182. base = base
  183. else:
  184. base = self.conv_trans[j - 1](base)
  185. base = self.bn_block[j - 1](base)
  186. base = torch.cat([base, F[-j - 1]], dim=1)
  187. base = block(base)
  188. return base