12345678910111213141516171819202122232425262728 |
- from torch import nn
- class NRTREncoder(nn.Module):
- def __init__(self, in_channels):
- super(NRTREncoder, self).__init__()
- self.out_channels = 512 # 64*H
- self.block = nn.Sequential(
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=32,
- kernel_size=3,
- stride=2,
- padding=1,
- ), nn.ReLU(), nn.BatchNorm2d(32),
- nn.Conv2d(
- in_channels=32,
- out_channels=64,
- kernel_size=3,
- stride=2,
- padding=1,
- ), nn.ReLU(), nn.BatchNorm2d(64))
- def forward(self, images):
- x = self.block(images)
- x = x.permute(0, 3, 2, 1).flatten(2) # B, W, H*C
- return x
|