cppd_loss.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. class CPPDLoss(nn.Module):
  5. def __init__(self,
  6. smoothing=False,
  7. ignore_index=100,
  8. pos_len=False,
  9. sideloss_weight=1.0,
  10. max_len=25,
  11. **kwargs):
  12. super(CPPDLoss, self).__init__()
  13. self.edge_ce = nn.CrossEntropyLoss(reduction='mean',
  14. ignore_index=ignore_index)
  15. self.char_node_ce = nn.CrossEntropyLoss(reduction='mean')
  16. if pos_len:
  17. self.pos_node_ce = nn.CrossEntropyLoss(reduction='mean',
  18. ignore_index=ignore_index)
  19. else:
  20. self.pos_node_ce = nn.BCEWithLogitsLoss(reduction='mean')
  21. self.smoothing = smoothing
  22. self.ignore_index = ignore_index
  23. self.pos_len = pos_len
  24. self.sideloss_weight = sideloss_weight
  25. self.max_len = max_len + 1
  26. def label_smoothing_ce(self, preds, targets):
  27. zeros_ = torch.zeros_like(targets)
  28. ignore_index_ = zeros_ + self.ignore_index
  29. non_pad_mask = torch.not_equal(targets, ignore_index_)
  30. tgts = torch.where(targets == ignore_index_, zeros_, targets)
  31. eps = 0.1
  32. n_class = preds.shape[1]
  33. one_hot = F.one_hot(tgts, preds.shape[1])
  34. one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
  35. log_prb = F.log_softmax(preds, dim=1)
  36. loss = -(one_hot * log_prb).sum(dim=1)
  37. loss = loss.masked_select(non_pad_mask).mean()
  38. return loss
  39. def forward(self, pred, batch):
  40. node_feats, edge_feats = pred
  41. node_tgt = batch[2]
  42. char_tgt = batch[1]
  43. # updated code
  44. char_num_label = torch.clip(node_tgt[:, :-self.max_len].flatten(0, 1),
  45. 0, node_feats[0].shape[-1] - 1)
  46. loss_char_node = self.char_node_ce(node_feats[0].flatten(0, 1),
  47. char_num_label)
  48. if self.pos_len:
  49. loss_pos_node = self.pos_node_ce(
  50. node_feats[1].flatten(0, 1),
  51. node_tgt[:, -self.max_len:].flatten(0, 1))
  52. else:
  53. loss_pos_node = self.pos_node_ce(
  54. node_feats[1].flatten(0, 1),
  55. node_tgt[:, -self.max_len:].flatten(0, 1).float())
  56. loss_node = loss_char_node + loss_pos_node
  57. # -----
  58. edge_feats = edge_feats.flatten(0, 1)
  59. char_tgt = char_tgt.flatten(0, 1)
  60. if self.smoothing:
  61. loss_edge = self.label_smoothing_ce(edge_feats, char_tgt)
  62. else:
  63. loss_edge = self.edge_ce(edge_feats, char_tgt)
  64. return {
  65. 'loss': self.sideloss_weight * loss_node + loss_edge,
  66. 'loss_node': self.sideloss_weight * loss_node,
  67. 'loss_edge': loss_edge,
  68. }