srn_loss.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import torch.nn.functional as F
  2. from torch import nn
  3. class SRNLoss(nn.Module):
  4. def __init__(self, label_smoothing=0.0, **kwargs):
  5. super(SRNLoss, self).__init__()
  6. self.label_smoothing = label_smoothing
  7. def forward(self, preds, batch):
  8. pvam_preds, gsrm_preds, vsfd_preds = preds
  9. label = batch[1].reshape([-1])
  10. ignore_index = pvam_preds.shape[-1] + 1
  11. loss_pvam = F.cross_entropy(pvam_preds,
  12. label,
  13. reduction='mean',
  14. label_smoothing=self.label_smoothing,
  15. ignore_index=ignore_index)
  16. loss_gsrm = F.cross_entropy(gsrm_preds,
  17. label,
  18. reduction='mean',
  19. label_smoothing=self.label_smoothing,
  20. ignore_index=ignore_index)
  21. loss_vsfd = F.cross_entropy(vsfd_preds,
  22. label,
  23. reduction='mean',
  24. label_smoothing=self.label_smoothing,
  25. ignore_index=ignore_index)
  26. loss_dict = {}
  27. loss_dict['loss_pvam'] = loss_pvam
  28. loss_dict['loss_gsrm'] = loss_gsrm
  29. loss_dict['loss_vsfd'] = loss_vsfd
  30. loss_dict['loss'] = loss_pvam * 3.0 + loss_gsrm * 0.15 + loss_vsfd
  31. return loss_dict