lpv_decoder.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from .abinet_decoder import PositionAttention
  6. from .nrtr_decoder import PositionalEncoding, TransformerBlock
  7. class Trans(nn.Module):
  8. def __init__(self, dim, nhead, dim_feedforward, dropout, num_layers):
  9. super().__init__()
  10. self.d_model = dim
  11. self.nhead = nhead
  12. self.pos_encoder = PositionalEncoding(dropout=0.0,
  13. dim=self.d_model,
  14. max_len=512)
  15. self.transformer = nn.ModuleList([
  16. TransformerBlock(
  17. dim,
  18. nhead,
  19. dim_feedforward,
  20. attention_dropout_rate=dropout,
  21. residual_dropout_rate=dropout,
  22. with_self_attn=True,
  23. with_cross_attn=False,
  24. ) for i in range(num_layers)
  25. ])
  26. def forward(self, feature, attn_map=None, use_mask=False):
  27. n, c, h, w = feature.shape
  28. feature = feature.flatten(2).transpose(1, 2)
  29. if use_mask:
  30. _, t, h, w = attn_map.shape
  31. location_mask = (attn_map.view(n, t, -1).transpose(1, 2) >
  32. 0.05).type(torch.float) # n,hw,t
  33. location_mask = location_mask.bmm(location_mask.transpose(
  34. 1, 2)) # n, hw, hw
  35. location_mask = location_mask.new_zeros(
  36. (h * w, h * w)).masked_fill(location_mask > 0, float('-inf'))
  37. location_mask = location_mask.unsqueeze(1) # n, 1, hw, hw
  38. else:
  39. location_mask = None
  40. feature = self.pos_encoder(feature)
  41. for layer in self.transformer:
  42. feature = layer(feature, self_mask=location_mask)
  43. feature = feature.transpose(1, 2).view(n, c, h, w)
  44. return feature, location_mask
  45. def _get_clones(module, N):
  46. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  47. class LPVDecoder(nn.Module):
  48. def __init__(self,
  49. in_channels,
  50. out_channels,
  51. num_layer=3,
  52. max_len=25,
  53. use_mask=False,
  54. dim_feedforward=1024,
  55. nhead=8,
  56. dropout=0.1,
  57. trans_layer=2):
  58. super().__init__()
  59. self.use_mask = use_mask
  60. self.max_len = max_len
  61. attn_layer = PositionAttention(max_length=max_len + 1,
  62. mode='nearest',
  63. in_channels=in_channels,
  64. num_channels=in_channels // 8)
  65. trans_layer = Trans(dim=in_channels,
  66. nhead=nhead,
  67. dim_feedforward=dim_feedforward,
  68. dropout=dropout,
  69. num_layers=trans_layer)
  70. cls_layer = nn.Linear(in_channels, out_channels - 2)
  71. self.attention = _get_clones(attn_layer, num_layer)
  72. self.trans = _get_clones(trans_layer, num_layer - 1)
  73. self.cls = _get_clones(cls_layer, num_layer)
  74. def forward(self, x, data=None):
  75. if data is not None:
  76. max_len = data[1].max()
  77. else:
  78. max_len = self.max_len
  79. features = x # (N, E, H, W)
  80. attn_vecs, attn_scores_map = self.attention[0](features)
  81. attn_vecs = attn_vecs[:, :max_len + 1, :]
  82. if not self.training:
  83. for i in range(1, len(self.attention)):
  84. features, mask = self.trans[i - 1](features,
  85. attn_scores_map,
  86. use_mask=self.use_mask)
  87. attn_vecs, attn_scores_map = self.attention[i](
  88. features, attn_vecs) # (N, T, E), (N, T, H, W)
  89. return F.softmax(self.cls[-1](attn_vecs), -1)
  90. else:
  91. logits = []
  92. logit = self.cls[0](attn_vecs) # (N, T, C)
  93. logits.append(logit)
  94. for i in range(1, len(self.attention)):
  95. features, mask = self.trans[i - 1](features,
  96. attn_scores_map,
  97. use_mask=self.use_mask)
  98. attn_vecs, attn_scores_map = self.attention[i](
  99. features, attn_vecs) # (N, T, E), (N, T, H, W)
  100. logit = self.cls[i](attn_vecs) # (N, T, C)
  101. logits.append(logit)
  102. return logits