cam_decoder.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import torch.nn as nn
  2. from .nrtr_decoder import NRTRDecoder
  3. class CAMDecoder(nn.Module):
  4. def __init__(
  5. self,
  6. in_channels,
  7. out_channels,
  8. nhead=None,
  9. num_encoder_layers=6,
  10. beam_size=0,
  11. num_decoder_layers=6,
  12. max_len=25,
  13. attention_dropout_rate=0.0,
  14. residual_dropout_rate=0.1,
  15. scale_embedding=True,
  16. ):
  17. super().__init__()
  18. self.decoder = NRTRDecoder(
  19. in_channels=in_channels,
  20. out_channels=out_channels,
  21. nhead=nhead,
  22. num_encoder_layers=num_encoder_layers,
  23. beam_size=beam_size,
  24. num_decoder_layers=num_decoder_layers,
  25. max_len=max_len,
  26. attention_dropout_rate=attention_dropout_rate,
  27. residual_dropout_rate=residual_dropout_rate,
  28. scale_embedding=scale_embedding,
  29. )
  30. def forward(self, x, data=None):
  31. dec_in = x['refined_feat']
  32. dec_output = self.decoder(dec_in, data=data)
  33. x['rec_output'] = dec_output
  34. if self.training:
  35. return x
  36. else:
  37. return dec_output