123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import torch
- from torch import nn
- from openrec.modeling.decoders import build_decoder
- from openrec.modeling.encoders import build_encoder
- from openrec.modeling.transforms import build_transform
- __all__ = ['BaseRecognizer']
- class BaseRecognizer(nn.Module):
- def __init__(self, config):
- """the module for OCR.
- args:
- config (dict): the super parameters for module.
- """
- super(BaseRecognizer, self).__init__()
- in_channels = config.get('in_channels', 3)
- self.use_wd = config.get('use_wd', True)
- # build transfrom,
- # for rec, transfrom can be TPS,None
- if 'Transform' not in config or config['Transform'] is None:
- self.use_transform = False
- else:
- self.use_transform = True
- config['Transform']['in_channels'] = in_channels
- self.transform = build_transform(config['Transform'])
- in_channels = self.transform.out_channels
- # build backbone
- if 'Encoder' not in config or config['Encoder'] is None:
- self.use_encoder = False
- else:
- self.use_encoder = True
- config['Encoder']['in_channels'] = in_channels
- self.encoder = build_encoder(config['Encoder'])
- in_channels = self.encoder.out_channels
- # build decoder
- if 'Decoder' not in config or config['Decoder'] is None:
- self.use_decoder = False
- else:
- self.use_decoder = True
- config['Decoder']['in_channels'] = in_channels
- self.decoder = build_decoder(config['Decoder'])
- @torch.jit.ignore
- def no_weight_decay(self):
- if self.use_wd:
- if hasattr(self.encoder, 'no_weight_decay'):
- no_weight_decay = self.encoder.no_weight_decay()
- else:
- no_weight_decay = {}
- if hasattr(self.decoder, 'no_weight_decay'):
- no_weight_decay.update(self.decoder.no_weight_decay())
- return no_weight_decay
- else:
- return {}
- def forward(self, x, data=None):
- if self.use_transform:
- x = self.transform(x)
- if self.use_encoder:
- x = self.encoder(x)
- if self.use_decoder:
- x = self.decoder(x, data=data)
- return x
|