1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import copy
- from importlib import import_module
- from torch import nn
- name_to_module = {
- 'ABINetLoss': '.abinet_loss',
- 'ARLoss': '.ar_loss',
- 'CDistNetLoss': '.cdistnet_loss',
- 'CELoss': '.ce_loss',
- 'CPPDLoss': '.cppd_loss',
- 'CTCLoss': '.ctc_loss',
- 'IGTRLoss': '.igtr_loss',
- 'LISTERLoss': '.lister_loss',
- 'LPVLoss': '.lpv_loss',
- 'MGPLoss': '.mgp_loss',
- 'PARSeqLoss': '.parseq_loss',
- 'RobustScannerLoss': '.robustscanner_loss',
- 'SEEDLoss': '.seed_loss',
- 'SMTRLoss': '.smtr_loss',
- 'SRNLoss': '.srn_loss',
- 'VisionLANLoss': '.visionlan_loss',
- 'CAMLoss': '.cam_loss',
- }
- def build_loss(config):
- config = copy.deepcopy(config)
- module_name = config.pop('name')
- if module_name in globals():
- module_class = globals()[module_name]
- else:
- assert module_name in name_to_module, Exception(
- '{} is not supported. The losses in {} are supportes'.format(
- module_name, list(name_to_module.keys())))
- module_path = name_to_module[module_name]
- module = import_module(module_path, package=__package__)
- module_class = getattr(module, module_name)
- return module_class(**config)
- class GTCLoss(nn.Module):
- def __init__(self,
- gtc_loss,
- gtc_weight=1.0,
- ctc_weight=1.0,
- zero_infinity=True,
- **kwargs):
- super(GTCLoss, self).__init__()
- # 动态构建CTCLoss
- ctc_config = {'name': 'CTCLoss', 'zero_infinity': zero_infinity}
- self.ctc_loss = build_loss(ctc_config)
- # 构建GTC损失
- self.gtc_loss = build_loss(gtc_loss)
- self.gtc_weight = gtc_weight
- self.ctc_weight = ctc_weight
- def forward(self, predicts, batch):
- ctc_loss = self.ctc_loss(predicts['ctc_pred'],
- [None] + batch[-2:])['loss']
- gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss']
- return {
- 'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss,
- 'ctc_loss': ctc_loss,
- 'gtc_loss': gtc_loss
- }
|