__init__.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import torch.nn as nn
  2. from importlib import import_module
  3. __all__ = ['build_decoder']
  4. class_to_module = {
  5. 'ABINetDecoder': '.abinet_decoder',
  6. 'ASTERDecoder': '.aster_decoder',
  7. 'CDistNetDecoder': '.cdistnet_decoder',
  8. 'CPPDDecoder': '.cppd_decoder',
  9. 'RCTCDecoder': '.rctc_decoder',
  10. 'CTCDecoder': '.ctc_decoder',
  11. 'DANDecoder': '.dan_decoder',
  12. 'IGTRDecoder': '.igtr_decoder',
  13. 'LISTERDecoder': '.lister_decoder',
  14. 'LPVDecoder': '.lpv_decoder',
  15. 'MGPDecoder': '.mgp_decoder',
  16. 'NRTRDecoder': '.nrtr_decoder',
  17. 'PARSeqDecoder': '.parseq_decoder',
  18. 'RobustScannerDecoder': '.robustscanner_decoder',
  19. 'SARDecoder': '.sar_decoder',
  20. 'SMTRDecoder': '.smtr_decoder',
  21. 'SMTRDecoderNumAttn': '.smtr_decoder_nattn',
  22. 'SRNDecoder': '.srn_decoder',
  23. 'VisionLANDecoder': '.visionlan_decoder',
  24. 'MATRNDecoder': '.matrn_decoder',
  25. 'CAMDecoder': '.cam_decoder',
  26. 'OTEDecoder': '.ote_decoder',
  27. 'BUSDecoder': '.bus_decoder',
  28. 'DptrParseq': '.dptr_parseq_clip_b_decoder',
  29. }
  30. def build_decoder(config):
  31. module_name = config.pop('name')
  32. # Check if the class is defined in current module (e.g., GTCDecoder)
  33. if module_name in globals():
  34. module_class = globals()[module_name]
  35. else:
  36. if module_name not in class_to_module:
  37. raise ValueError(f'Unsupported decoder: {module_name}')
  38. module_str = class_to_module[module_name]
  39. # Dynamically import the module and get the class
  40. module = import_module(module_str, package=__package__)
  41. module_class = getattr(module, module_name)
  42. return module_class(**config)
  43. class GTCDecoder(nn.Module):
  44. def __init__(self,
  45. in_channels,
  46. gtc_decoder,
  47. ctc_decoder,
  48. detach=True,
  49. infer_gtc=False,
  50. out_channels=0,
  51. **kwargs):
  52. super(GTCDecoder, self).__init__()
  53. self.detach = detach
  54. self.infer_gtc = infer_gtc
  55. if infer_gtc:
  56. gtc_decoder['out_channels'] = out_channels[0]
  57. ctc_decoder['out_channels'] = out_channels[1]
  58. gtc_decoder['in_channels'] = in_channels
  59. ctc_decoder['in_channels'] = in_channels
  60. self.gtc_decoder = build_decoder(gtc_decoder)
  61. else:
  62. ctc_decoder['in_channels'] = in_channels
  63. ctc_decoder['out_channels'] = out_channels
  64. self.ctc_decoder = build_decoder(ctc_decoder)
  65. def forward(self, x, data=None):
  66. ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
  67. data=data)
  68. if self.training or self.infer_gtc:
  69. gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
  70. data=data)
  71. return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
  72. else:
  73. return ctc_pred
  74. class GTCDecoderTwo(nn.Module):
  75. def __init__(self,
  76. in_channels,
  77. gtc_decoder,
  78. ctc_decoder,
  79. infer_gtc=False,
  80. out_channels=0,
  81. **kwargs):
  82. super(GTCDecoderTwo, self).__init__()
  83. self.infer_gtc = infer_gtc
  84. gtc_decoder['out_channels'] = out_channels[0]
  85. ctc_decoder['out_channels'] = out_channels[1]
  86. gtc_decoder['in_channels'] = in_channels
  87. ctc_decoder['in_channels'] = in_channels
  88. self.gtc_decoder = build_decoder(gtc_decoder)
  89. self.ctc_decoder = build_decoder(ctc_decoder)
  90. def forward(self, x, data=None):
  91. x_ctc, x_gtc = x
  92. ctc_pred = self.ctc_decoder(x_ctc, data=data)
  93. if self.training or self.infer_gtc:
  94. gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
  95. data=data)
  96. return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
  97. else:
  98. return ctc_pred