dan_decoder.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class CAM(nn.Module):
  5. '''
  6. Convolutional Alignment Module
  7. '''
  8. # Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc.
  9. # You can adapt it to any input size by changing the padding or stride.
  10. def __init__(self,
  11. channels_list=[64, 128, 256, 512],
  12. strides_list=[[2, 2], [1, 1], [1, 1]],
  13. in_shape=[8, 32],
  14. maxT=25,
  15. depth=4,
  16. num_channels=128):
  17. super(CAM, self).__init__()
  18. # cascade multiscale features
  19. fpn = []
  20. for i in range(1, len(channels_list)):
  21. fpn.append(
  22. nn.Sequential(
  23. nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3),
  24. (strides_list[i - 1][0], strides_list[i - 1][1]),
  25. 1), nn.BatchNorm2d(channels_list[i]),
  26. nn.ReLU(True)))
  27. self.fpn = nn.Sequential(*fpn)
  28. # convolutional alignment
  29. # convs
  30. assert depth % 2 == 0, 'the depth of CAM must be a even number.'
  31. # in_shape = scales[-1]
  32. strides = []
  33. conv_ksizes = []
  34. deconv_ksizes = []
  35. h, w = in_shape[0], in_shape[1]
  36. for i in range(0, int(depth / 2)):
  37. stride = [2] if 2**(depth / 2 - i) <= h else [1]
  38. stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1]
  39. strides.append(stride)
  40. conv_ksizes.append([3, 3])
  41. deconv_ksizes.append([_**2 for _ in stride])
  42. convs = [
  43. nn.Sequential(
  44. nn.Conv2d(channels_list[-1], num_channels,
  45. tuple(conv_ksizes[0]), tuple(strides[0]),
  46. (int((conv_ksizes[0][0] - 1) / 2),
  47. int((conv_ksizes[0][1] - 1) / 2))),
  48. nn.BatchNorm2d(num_channels), nn.ReLU(True))
  49. ]
  50. for i in range(1, int(depth / 2)):
  51. convs.append(
  52. nn.Sequential(
  53. nn.Conv2d(num_channels, num_channels,
  54. tuple(conv_ksizes[i]), tuple(strides[i]),
  55. (int((conv_ksizes[i][0] - 1) / 2),
  56. int((conv_ksizes[i][1] - 1) / 2))),
  57. nn.BatchNorm2d(num_channels), nn.ReLU(True)))
  58. self.convs = nn.Sequential(*convs)
  59. # deconvs
  60. deconvs = []
  61. for i in range(1, int(depth / 2)):
  62. deconvs.append(
  63. nn.Sequential(
  64. nn.ConvTranspose2d(
  65. num_channels, num_channels,
  66. tuple(deconv_ksizes[int(depth / 2) - i]),
  67. tuple(strides[int(depth / 2) - i]),
  68. (int(deconv_ksizes[int(depth / 2) - i][0] / 4.),
  69. int(deconv_ksizes[int(depth / 2) - i][1] / 4.))),
  70. nn.BatchNorm2d(num_channels), nn.ReLU(True)))
  71. deconvs.append(
  72. nn.Sequential(
  73. nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]),
  74. tuple(strides[0]),
  75. (int(deconv_ksizes[0][0] / 4.),
  76. int(deconv_ksizes[0][1] / 4.))),
  77. nn.Sigmoid()))
  78. self.deconvs = nn.Sequential(*deconvs)
  79. def forward(self, input):
  80. x = input[0]
  81. for i in range(0, len(self.fpn)):
  82. # print(self.fpn[i](x).shape, input[i+1].shape)
  83. x = self.fpn[i](x) + input[i + 1]
  84. conv_feats = []
  85. for i in range(0, len(self.convs)):
  86. x = self.convs[i](x)
  87. conv_feats.append(x)
  88. for i in range(0, len(self.deconvs) - 1):
  89. x = self.deconvs[i](x)
  90. x = x + conv_feats[len(conv_feats) - 2 - i]
  91. x = self.deconvs[-1](x)
  92. return x
  93. class CAMSimp(nn.Module):
  94. def __init__(self, maxT=25, num_channels=128):
  95. super(CAMSimp, self).__init__()
  96. self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0),
  97. nn.Sigmoid())
  98. def forward(self, x):
  99. x = self.conv(x)
  100. return x
  101. class DANDecoder(nn.Module):
  102. '''
  103. Decoupled Text Decoder
  104. '''
  105. def __init__(self,
  106. out_channels,
  107. in_channels,
  108. use_cam=True,
  109. max_len=25,
  110. channels_list=[64, 128, 256, 512],
  111. strides_list=[[2, 2], [1, 1], [1, 1]],
  112. in_shape=[8, 32],
  113. depth=4,
  114. dropout=0.3,
  115. **kwargs):
  116. super(DANDecoder, self).__init__()
  117. self.eos = 0
  118. self.bos = out_channels - 2
  119. self.ignore_index = out_channels - 1
  120. nchannel = in_channels
  121. self.nchannel = in_channels
  122. self.use_cam = use_cam
  123. if use_cam:
  124. self.cam = CAM(channels_list=channels_list,
  125. strides_list=strides_list,
  126. in_shape=in_shape,
  127. maxT=max_len + 1,
  128. depth=depth,
  129. num_channels=nchannel)
  130. else:
  131. self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel)
  132. self.pre_lstm = nn.LSTM(nchannel,
  133. int(nchannel / 2),
  134. bidirectional=True)
  135. self.rnn = nn.GRUCell(nchannel * 2, nchannel)
  136. self.generator = nn.Sequential(nn.Dropout(p=dropout),
  137. nn.Linear(nchannel, out_channels - 2))
  138. self.char_embeddings = nn.Embedding(out_channels,
  139. embedding_dim=in_channels,
  140. padding_idx=out_channels - 1)
  141. def forward(self, inputs, data=None):
  142. A = self.cam(inputs)
  143. if isinstance(inputs, list):
  144. feature = inputs[-1]
  145. else:
  146. feature = inputs
  147. nB, nC, nH, nW = feature.shape
  148. nT = A.shape[1]
  149. # Normalize
  150. A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1)
  151. # weighted sum
  152. C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW)
  153. C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) # T, B, C
  154. C, _ = self.pre_lstm(C) # T, B, C
  155. C = F.dropout(C, p=0.3, training=self.training)
  156. if self.training:
  157. text = data[0]
  158. text_length = data[-1]
  159. nsteps = int(text_length.max())
  160. gru_res = torch.zeros_like(C)
  161. hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
  162. prev_emb = self.char_embeddings(text[:, 0])
  163. for i in range(0, nsteps + 1):
  164. hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
  165. hidden)
  166. gru_res[i, :, :] = hidden
  167. prev_emb = self.char_embeddings(text[:, i + 1])
  168. gru_res = self.generator(gru_res)
  169. return gru_res[:nsteps + 1, :, :].transpose(1, 0)
  170. else:
  171. gru_res = torch.zeros_like(C)
  172. hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
  173. prev_emb = self.char_embeddings(
  174. torch.zeros(nB, dtype=torch.int64, device=feature.device) +
  175. self.bos)
  176. dec_seq = torch.full((nB, nT),
  177. self.ignore_index,
  178. dtype=torch.int64,
  179. device=feature.get_device())
  180. for i in range(0, nT):
  181. hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
  182. hidden)
  183. gru_res[i, :, :] = hidden
  184. mid_res = self.generator(hidden).argmax(-1)
  185. dec_seq[:, i] = mid_res.squeeze(0)
  186. if (dec_seq == self.eos).any(dim=-1).all():
  187. break
  188. prev_emb = self.char_embeddings(mid_res)
  189. gru_res = self.generator(gru_res)
  190. return F.softmax(gru_res.transpose(1, 0), -1)