robustscanner_loss.py 576 B

1234567891011121314151617181920
  1. from torch import nn
  2. class RobustScannerLoss(nn.Module):
  3. def __init__(self, **kwargs):
  4. super(RobustScannerLoss, self).__init__()
  5. ignore_index = kwargs.get('ignore_index', 38)
  6. self.loss_func = nn.CrossEntropyLoss(reduction='mean',
  7. ignore_index=ignore_index)
  8. def forward(self, pred, batch):
  9. pred = pred[:, :-1, :]
  10. label = batch[1][:, 1:].reshape([-1])
  11. inputs = pred.reshape([-1, pred.shape[2]])
  12. loss = self.loss_func(inputs, label)
  13. return {'loss': loss}