rec_resnet_45.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import math
  2. import numpy as np
  3. import torch.nn as nn
  4. from openrec.modeling.common import Block
  5. def conv1x1(in_planes, out_planes, stride=1):
  6. return nn.Conv2d(in_planes,
  7. out_planes,
  8. kernel_size=1,
  9. stride=stride,
  10. bias=False)
  11. def conv3x3(in_planes, out_planes, stride=1):
  12. """3x3 convolution with padding."""
  13. return nn.Conv2d(in_planes,
  14. out_planes,
  15. kernel_size=3,
  16. stride=stride,
  17. padding=1,
  18. bias=False)
  19. class BasicBlock(nn.Module):
  20. expansion = 1
  21. def __init__(self, inplanes, planes, stride=1, downsample=None):
  22. super(BasicBlock, self).__init__()
  23. self.conv1 = conv1x1(inplanes, planes)
  24. self.bn1 = nn.BatchNorm2d(planes)
  25. self.relu = nn.ReLU(inplace=True)
  26. self.conv2 = conv3x3(planes, planes, stride)
  27. self.bn2 = nn.BatchNorm2d(planes)
  28. self.downsample = downsample
  29. self.stride = stride
  30. def forward(self, x):
  31. residual = x
  32. out = self.conv1(x)
  33. out = self.bn1(out)
  34. out = self.relu(out)
  35. out = self.conv2(out)
  36. out = self.bn2(out)
  37. if self.downsample is not None:
  38. residual = self.downsample(x)
  39. out += residual
  40. out = self.relu(out)
  41. return out
  42. class ResNet45(nn.Module):
  43. def __init__(
  44. self,
  45. in_channels=3,
  46. block=BasicBlock,
  47. layers=[3, 4, 6, 6, 3],
  48. strides=[2, 1, 2, 1, 1],
  49. last_stage=False,
  50. out_channels=256,
  51. trans_layer=0,
  52. out_dim=384,
  53. feat2d=True,
  54. return_list=False,
  55. ):
  56. super(ResNet45, self).__init__()
  57. self.inplanes = 32
  58. self.conv1 = nn.Conv2d(in_channels,
  59. 32,
  60. kernel_size=3,
  61. stride=1,
  62. padding=1,
  63. bias=False)
  64. self.bn1 = nn.BatchNorm2d(32)
  65. self.relu = nn.ReLU(inplace=True)
  66. self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
  67. self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
  68. self.layer3 = self._make_layer(block,
  69. 128,
  70. layers[2],
  71. stride=strides[2])
  72. self.layer4 = self._make_layer(block,
  73. 256,
  74. layers[3],
  75. stride=strides[3])
  76. self.layer5 = self._make_layer(block,
  77. 512,
  78. layers[4],
  79. stride=strides[4])
  80. self.out_channels = 512
  81. self.feat2d = feat2d
  82. self.return_list = return_list
  83. if trans_layer > 0:
  84. dpr = np.linspace(0, 0.1, trans_layer)
  85. blocks = [nn.Linear(512, out_dim)] + [
  86. Block(dim=out_dim,
  87. num_heads=out_dim // 32,
  88. mlp_ratio=4.0,
  89. qkv_bias=False,
  90. drop_path=dpr[i]) for i in range(trans_layer)
  91. ]
  92. self.trans_blocks = nn.Sequential(*blocks)
  93. dim = out_dim
  94. self.out_channels = out_dim
  95. else:
  96. self.trans_blocks = None
  97. dim = 512
  98. self.last_stage = last_stage
  99. if last_stage:
  100. self.out_channels = out_channels
  101. self.last_conv = nn.Linear(dim, self.out_channels, bias=False)
  102. self.hardswish = nn.Hardswish()
  103. self.dropout = nn.Dropout(p=0.1)
  104. for m in self.modules():
  105. if isinstance(m, nn.Conv2d):
  106. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  107. m.weight.data.normal_(0, math.sqrt(2.0 / n))
  108. elif isinstance(m, nn.BatchNorm2d):
  109. m.weight.data.fill_(1)
  110. m.bias.data.zero_()
  111. elif isinstance(m, nn.Linear):
  112. nn.init.trunc_normal_(m.weight, mean=0, std=0.02)
  113. if isinstance(m, nn.Linear) and m.bias is not None:
  114. nn.init.zeros_(m.bias)
  115. def _make_layer(self, block, planes, blocks, stride=1):
  116. downsample = None
  117. if stride != 1 or self.inplanes != planes * block.expansion:
  118. downsample = nn.Sequential(
  119. nn.Conv2d(
  120. self.inplanes,
  121. planes * block.expansion,
  122. kernel_size=1,
  123. stride=stride,
  124. bias=False,
  125. ),
  126. nn.BatchNorm2d(planes * block.expansion),
  127. )
  128. layers = []
  129. layers.append(block(self.inplanes, planes, stride, downsample))
  130. self.inplanes = planes * block.expansion
  131. for i in range(1, blocks):
  132. layers.append(block(self.inplanes, planes))
  133. return nn.Sequential(*layers)
  134. def forward(self, x):
  135. x = self.conv1(x)
  136. x = self.bn1(x)
  137. x = self.relu(x)
  138. x = self.layer1(x)
  139. x2 = self.layer2(x)
  140. x3 = self.layer3(x2)
  141. x4 = self.layer4(x3)
  142. x5 = self.layer5(x4)
  143. if self.return_list:
  144. return [x2, x3, x4, x5]
  145. x = x5
  146. if self.trans_blocks is not None:
  147. B, C, H, W = x.shape
  148. x = self.trans_blocks(x.flatten(2, 3).transpose(1, 2))
  149. x = x.transpose(1, 2).reshape(B, -1, H, W)
  150. if self.last_stage:
  151. x = x.mean(2).transpose(1, 2)
  152. x = self.last_conv(x)
  153. x = self.hardswish(x)
  154. x = self.dropout(x)
  155. elif not self.feat2d:
  156. x = x.flatten(2).transpose(1, 2)
  157. return x