rec_nrtr_mtb.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. from torch import nn
  3. class MTB(nn.Module):
  4. def __init__(self, cnn_num, in_channels):
  5. super(MTB, self).__init__()
  6. self.block = nn.Sequential()
  7. self.out_channels = in_channels
  8. self.cnn_num = cnn_num
  9. if self.cnn_num == 2:
  10. for i in range(self.cnn_num):
  11. self.block.add_module(
  12. 'conv_{}'.format(i),
  13. nn.Conv2d(
  14. in_channels=in_channels if i == 0 else 32 *
  15. (2**(i - 1)),
  16. out_channels=32 * (2**i),
  17. kernel_size=3,
  18. stride=2,
  19. padding=1,
  20. ),
  21. )
  22. self.block.add_module('relu_{}'.format(i), nn.ReLU())
  23. self.block.add_module('bn_{}'.format(i),
  24. nn.BatchNorm2d(32 * (2**i)))
  25. def forward(self, images):
  26. x = self.block(images)
  27. if self.cnn_num == 2:
  28. # (b, w, h, c)
  29. x = x.permute(0, 3, 2, 1)
  30. x_shape = x.shape
  31. x = torch.reshape(
  32. x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3]))
  33. return x