12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- import torch.nn as nn
- from .nrtr_decoder import NRTRDecoder
- class CAMDecoder(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- nhead=None,
- num_encoder_layers=6,
- beam_size=0,
- num_decoder_layers=6,
- max_len=25,
- attention_dropout_rate=0.0,
- residual_dropout_rate=0.1,
- scale_embedding=True,
- ):
- super().__init__()
- self.decoder = NRTRDecoder(
- in_channels=in_channels,
- out_channels=out_channels,
- nhead=nhead,
- num_encoder_layers=num_encoder_layers,
- beam_size=beam_size,
- num_decoder_layers=num_decoder_layers,
- max_len=max_len,
- attention_dropout_rate=attention_dropout_rate,
- residual_dropout_rate=residual_dropout_rate,
- scale_embedding=scale_embedding,
- )
- def forward(self, x, data=None):
- dec_in = x['refined_feat']
- dec_output = self.decoder(dec_in, data=data)
- x['rec_output'] = dec_output
- if self.training:
- return x
- else:
- return dec_output
|