abinet_decoder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
  5. class BCNLanguage(nn.Module):
  6. def __init__(
  7. self,
  8. d_model=512,
  9. nhead=8,
  10. num_layers=4,
  11. dim_feedforward=2048,
  12. dropout=0.0,
  13. max_length=25,
  14. detach=True,
  15. num_classes=37,
  16. ):
  17. super().__init__()
  18. self.d_model = d_model
  19. self.detach = detach
  20. self.max_length = max_length + 1
  21. self.proj = nn.Linear(num_classes, d_model, False)
  22. self.token_encoder = PositionalEncoding(dropout=0.1,
  23. dim=d_model,
  24. max_len=self.max_length)
  25. self.pos_encoder = PositionalEncoding(dropout=0,
  26. dim=d_model,
  27. max_len=self.max_length)
  28. self.decoder = nn.ModuleList([
  29. TransformerBlock(
  30. d_model=d_model,
  31. nhead=nhead,
  32. dim_feedforward=dim_feedforward,
  33. attention_dropout_rate=dropout,
  34. residual_dropout_rate=dropout,
  35. with_self_attn=False,
  36. with_cross_attn=True,
  37. ) for i in range(num_layers)
  38. ])
  39. self.cls = nn.Linear(d_model, num_classes)
  40. def forward(self, tokens, lengths):
  41. """
  42. Args:
  43. tokens: (N, T, C) where T is length, N is batch size and C is classes number
  44. lengths: (N,)
  45. """
  46. if self.detach:
  47. tokens = tokens.detach()
  48. embed = self.proj(tokens) # (N, T, E)
  49. embed = self.token_encoder(embed) # (N, T, E)
  50. mask = _get_mask(lengths, self.max_length) # (N, 1, T, T)
  51. zeros = embed.new_zeros(*embed.shape)
  52. qeury = self.pos_encoder(zeros)
  53. for decoder_layer in self.decoder:
  54. qeury = decoder_layer(qeury, embed, cross_mask=mask)
  55. output = qeury # (N, T, E)
  56. logits = self.cls(output) # (N, T, C)
  57. return output, logits
  58. def encoder_layer(in_c, out_c, k=3, s=2, p=1):
  59. return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
  60. nn.BatchNorm2d(out_c), nn.ReLU(True))
  61. class DecoderUpsample(nn.Module):
  62. def __init__(self, in_c, out_c, k=3, s=1, p=1, mode='nearest') -> None:
  63. super().__init__()
  64. self.align_corners = None if mode == 'nearest' else True
  65. self.mode = mode
  66. # nn.Upsample(size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners),
  67. self.w = nn.Sequential(
  68. nn.Conv2d(in_c, out_c, k, s, p),
  69. nn.BatchNorm2d(out_c),
  70. nn.ReLU(True),
  71. )
  72. def forward(self, x, size):
  73. x = F.interpolate(x,
  74. size=size,
  75. mode=self.mode,
  76. align_corners=self.align_corners)
  77. return self.w(x)
  78. class PositionAttention(nn.Module):
  79. def __init__(self,
  80. max_length,
  81. in_channels=512,
  82. num_channels=64,
  83. mode='nearest',
  84. **kwargs):
  85. super().__init__()
  86. self.max_length = max_length
  87. self.k_encoder = nn.Sequential(
  88. encoder_layer(in_channels, num_channels, s=(1, 2)),
  89. encoder_layer(num_channels, num_channels, s=(2, 2)),
  90. encoder_layer(num_channels, num_channels, s=(2, 2)),
  91. encoder_layer(num_channels, num_channels, s=(2, 2)),
  92. )
  93. self.k_decoder = nn.ModuleList([
  94. DecoderUpsample(num_channels, num_channels, mode=mode),
  95. DecoderUpsample(num_channels, num_channels, mode=mode),
  96. DecoderUpsample(num_channels, num_channels, mode=mode),
  97. DecoderUpsample(num_channels, in_channels, mode=mode),
  98. ])
  99. self.pos_encoder = PositionalEncoding(dropout=0,
  100. dim=in_channels,
  101. max_len=max_length)
  102. self.project = nn.Linear(in_channels, in_channels)
  103. def forward(self, x, query=None):
  104. N, E, H, W = x.size()
  105. k, v = x, x # (N, E, H, W)
  106. # calculate key vector
  107. features = []
  108. size_decoder = []
  109. for i in range(0, len(self.k_encoder)):
  110. size_decoder.append(k.shape[2:])
  111. k = self.k_encoder[i](k)
  112. features.append(k)
  113. for i in range(0, len(self.k_decoder) - 1):
  114. k = self.k_decoder[i](k, size=size_decoder[-(i + 1)])
  115. k = k + features[len(self.k_decoder) - 2 - i]
  116. k = self.k_decoder[-1](k, size=size_decoder[0]) # (N, E, H, W)
  117. # calculate query vector
  118. # TODO q=f(q,k)
  119. zeros = x.new_zeros(
  120. (N, self.max_length, E)) if query is None else query # (N, T, E)
  121. q = self.pos_encoder(zeros) # (N, T, E)
  122. q = self.project(q) # (N, T, E)
  123. # calculate attention
  124. attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
  125. attn_scores = attn_scores / (E**0.5)
  126. attn_scores = F.softmax(attn_scores, dim=-1)
  127. # (N, E, H, W) -> (N, H, W, E) -> (N, (H*W), E)
  128. v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
  129. attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
  130. return attn_vecs, attn_scores.view(N, -1, H, W)
  131. class ABINetDecoder(nn.Module):
  132. def __init__(self,
  133. in_channels,
  134. out_channels,
  135. nhead=8,
  136. num_layers=3,
  137. dim_feedforward=2048,
  138. dropout=0.1,
  139. max_length=25,
  140. iter_size=3,
  141. **kwargs):
  142. super().__init__()
  143. self.max_length = max_length + 1
  144. d_model = in_channels
  145. self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
  146. self.encoder = nn.ModuleList([
  147. TransformerBlock(
  148. d_model=d_model,
  149. nhead=nhead,
  150. dim_feedforward=dim_feedforward,
  151. attention_dropout_rate=dropout,
  152. residual_dropout_rate=dropout,
  153. with_self_attn=True,
  154. with_cross_attn=False,
  155. ) for _ in range(num_layers)
  156. ])
  157. self.decoder = PositionAttention(
  158. max_length=self.max_length, # additional stop token
  159. in_channels=d_model,
  160. num_channels=d_model // 8,
  161. mode='nearest',
  162. )
  163. self.out_channels = out_channels
  164. self.cls = nn.Linear(d_model, self.out_channels)
  165. self.iter_size = iter_size
  166. if iter_size > 0:
  167. self.language = BCNLanguage(
  168. d_model=d_model,
  169. nhead=nhead,
  170. num_layers=4,
  171. dim_feedforward=dim_feedforward,
  172. dropout=dropout,
  173. max_length=max_length,
  174. num_classes=self.out_channels,
  175. )
  176. # alignment
  177. self.w_att_align = nn.Linear(2 * d_model, d_model)
  178. self.cls_align = nn.Linear(d_model, self.out_channels)
  179. def forward(self, x, data=None):
  180. # bs, c, h, w
  181. x = x.permute([0, 2, 3, 1]) # bs, h, w, c
  182. _, H, W, C = x.shape
  183. # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
  184. feature = x.flatten(1, 2) # bs, h*w, c
  185. feature = self.pos_encoder(feature) # bs, h*w, c
  186. for encoder_layer in self.encoder:
  187. feature = encoder_layer(feature)
  188. # bs, h*w, c
  189. feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
  190. 2) # bs, c, h, w
  191. v_feature, _ = self.decoder(feature) # (bs[N], T, E)
  192. vis_logits = self.cls(v_feature) # (bs[N], T, E)
  193. align_lengths = _get_length(vis_logits)
  194. align_logits = vis_logits
  195. all_l_res, all_a_res = [], []
  196. for _ in range(self.iter_size):
  197. tokens = F.softmax(align_logits, dim=-1)
  198. lengths = torch.clamp(
  199. align_lengths, 2,
  200. self.max_length) # TODO: move to language model
  201. l_feature, l_logits = self.language(tokens, lengths)
  202. # alignment
  203. all_l_res.append(l_logits)
  204. fuse = torch.cat((l_feature, v_feature), -1)
  205. f_att = torch.sigmoid(self.w_att_align(fuse))
  206. output = f_att * v_feature + (1 - f_att) * l_feature
  207. align_logits = self.cls_align(output)
  208. align_lengths = _get_length(align_logits)
  209. all_a_res.append(align_logits)
  210. if self.training:
  211. return {
  212. 'align': all_a_res,
  213. 'lang': all_l_res,
  214. 'vision': vis_logits
  215. }
  216. else:
  217. return F.softmax(align_logits, -1)
  218. def _get_length(logit):
  219. """Greed decoder to obtain length from logit."""
  220. out = logit.argmax(dim=-1) == 0
  221. non_zero_mask = out.int() != 0
  222. mask_max_values, mask_max_indices = torch.max(non_zero_mask.int(), dim=-1)
  223. mask_max_indices[mask_max_values == 0] = -1
  224. out = mask_max_indices + 1
  225. return out
  226. def _get_mask(length, max_length):
  227. """Generate a square mask for the sequence.
  228. The masked positions are filled with float('-inf'). Unmasked positions are
  229. filled with float(0.0).
  230. """
  231. length = length.unsqueeze(-1)
  232. N = length.size(0)
  233. grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
  234. zero_mask = torch.zeros([N, max_length],
  235. dtype=torch.float32,
  236. device=length.device)
  237. inf_mask = torch.full([N, max_length],
  238. float('-inf'),
  239. dtype=torch.float32,
  240. device=length.device)
  241. diag_mask = torch.diag(
  242. torch.full([max_length],
  243. float('-inf'),
  244. dtype=torch.float32,
  245. device=length.device),
  246. diagonal=0,
  247. )
  248. mask = torch.where(grid >= length, inf_mask, zero_mask)
  249. mask = mask.unsqueeze(1) + diag_mask
  250. return mask.unsqueeze(1)