bus_decoder.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """This code is refer from:
  2. https://github.com/jjwei66/BUSNet
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from .nrtr_decoder import PositionalEncoding, TransformerBlock
  8. from .abinet_decoder import _get_mask, _get_length
  9. class BUSDecoder(nn.Module):
  10. def __init__(self,
  11. in_channels,
  12. out_channels,
  13. nhead=8,
  14. num_layers=4,
  15. dim_feedforward=2048,
  16. dropout=0.1,
  17. max_length=25,
  18. ignore_index=100,
  19. pretraining=False,
  20. detach=True):
  21. super().__init__()
  22. d_model = in_channels
  23. self.ignore_index = ignore_index
  24. self.pretraining = pretraining
  25. self.d_model = d_model
  26. self.detach = detach
  27. self.max_length = max_length + 1 # additional stop token
  28. self.out_channels = out_channels
  29. # --------------------------------------------------------------------------
  30. # decoder specifics
  31. self.proj = nn.Linear(out_channels, d_model, False)
  32. self.token_encoder = PositionalEncoding(dropout=0.1,
  33. dim=d_model,
  34. max_len=self.max_length)
  35. self.pos_encoder = PositionalEncoding(dropout=0.1,
  36. dim=d_model,
  37. max_len=self.max_length)
  38. self.decoder = nn.ModuleList([
  39. TransformerBlock(
  40. d_model=d_model,
  41. nhead=nhead,
  42. dim_feedforward=dim_feedforward,
  43. attention_dropout_rate=dropout,
  44. residual_dropout_rate=dropout,
  45. with_self_attn=False,
  46. with_cross_attn=True,
  47. ) for i in range(num_layers)
  48. ])
  49. v_mask = torch.empty((1, 1, d_model))
  50. l_mask = torch.empty((1, 1, d_model))
  51. self.v_mask = nn.Parameter(v_mask)
  52. self.l_mask = nn.Parameter(l_mask)
  53. torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
  54. torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)
  55. v_embeding = torch.empty((1, 1, d_model))
  56. l_embeding = torch.empty((1, 1, d_model))
  57. self.v_embeding = nn.Parameter(v_embeding)
  58. self.l_embeding = nn.Parameter(l_embeding)
  59. torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
  60. torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
  61. self.cls = nn.Linear(d_model, out_channels)
  62. def forward_decoder(self, q, x, mask=None):
  63. for decoder_layer in self.decoder:
  64. q = decoder_layer(q, x, cross_mask=mask)
  65. output = q # (N, T, E)
  66. logits = self.cls(output) # (N, T, C)
  67. return logits
  68. def forward(self, img_feat, data=None):
  69. """
  70. Args:
  71. tokens: (N, T, C) where T is length, N is batch size and C is classes number
  72. lengths: (N,)
  73. """
  74. img_feat = img_feat + self.v_embeding
  75. B, L, C = img_feat.shape
  76. # --------------------------------------------------------------------------
  77. # decoder procedure
  78. T = self.max_length
  79. zeros = img_feat.new_zeros((B, T, C))
  80. zeros_len = img_feat.new_zeros(B)
  81. query = self.pos_encoder(zeros)
  82. # 1. vision decode
  83. v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
  84. dim=1) # v
  85. padding_mask = _get_mask(
  86. self.max_length + zeros_len,
  87. self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen
  88. v_mask = torch.zeros((1, 1, self.max_length, L),
  89. device=img_feat.device).tile([B, 1, 1,
  90. 1]) # maxlen L
  91. mask = torch.cat((v_mask, padding_mask), 3)
  92. v_logits = self.forward_decoder(query, v_embed, mask=mask)
  93. # 2. language decode
  94. if self.training and self.pretraining:
  95. tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
  96. tokens = F.one_hot(tgt, num_classes=self.out_channels)
  97. tokens = tokens.float()
  98. lengths = data[-1]
  99. else:
  100. tokens = torch.softmax(v_logits, dim=-1)
  101. lengths = _get_length(v_logits)
  102. tokens = tokens.detach()
  103. token_embed = self.proj(tokens) # (N, T, E)
  104. token_embed = self.token_encoder(token_embed) # (T, N, E)
  105. token_embed = token_embed + self.l_embeding
  106. padding_mask = _get_mask(lengths,
  107. self.max_length) # 对tokens长度以外的padding
  108. mask = torch.cat((v_mask, padding_mask), 3)
  109. l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
  110. l_logits = self.forward_decoder(query, l_embed, mask=mask)
  111. # 3. vision language decode
  112. vl_embed = torch.cat((img_feat, token_embed), dim=1)
  113. vl_logits = self.forward_decoder(query, vl_embed, mask=mask)
  114. if self.training:
  115. return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
  116. else:
  117. return F.softmax(vl_logits, -1)