ce_loss.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. class CELoss(nn.Module):
  5. def __init__(self,
  6. smoothing=False,
  7. with_all=False,
  8. ignore_index=-1,
  9. **kwargs):
  10. super(CELoss, self).__init__()
  11. if ignore_index >= 0:
  12. self.loss_func = nn.CrossEntropyLoss(reduction='mean',
  13. ignore_index=ignore_index)
  14. else:
  15. self.loss_func = nn.CrossEntropyLoss(reduction='mean')
  16. self.smoothing = smoothing
  17. self.with_all = with_all
  18. def forward(self, pred, batch):
  19. pred = pred['res']
  20. if isinstance(pred, dict): # for ABINet
  21. loss = {}
  22. loss_sum = []
  23. for name, logits in pred.items():
  24. if isinstance(logits, list):
  25. logit_num = len(logits)
  26. all_tgt = torch.cat([batch[1]] * logit_num, 0)
  27. all_logits = torch.cat(logits, 0)
  28. flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
  29. flt_tgt = all_tgt.reshape([-1])
  30. else:
  31. flt_logtis = logits.reshape([-1, logits.shape[2]])
  32. flt_tgt = batch[1].reshape([-1])
  33. loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
  34. loss_sum.append(loss[name + '_loss'])
  35. loss['loss'] = sum(loss_sum)
  36. return loss
  37. else:
  38. if self.with_all: # for ViTSTR
  39. tgt = batch[1]
  40. pred = pred.reshape([-1, pred.shape[2]])
  41. tgt = tgt.reshape([-1])
  42. loss = self.loss_func(pred, tgt)
  43. return {'loss': loss}
  44. else: # for NRTR
  45. max_len = batch[2].max()
  46. tgt = batch[1][:, 1:2 + max_len]
  47. pred = pred.reshape([-1, pred.shape[2]])
  48. tgt = tgt.reshape([-1])
  49. if self.smoothing:
  50. eps = 0.1
  51. pred.shape[1]
  52. one_hot = F.one_hot(tgt, pred.shape[1])
  53. one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (-1)
  54. log_prb = F.log_softmax(pred, dim=1)
  55. non_pad_mask = torch.not_equal(
  56. tgt,
  57. torch.zeros(tgt.shape,
  58. dtype=tgt.dtype,
  59. device=tgt.device))
  60. loss = -(one_hot * log_prb).sum(dim=1)
  61. loss = loss.masked_select(non_pad_mask).mean()
  62. else:
  63. loss = self.loss_func(pred, tgt)
  64. return {'loss': loss}