cam_loss.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. import torch.nn.functional as F
  3. from .ar_loss import ARLoss
  4. def BanlanceMultiClassCrossEntropyLoss(x_o, x_t):
  5. # [B, num_cls, H, W]
  6. B, num_cls, H, W = x_o.shape
  7. x_o = x_o.reshape(B, num_cls, H * W).permute(0, 2, 1)
  8. # [B, H, W, num_cls]
  9. # generate gt
  10. x_t[x_t > 0.5] = 1
  11. x_t[x_t <= 0.5] = 0
  12. fg_x_t = x_t.sum(-1) # [B, H, W]
  13. x_t = x_t.argmax(-1) # [B, H, W]
  14. x_t[fg_x_t == 0] = num_cls - 1 # background
  15. x_t = x_t.reshape(B, H * W)
  16. # loss
  17. weight = torch.ones((B, num_cls)).type_as(x_o) # the weight of bg is 1.
  18. num_bg = (x_t == (num_cls - 1)).sum(-1) # [B]
  19. weight[:, :-1] = (num_bg / (H * W - num_bg + 1e-5)).unsqueeze(-1).expand(
  20. -1, num_cls - 1)
  21. logit = F.log_softmax(x_o, dim=-1) # [B, H*W, num_cls]
  22. logit = logit * weight.unsqueeze(1)
  23. loss = -logit.gather(2, x_t.unsqueeze(-1).long())
  24. return loss.mean()
  25. class CAMLoss(ARLoss):
  26. def __init__(self, label_smoothing=0.1, loss_weight_binary=1.5, **kwargs):
  27. super(CAMLoss, self).__init__(label_smoothing=label_smoothing)
  28. self.label_smoothing = label_smoothing
  29. self.loss_weight_binary = loss_weight_binary
  30. def forward(self, pred, batch):
  31. binary_mask = batch[-1]
  32. rec_loss = super().forward(pred['rec_output'], batch[:-1])['loss']
  33. output = pred
  34. loss_binary = self.loss_weight_binary * BanlanceMultiClassCrossEntropyLoss(
  35. output['pred_binary'], binary_mask)
  36. return {
  37. 'loss': rec_loss + loss_binary,
  38. 'rec_loss': rec_loss,
  39. 'loss_binary': loss_binary
  40. }