base_recognizer.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torch
  2. from torch import nn
  3. from openrec.modeling.decoders import build_decoder
  4. from openrec.modeling.encoders import build_encoder
  5. from openrec.modeling.transforms import build_transform
  6. __all__ = ['BaseRecognizer']
  7. class BaseRecognizer(nn.Module):
  8. def __init__(self, config):
  9. """the module for OCR.
  10. args:
  11. config (dict): the super parameters for module.
  12. """
  13. super(BaseRecognizer, self).__init__()
  14. in_channels = config.get('in_channels', 3)
  15. self.use_wd = config.get('use_wd', True)
  16. # build transfrom,
  17. # for rec, transfrom can be TPS,None
  18. if 'Transform' not in config or config['Transform'] is None:
  19. self.use_transform = False
  20. else:
  21. self.use_transform = True
  22. config['Transform']['in_channels'] = in_channels
  23. self.transform = build_transform(config['Transform'])
  24. in_channels = self.transform.out_channels
  25. # build backbone
  26. if 'Encoder' not in config or config['Encoder'] is None:
  27. self.use_encoder = False
  28. else:
  29. self.use_encoder = True
  30. config['Encoder']['in_channels'] = in_channels
  31. self.encoder = build_encoder(config['Encoder'])
  32. in_channels = self.encoder.out_channels
  33. # build decoder
  34. if 'Decoder' not in config or config['Decoder'] is None:
  35. self.use_decoder = False
  36. else:
  37. self.use_decoder = True
  38. config['Decoder']['in_channels'] = in_channels
  39. self.decoder = build_decoder(config['Decoder'])
  40. @torch.jit.ignore
  41. def no_weight_decay(self):
  42. if self.use_wd:
  43. if hasattr(self.encoder, 'no_weight_decay'):
  44. no_weight_decay = self.encoder.no_weight_decay()
  45. else:
  46. no_weight_decay = {}
  47. if hasattr(self.decoder, 'no_weight_decay'):
  48. no_weight_decay.update(self.decoder.no_weight_decay())
  49. return no_weight_decay
  50. else:
  51. return {}
  52. def forward(self, x, data=None):
  53. if self.use_transform:
  54. x = self.transform(x)
  55. if self.use_encoder:
  56. x = self.encoder(x)
  57. if self.use_decoder:
  58. x = self.decoder(x, data=data)
  59. return x