mgp_loss.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. from torch import nn
  2. class MGPLoss(nn.Module):
  3. def __init__(self, only_char=False, **kwargs):
  4. super(MGPLoss, self).__init__()
  5. self.ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
  6. self.only_char = only_char
  7. def forward(self, pred, batch):
  8. if self.only_char:
  9. char_feats = pred
  10. char_tgt = batch[1].flatten(0, 1)
  11. char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
  12. return {'loss': char_loss}
  13. else:
  14. return self.forward_all(pred, batch)
  15. def forward_all(self, pred, batch):
  16. char_feats, dpe_feats, wp_feats = pred
  17. char_tgt = batch[1].flatten(0, 1)
  18. dpe_tgt = batch[2].flatten(0, 1)
  19. wp_tgt = batch[3].flatten(0, 1)
  20. char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
  21. dpe_loss = self.ce(dpe_feats.flatten(0, 1), dpe_tgt)
  22. wp_loss = self.ce(wp_feats.flatten(0, 1), wp_tgt)
  23. loss = char_loss + dpe_loss + wp_loss
  24. return {
  25. 'loss': loss,
  26. 'char_loss': char_loss,
  27. 'dpe_loss': dpe_loss,
  28. 'wp_loss': wp_loss
  29. }