det_basic_loss.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import torch
  22. from torch import nn
  23. import torch.nn.functional as F
  24. class BalanceLoss(nn.Module):
  25. def __init__(
  26. self,
  27. balance_loss=True,
  28. main_loss_type='DiceLoss',
  29. negative_ratio=3,
  30. return_origin=False,
  31. eps=1e-6,
  32. **kwargs,
  33. ):
  34. """
  35. The BalanceLoss for Differentiable Binarization text detection
  36. args:
  37. balance_loss (bool): whether balance loss or not, default is True
  38. main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
  39. 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
  40. negative_ratio (int|float): float, default is 3.
  41. return_origin (bool): whether return unbalanced loss or not, default is False.
  42. eps (float): default is 1e-6.
  43. """
  44. super(BalanceLoss, self).__init__()
  45. self.balance_loss = balance_loss
  46. self.main_loss_type = main_loss_type
  47. self.negative_ratio = negative_ratio
  48. self.return_origin = return_origin
  49. self.eps = eps
  50. if self.main_loss_type == 'CrossEntropy':
  51. self.loss = nn.CrossEntropyLoss()
  52. elif self.main_loss_type == 'Euclidean':
  53. self.loss = nn.MSELoss()
  54. elif self.main_loss_type == 'DiceLoss':
  55. self.loss = DiceLoss(self.eps)
  56. elif self.main_loss_type == 'BCELoss':
  57. self.loss = BCELoss(reduction='none')
  58. elif self.main_loss_type == 'MaskL1Loss':
  59. self.loss = MaskL1Loss(self.eps)
  60. else:
  61. loss_type = [
  62. 'CrossEntropy',
  63. 'DiceLoss',
  64. 'Euclidean',
  65. 'BCELoss',
  66. 'MaskL1Loss',
  67. ]
  68. raise Exception(
  69. 'main_loss_type in BalanceLoss() can only be one of {}'.format(
  70. loss_type))
  71. def forward(self, pred, gt, mask=None):
  72. """
  73. The BalanceLoss for Differentiable Binarization text detection
  74. args:
  75. pred (variable): predicted feature maps.
  76. gt (variable): ground truth feature maps.
  77. mask (variable): masked maps.
  78. return: (variable) balanced loss
  79. """
  80. positive = gt * mask
  81. negative = (1 - gt) * mask
  82. positive_count = int(positive.sum())
  83. negative_count = int(
  84. min(negative.sum(), positive_count * self.negative_ratio))
  85. loss = self.loss(pred, gt, mask=mask)
  86. if not self.balance_loss:
  87. return loss
  88. positive_loss = positive * loss
  89. negative_loss = negative * loss
  90. negative_loss = negative_loss.reshape(-1)
  91. if negative_count > 0:
  92. negative_loss, _ = negative_loss.topk(negative_count)
  93. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
  94. positive_count + negative_count + self.eps)
  95. else:
  96. balance_loss = positive_loss.sum() / (positive_count + self.eps)
  97. if self.return_origin:
  98. return balance_loss, loss
  99. return balance_loss
  100. class DiceLoss(nn.Module):
  101. def __init__(self, eps=1e-6):
  102. super(DiceLoss, self).__init__()
  103. self.eps = eps
  104. def forward(self, pred, gt, mask, weights=None):
  105. """
  106. DiceLoss function.
  107. """
  108. assert pred.shape == gt.shape
  109. assert pred.shape == mask.shape
  110. if weights is not None:
  111. assert weights.shape == mask.shape
  112. mask = weights * mask
  113. intersection = torch.sum(pred * gt * mask)
  114. union = torch.sum(pred * mask) + torch.sum(gt * mask) + self.eps
  115. loss = 1 - 2.0 * intersection / union
  116. assert loss <= 1
  117. return loss
  118. class MaskL1Loss(nn.Module):
  119. def __init__(self, eps=1e-6):
  120. super(MaskL1Loss, self).__init__()
  121. self.eps = eps
  122. def forward(self, pred, gt, mask):
  123. """
  124. Mask L1 Loss
  125. """
  126. loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  127. loss = torch.mean(loss)
  128. return loss
  129. class BCELoss(nn.Module):
  130. def __init__(self, reduction='mean'):
  131. super(BCELoss, self).__init__()
  132. self.reduction = reduction
  133. def forward(self, input, label, mask=None, weight=None, name=None):
  134. loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
  135. return loss