123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class CAM(nn.Module):
- '''
- Convolutional Alignment Module
- '''
- # Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc.
- # You can adapt it to any input size by changing the padding or stride.
- def __init__(self,
- channels_list=[64, 128, 256, 512],
- strides_list=[[2, 2], [1, 1], [1, 1]],
- in_shape=[8, 32],
- maxT=25,
- depth=4,
- num_channels=128):
- super(CAM, self).__init__()
- # cascade multiscale features
- fpn = []
- for i in range(1, len(channels_list)):
- fpn.append(
- nn.Sequential(
- nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3),
- (strides_list[i - 1][0], strides_list[i - 1][1]),
- 1), nn.BatchNorm2d(channels_list[i]),
- nn.ReLU(True)))
- self.fpn = nn.Sequential(*fpn)
- # convolutional alignment
- # convs
- assert depth % 2 == 0, 'the depth of CAM must be a even number.'
- # in_shape = scales[-1]
- strides = []
- conv_ksizes = []
- deconv_ksizes = []
- h, w = in_shape[0], in_shape[1]
- for i in range(0, int(depth / 2)):
- stride = [2] if 2**(depth / 2 - i) <= h else [1]
- stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1]
- strides.append(stride)
- conv_ksizes.append([3, 3])
- deconv_ksizes.append([_**2 for _ in stride])
- convs = [
- nn.Sequential(
- nn.Conv2d(channels_list[-1], num_channels,
- tuple(conv_ksizes[0]), tuple(strides[0]),
- (int((conv_ksizes[0][0] - 1) / 2),
- int((conv_ksizes[0][1] - 1) / 2))),
- nn.BatchNorm2d(num_channels), nn.ReLU(True))
- ]
- for i in range(1, int(depth / 2)):
- convs.append(
- nn.Sequential(
- nn.Conv2d(num_channels, num_channels,
- tuple(conv_ksizes[i]), tuple(strides[i]),
- (int((conv_ksizes[i][0] - 1) / 2),
- int((conv_ksizes[i][1] - 1) / 2))),
- nn.BatchNorm2d(num_channels), nn.ReLU(True)))
- self.convs = nn.Sequential(*convs)
- # deconvs
- deconvs = []
- for i in range(1, int(depth / 2)):
- deconvs.append(
- nn.Sequential(
- nn.ConvTranspose2d(
- num_channels, num_channels,
- tuple(deconv_ksizes[int(depth / 2) - i]),
- tuple(strides[int(depth / 2) - i]),
- (int(deconv_ksizes[int(depth / 2) - i][0] / 4.),
- int(deconv_ksizes[int(depth / 2) - i][1] / 4.))),
- nn.BatchNorm2d(num_channels), nn.ReLU(True)))
- deconvs.append(
- nn.Sequential(
- nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]),
- tuple(strides[0]),
- (int(deconv_ksizes[0][0] / 4.),
- int(deconv_ksizes[0][1] / 4.))),
- nn.Sigmoid()))
- self.deconvs = nn.Sequential(*deconvs)
- def forward(self, input):
- x = input[0]
- for i in range(0, len(self.fpn)):
- # print(self.fpn[i](x).shape, input[i+1].shape)
- x = self.fpn[i](x) + input[i + 1]
- conv_feats = []
- for i in range(0, len(self.convs)):
- x = self.convs[i](x)
- conv_feats.append(x)
- for i in range(0, len(self.deconvs) - 1):
- x = self.deconvs[i](x)
- x = x + conv_feats[len(conv_feats) - 2 - i]
- x = self.deconvs[-1](x)
- return x
- class CAMSimp(nn.Module):
- def __init__(self, maxT=25, num_channels=128):
- super(CAMSimp, self).__init__()
- self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0),
- nn.Sigmoid())
- def forward(self, x):
- x = self.conv(x)
- return x
- class DANDecoder(nn.Module):
- '''
- Decoupled Text Decoder
- '''
- def __init__(self,
- out_channels,
- in_channels,
- use_cam=True,
- max_len=25,
- channels_list=[64, 128, 256, 512],
- strides_list=[[2, 2], [1, 1], [1, 1]],
- in_shape=[8, 32],
- depth=4,
- dropout=0.3,
- **kwargs):
- super(DANDecoder, self).__init__()
- self.eos = 0
- self.bos = out_channels - 2
- self.ignore_index = out_channels - 1
- nchannel = in_channels
- self.nchannel = in_channels
- self.use_cam = use_cam
- if use_cam:
- self.cam = CAM(channels_list=channels_list,
- strides_list=strides_list,
- in_shape=in_shape,
- maxT=max_len + 1,
- depth=depth,
- num_channels=nchannel)
- else:
- self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel)
- self.pre_lstm = nn.LSTM(nchannel,
- int(nchannel / 2),
- bidirectional=True)
- self.rnn = nn.GRUCell(nchannel * 2, nchannel)
- self.generator = nn.Sequential(nn.Dropout(p=dropout),
- nn.Linear(nchannel, out_channels - 2))
- self.char_embeddings = nn.Embedding(out_channels,
- embedding_dim=in_channels,
- padding_idx=out_channels - 1)
- def forward(self, inputs, data=None):
- A = self.cam(inputs)
- if isinstance(inputs, list):
- feature = inputs[-1]
- else:
- feature = inputs
- nB, nC, nH, nW = feature.shape
- nT = A.shape[1]
- # Normalize
- A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1)
- # weighted sum
- C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW)
- C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) # T, B, C
- C, _ = self.pre_lstm(C) # T, B, C
- C = F.dropout(C, p=0.3, training=self.training)
- if self.training:
- text = data[0]
- text_length = data[-1]
- nsteps = int(text_length.max())
- gru_res = torch.zeros_like(C)
- hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
- prev_emb = self.char_embeddings(text[:, 0])
- for i in range(0, nsteps + 1):
- hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
- hidden)
- gru_res[i, :, :] = hidden
- prev_emb = self.char_embeddings(text[:, i + 1])
- gru_res = self.generator(gru_res)
- return gru_res[:nsteps + 1, :, :].transpose(1, 0)
- else:
- gru_res = torch.zeros_like(C)
- hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
- prev_emb = self.char_embeddings(
- torch.zeros(nB, dtype=torch.int64, device=feature.device) +
- self.bos)
- dec_seq = torch.full((nB, nT),
- self.ignore_index,
- dtype=torch.int64,
- device=feature.get_device())
-
- for i in range(0, nT):
- hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
- hidden)
- gru_res[i, :, :] = hidden
- mid_res = self.generator(hidden).argmax(-1)
- dec_seq[:, i] = mid_res.squeeze(0)
- if (dec_seq == self.eos).any(dim=-1).all():
- break
- prev_emb = self.char_embeddings(mid_res)
- gru_res = self.generator(gru_res)
- return F.softmax(gru_res.transpose(1, 0), -1)
|