123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import torch
- from torch import nn
- from opendet.modeling.backbones import build_backbone
- from opendet.modeling.necks import build_neck
- from opendet.modeling.heads import build_head
- __all__ = ['BaseDetector']
- class BaseDetector(nn.Module):
- def __init__(self, config):
- """the module for OCR.
- args:
- config (dict): the super parameters for module.
- """
- super(BaseDetector, self).__init__()
- in_channels = config.get('in_channels', 3)
- self.use_wd = config.get('use_wd', True)
- # build backbone
- if 'Backbone' not in config or config['Backbone'] is None:
- self.use_backbone = False
- else:
- self.use_backbone = True
- config['Backbone']['in_channels'] = in_channels
- self.backbone = build_backbone(config['Backbone'])
- in_channels = self.backbone.out_channels
- # build neck
- if 'Neck' not in config or config['Neck'] is None:
- self.use_neck = False
- else:
- self.use_neck = True
- config['Neck']['in_channels'] = in_channels
- self.neck = build_neck(config['Neck'])
- in_channels = self.neck.out_channels
- # build head
- if 'Head' not in config or config['Head'] is None:
- self.use_head = False
- else:
- self.use_head = True
- config['Head']['in_channels'] = in_channels
- self.head = build_head(config['Head'])
- @torch.jit.ignore
- def no_weight_decay(self):
- if self.use_wd:
- if hasattr(self.backbone, 'no_weight_decay'):
- no_weight_decay = self.backbone.no_weight_decay()
- else:
- no_weight_decay = {}
- if hasattr(self.head, 'no_weight_decay'):
- no_weight_decay.update(self.head.no_weight_decay())
- return no_weight_decay
- else:
- return {}
- def forward(self, x, data=None):
- if self.use_backbone:
- x = self.backbone(x)
- if self.use_neck:
- x = self.neck(x)
- if self.use_head:
- x = self.head(x, data=data)
- return x
|