resnet31_rnn.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import torch.nn as nn
  2. def conv3x3(in_planes, out_planes, stride=1):
  3. """3x3 convolution with padding."""
  4. return nn.Conv2d(in_planes,
  5. out_planes,
  6. kernel_size=3,
  7. stride=stride,
  8. padding=1,
  9. bias=False)
  10. def conv1x1(in_planes, out_planes, stride=1):
  11. """1x1 convolution."""
  12. return nn.Conv2d(in_planes,
  13. out_planes,
  14. kernel_size=1,
  15. stride=stride,
  16. bias=False)
  17. class AsterBlock(nn.Module):
  18. def __init__(self, inplanes, planes, stride=1, downsample=None):
  19. super(AsterBlock, self).__init__()
  20. self.conv1 = conv1x1(inplanes, planes, stride)
  21. self.bn1 = nn.BatchNorm2d(planes)
  22. self.relu = nn.ReLU(inplace=True)
  23. self.conv2 = conv3x3(planes, planes)
  24. self.bn2 = nn.BatchNorm2d(planes)
  25. self.downsample = downsample
  26. self.stride = stride
  27. def forward(self, x):
  28. residual = x
  29. out = self.conv1(x)
  30. out = self.bn1(out)
  31. out = self.relu(out)
  32. out = self.conv2(out)
  33. out = self.bn2(out)
  34. if self.downsample is not None:
  35. residual = self.downsample(x)
  36. out += residual
  37. out = self.relu(out)
  38. return out
  39. class ResNet_ASTER(nn.Module):
  40. """For aster or crnn."""
  41. def __init__(self, in_channels, with_lstm=True, n_group=1):
  42. super(ResNet_ASTER, self).__init__()
  43. self.with_lstm = with_lstm
  44. self.n_group = n_group
  45. self.out_channels = 512
  46. if with_lstm:
  47. self.out_channels = 512
  48. self.layer0 = nn.Sequential(
  49. nn.Conv2d(in_channels,
  50. 32,
  51. kernel_size=(3, 3),
  52. stride=1,
  53. padding=1,
  54. bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
  55. self.inplanes = 32
  56. self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
  57. self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
  58. self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
  59. self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
  60. self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
  61. if with_lstm:
  62. self.rnn = nn.LSTM(512,
  63. 256,
  64. bidirectional=True,
  65. num_layers=2,
  66. batch_first=True)
  67. self.out_planes = 2 * 256
  68. else:
  69. self.out_planes = 512
  70. for m in self.modules():
  71. if isinstance(m, nn.Conv2d):
  72. nn.init.kaiming_normal_(m.weight,
  73. mode='fan_out',
  74. nonlinearity='relu')
  75. elif isinstance(m, nn.BatchNorm2d):
  76. nn.init.constant_(m.weight, 1)
  77. nn.init.constant_(m.bias, 0)
  78. def _make_layer(self, planes, blocks, stride):
  79. downsample = None
  80. if stride != [1, 1] or self.inplanes != planes:
  81. downsample = nn.Sequential(conv1x1(self.inplanes, planes, stride),
  82. nn.BatchNorm2d(planes))
  83. layers = []
  84. layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
  85. self.inplanes = planes
  86. for _ in range(1, blocks):
  87. layers.append(AsterBlock(self.inplanes, planes))
  88. return nn.Sequential(*layers)
  89. def forward(self, x):
  90. x0 = self.layer0(x)
  91. x1 = self.layer1(x0)
  92. x2 = self.layer2(x1)
  93. x3 = self.layer3(x2)
  94. x4 = self.layer4(x3)
  95. x5 = self.layer5(x4)
  96. cnn_feat = x5.squeeze(2) # [N, c, w]
  97. cnn_feat = cnn_feat.transpose(2, 1).contiguous()
  98. if self.with_lstm:
  99. rnn_feat, _ = self.rnn(cnn_feat)
  100. return rnn_feat
  101. else:
  102. return cnn_feat