12345678910111213141516171819202122232425262728293031323334 |
- from torch import nn
- class MGPLoss(nn.Module):
- def __init__(self, only_char=False, **kwargs):
- super(MGPLoss, self).__init__()
- self.ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
- self.only_char = only_char
- def forward(self, pred, batch):
- if self.only_char:
- char_feats = pred
- char_tgt = batch[1].flatten(0, 1)
- char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
- return {'loss': char_loss}
- else:
- return self.forward_all(pred, batch)
- def forward_all(self, pred, batch):
- char_feats, dpe_feats, wp_feats = pred
- char_tgt = batch[1].flatten(0, 1)
- dpe_tgt = batch[2].flatten(0, 1)
- wp_tgt = batch[3].flatten(0, 1)
- char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
- dpe_loss = self.ce(dpe_feats.flatten(0, 1), dpe_tgt)
- wp_loss = self.ce(wp_feats.flatten(0, 1), wp_tgt)
- loss = char_loss + dpe_loss + wp_loss
- return {
- 'loss': loss,
- 'char_loss': char_loss,
- 'dpe_loss': dpe_loss,
- 'wp_loss': wp_loss
- }
|