123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import torch.nn as nn
- from importlib import import_module
- __all__ = ['build_decoder']
- class_to_module = {
- 'ABINetDecoder': '.abinet_decoder',
- 'ASTERDecoder': '.aster_decoder',
- 'CDistNetDecoder': '.cdistnet_decoder',
- 'CPPDDecoder': '.cppd_decoder',
- 'RCTCDecoder': '.rctc_decoder',
- 'CTCDecoder': '.ctc_decoder',
- 'DANDecoder': '.dan_decoder',
- 'IGTRDecoder': '.igtr_decoder',
- 'LISTERDecoder': '.lister_decoder',
- 'LPVDecoder': '.lpv_decoder',
- 'MGPDecoder': '.mgp_decoder',
- 'NRTRDecoder': '.nrtr_decoder',
- 'PARSeqDecoder': '.parseq_decoder',
- 'RobustScannerDecoder': '.robustscanner_decoder',
- 'SARDecoder': '.sar_decoder',
- 'SMTRDecoder': '.smtr_decoder',
- 'SMTRDecoderNumAttn': '.smtr_decoder_nattn',
- 'SRNDecoder': '.srn_decoder',
- 'VisionLANDecoder': '.visionlan_decoder',
- 'MATRNDecoder': '.matrn_decoder',
- 'CAMDecoder': '.cam_decoder',
- 'OTEDecoder': '.ote_decoder',
- 'BUSDecoder': '.bus_decoder',
- 'DptrParseq': '.dptr_parseq_clip_b_decoder',
- }
- def build_decoder(config):
- module_name = config.pop('name')
- # Check if the class is defined in current module (e.g., GTCDecoder)
- if module_name in globals():
- module_class = globals()[module_name]
- else:
- if module_name not in class_to_module:
- raise ValueError(f'Unsupported decoder: {module_name}')
- module_str = class_to_module[module_name]
- # Dynamically import the module and get the class
- module = import_module(module_str, package=__package__)
- module_class = getattr(module, module_name)
- return module_class(**config)
- class GTCDecoder(nn.Module):
- def __init__(self,
- in_channels,
- gtc_decoder,
- ctc_decoder,
- detach=True,
- infer_gtc=False,
- out_channels=0,
- **kwargs):
- super(GTCDecoder, self).__init__()
- self.detach = detach
- self.infer_gtc = infer_gtc
- if infer_gtc:
- gtc_decoder['out_channels'] = out_channels[0]
- ctc_decoder['out_channels'] = out_channels[1]
- gtc_decoder['in_channels'] = in_channels
- ctc_decoder['in_channels'] = in_channels
- self.gtc_decoder = build_decoder(gtc_decoder)
- else:
- ctc_decoder['in_channels'] = in_channels
- ctc_decoder['out_channels'] = out_channels
- self.ctc_decoder = build_decoder(ctc_decoder)
- def forward(self, x, data=None):
- ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
- data=data)
- if self.training or self.infer_gtc:
- gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
- data=data)
- return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
- else:
- return ctc_pred
- class GTCDecoderTwo(nn.Module):
- def __init__(self,
- in_channels,
- gtc_decoder,
- ctc_decoder,
- infer_gtc=False,
- out_channels=0,
- **kwargs):
- super(GTCDecoderTwo, self).__init__()
- self.infer_gtc = infer_gtc
- gtc_decoder['out_channels'] = out_channels[0]
- ctc_decoder['out_channels'] = out_channels[1]
- gtc_decoder['in_channels'] = in_channels
- ctc_decoder['in_channels'] = in_channels
- self.gtc_decoder = build_decoder(gtc_decoder)
- self.ctc_decoder = build_decoder(ctc_decoder)
- def forward(self, x, data=None):
- x_ctc, x_gtc = x
- ctc_pred = self.ctc_decoder(x_ctc, data=data)
- if self.training or self.infer_gtc:
- gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
- data=data)
- return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
- else:
- return ctc_pred
|