12345678910111213141516171819202122232425262728293031323334353637383940 |
- import torch.nn.functional as F
- from torch import nn
- class SRNLoss(nn.Module):
- def __init__(self, label_smoothing=0.0, **kwargs):
- super(SRNLoss, self).__init__()
- self.label_smoothing = label_smoothing
- def forward(self, preds, batch):
- pvam_preds, gsrm_preds, vsfd_preds = preds
- label = batch[1].reshape([-1])
- ignore_index = pvam_preds.shape[-1] + 1
- loss_pvam = F.cross_entropy(pvam_preds,
- label,
- reduction='mean',
- label_smoothing=self.label_smoothing,
- ignore_index=ignore_index)
- loss_gsrm = F.cross_entropy(gsrm_preds,
- label,
- reduction='mean',
- label_smoothing=self.label_smoothing,
- ignore_index=ignore_index)
- loss_vsfd = F.cross_entropy(vsfd_preds,
- label,
- reduction='mean',
- label_smoothing=self.label_smoothing,
- ignore_index=ignore_index)
- loss_dict = {}
- loss_dict['loss_pvam'] = loss_pvam
- loss_dict['loss_gsrm'] = loss_gsrm
- loss_dict['loss_vsfd'] = loss_vsfd
- loss_dict['loss'] = loss_pvam * 3.0 + loss_gsrm * 0.15 + loss_vsfd
- return loss_dict
|