abinet_loss.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. from torch import nn
  3. class ABINetLoss(nn.Module):
  4. def __init__(self,
  5. smoothing=False,
  6. ignore_index=100,
  7. align_weight=1.0,
  8. **kwargs):
  9. super(ABINetLoss, self).__init__()
  10. if ignore_index >= 0:
  11. self.loss_func = nn.CrossEntropyLoss(reduction='mean',
  12. ignore_index=ignore_index)
  13. else:
  14. self.loss_func = nn.CrossEntropyLoss(reduction='mean')
  15. self.smoothing = smoothing
  16. self.align_weight = align_weight
  17. def forward(self, pred, batch):
  18. loss = {}
  19. loss_sum = []
  20. for name, logits in pred.items():
  21. if isinstance(logits, list):
  22. logit_num = len(logits)
  23. if logit_num > 0:
  24. all_tgt = torch.cat([batch[1]] * logit_num, 0)
  25. all_logits = torch.cat(logits, 0)
  26. flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
  27. flt_tgt = all_tgt.reshape([-1])
  28. else:
  29. continue
  30. else:
  31. flt_logtis = logits.reshape([-1, logits.shape[2]])
  32. flt_tgt = batch[1].reshape([-1])
  33. loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt) * (
  34. self.align_weight if name == 'align' else 1.0)
  35. loss_sum.append(loss[name + '_loss'])
  36. loss['loss'] = sum(loss_sum)
  37. return loss