__init__.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import copy
  2. from importlib import import_module
  3. from torch import nn
  4. name_to_module = {
  5. 'ABINetLoss': '.abinet_loss',
  6. 'ARLoss': '.ar_loss',
  7. 'CDistNetLoss': '.cdistnet_loss',
  8. 'CELoss': '.ce_loss',
  9. 'CPPDLoss': '.cppd_loss',
  10. 'CTCLoss': '.ctc_loss',
  11. 'IGTRLoss': '.igtr_loss',
  12. 'LISTERLoss': '.lister_loss',
  13. 'LPVLoss': '.lpv_loss',
  14. 'MGPLoss': '.mgp_loss',
  15. 'PARSeqLoss': '.parseq_loss',
  16. 'RobustScannerLoss': '.robustscanner_loss',
  17. 'SEEDLoss': '.seed_loss',
  18. 'SMTRLoss': '.smtr_loss',
  19. 'SRNLoss': '.srn_loss',
  20. 'VisionLANLoss': '.visionlan_loss',
  21. 'CAMLoss': '.cam_loss',
  22. }
  23. def build_loss(config):
  24. config = copy.deepcopy(config)
  25. module_name = config.pop('name')
  26. if module_name in globals():
  27. module_class = globals()[module_name]
  28. else:
  29. assert module_name in name_to_module, Exception(
  30. '{} is not supported. The losses in {} are supportes'.format(
  31. module_name, list(name_to_module.keys())))
  32. module_path = name_to_module[module_name]
  33. module = import_module(module_path, package=__package__)
  34. module_class = getattr(module, module_name)
  35. return module_class(**config)
  36. class GTCLoss(nn.Module):
  37. def __init__(self,
  38. gtc_loss,
  39. gtc_weight=1.0,
  40. ctc_weight=1.0,
  41. zero_infinity=True,
  42. **kwargs):
  43. super(GTCLoss, self).__init__()
  44. # 动态构建CTCLoss
  45. ctc_config = {'name': 'CTCLoss', 'zero_infinity': zero_infinity}
  46. self.ctc_loss = build_loss(ctc_config)
  47. # 构建GTC损失
  48. self.gtc_loss = build_loss(gtc_loss)
  49. self.gtc_weight = gtc_weight
  50. self.ctc_weight = ctc_weight
  51. def forward(self, predicts, batch):
  52. ctc_loss = self.ctc_loss(predicts['ctc_pred'],
  53. [None] + batch[-2:])['loss']
  54. gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss']
  55. return {
  56. 'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss,
  57. 'ctc_loss': ctc_loss,
  58. 'gtc_loss': gtc_loss
  59. }