cdistnet_loss.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. class CDistNetLoss(nn.Module):
  5. def __init__(self, smoothing=True, ignore_index=0, **kwargs):
  6. super(CDistNetLoss, self).__init__()
  7. if ignore_index >= 0 and not smoothing:
  8. self.loss_func = nn.CrossEntropyLoss(reduction='mean',
  9. ignore_index=ignore_index)
  10. self.smoothing = smoothing
  11. def forward(self, pred, batch):
  12. pred = pred['res']
  13. tgt = batch[1][:, 1:]
  14. pred = pred.reshape([-1, pred.shape[2]])
  15. tgt = tgt.reshape([-1])
  16. if self.smoothing:
  17. eps = 0.1
  18. n_class = pred.shape[1]
  19. one_hot = F.one_hot(tgt.long(), num_classes=pred.shape[1])
  20. torch.set_printoptions(profile='full')
  21. one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
  22. log_prb = F.log_softmax(pred, dim=1)
  23. non_pad_mask = torch.not_equal(
  24. tgt, torch.zeros(tgt.shape, dtype=tgt.dtype,
  25. device=tgt.device))
  26. loss = -(one_hot * log_prb).sum(dim=1)
  27. loss = loss.masked_select(non_pad_mask).mean()
  28. else:
  29. loss = self.loss_func(pred, tgt)
  30. return {'loss': loss}