nrtr_encoder.py 794 B

12345678910111213141516171819202122232425262728
  1. from torch import nn
  2. class NRTREncoder(nn.Module):
  3. def __init__(self, in_channels):
  4. super(NRTREncoder, self).__init__()
  5. self.out_channels = 512 # 64*H
  6. self.block = nn.Sequential(
  7. nn.Conv2d(
  8. in_channels=in_channels,
  9. out_channels=32,
  10. kernel_size=3,
  11. stride=2,
  12. padding=1,
  13. ), nn.ReLU(), nn.BatchNorm2d(32),
  14. nn.Conv2d(
  15. in_channels=32,
  16. out_channels=64,
  17. kernel_size=3,
  18. stride=2,
  19. padding=1,
  20. ), nn.ReLU(), nn.BatchNorm2d(64))
  21. def forward(self, images):
  22. x = self.block(images)
  23. x = x.permute(0, 3, 2, 1).flatten(2) # B, W, H*C
  24. return x