db_loss.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import torch
  22. from torch import nn
  23. from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
  24. class DBLoss(nn.Module):
  25. """
  26. Differentiable Binarization (DB) Loss Function
  27. args:
  28. param (dict): the super paramter for DB Loss
  29. """
  30. def __init__(self,
  31. balance_loss=True,
  32. main_loss_type='DiceLoss',
  33. alpha=5,
  34. beta=10,
  35. ohem_ratio=3,
  36. eps=1e-6,
  37. **kwargs):
  38. super(DBLoss, self).__init__()
  39. self.alpha = alpha
  40. self.beta = beta
  41. self.dice_loss = DiceLoss(eps=eps)
  42. self.l1_loss = MaskL1Loss(eps=eps)
  43. self.bce_loss = BalanceLoss(balance_loss=balance_loss,
  44. main_loss_type=main_loss_type,
  45. negative_ratio=ohem_ratio)
  46. def forward(self, predicts, labels):
  47. predict_maps = predicts['maps']
  48. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
  49. 1:]
  50. shrink_maps = predict_maps[:, 0, :, :]
  51. threshold_maps = predict_maps[:, 1, :, :]
  52. binary_maps = predict_maps[:, 2, :, :]
  53. loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
  54. label_shrink_mask)
  55. loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
  56. label_threshold_mask)
  57. loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
  58. label_shrink_mask)
  59. loss_shrink_maps = self.alpha * loss_shrink_maps
  60. loss_threshold_maps = self.beta * loss_threshold_maps
  61. # CBN loss
  62. if 'distance_maps' in predicts.keys():
  63. # distance_maps = predicts['distance_maps']
  64. cbn_maps = predicts['cbn_maps']
  65. cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
  66. label_shrink_mask)
  67. else:
  68. # dis_loss = torch.tensor([0.])
  69. cbn_loss = torch.tensor([0.], device=predict_maps.device)
  70. loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
  71. losses = {
  72. 'loss': loss_all + cbn_loss,
  73. 'loss_shrink_maps': loss_shrink_maps,
  74. 'loss_threshold_maps': loss_threshold_maps,
  75. 'loss_binary_maps': loss_binary_maps,
  76. 'loss_cbn': cbn_loss
  77. }
  78. return losses