import copy from importlib import import_module from torch import nn name_to_module = { 'ABINetLoss': '.abinet_loss', 'ARLoss': '.ar_loss', 'CDistNetLoss': '.cdistnet_loss', 'CELoss': '.ce_loss', 'CPPDLoss': '.cppd_loss', 'CTCLoss': '.ctc_loss', 'IGTRLoss': '.igtr_loss', 'LISTERLoss': '.lister_loss', 'LPVLoss': '.lpv_loss', 'MGPLoss': '.mgp_loss', 'PARSeqLoss': '.parseq_loss', 'RobustScannerLoss': '.robustscanner_loss', 'SEEDLoss': '.seed_loss', 'SMTRLoss': '.smtr_loss', 'SRNLoss': '.srn_loss', 'VisionLANLoss': '.visionlan_loss', 'CAMLoss': '.cam_loss', } def build_loss(config): config = copy.deepcopy(config) module_name = config.pop('name') if module_name in globals(): module_class = globals()[module_name] else: assert module_name in name_to_module, Exception( '{} is not supported. The losses in {} are supportes'.format( module_name, list(name_to_module.keys()))) module_path = name_to_module[module_name] module = import_module(module_path, package=__package__) module_class = getattr(module, module_name) return module_class(**config) class GTCLoss(nn.Module): def __init__(self, gtc_loss, gtc_weight=1.0, ctc_weight=1.0, zero_infinity=True, **kwargs): super(GTCLoss, self).__init__() # 动态构建CTCLoss ctc_config = {'name': 'CTCLoss', 'zero_infinity': zero_infinity} self.ctc_loss = build_loss(ctc_config) # 构建GTC损失 self.gtc_loss = build_loss(gtc_loss) self.gtc_weight = gtc_weight self.ctc_weight = ctc_weight def forward(self, predicts, batch): ctc_loss = self.ctc_loss(predicts['ctc_pred'], [None] + batch[-2:])['loss'] gtc_loss = self.gtc_loss(predicts['gtc_pred'], batch[:-2])['loss'] return { 'loss': self.ctc_weight * ctc_loss + self.gtc_weight * gtc_loss, 'ctc_loss': ctc_loss, 'gtc_loss': gtc_loss }