base_detector.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torch
  2. from torch import nn
  3. from opendet.modeling.backbones import build_backbone
  4. from opendet.modeling.necks import build_neck
  5. from opendet.modeling.heads import build_head
  6. __all__ = ['BaseDetector']
  7. class BaseDetector(nn.Module):
  8. def __init__(self, config):
  9. """the module for OCR.
  10. args:
  11. config (dict): the super parameters for module.
  12. """
  13. super(BaseDetector, self).__init__()
  14. in_channels = config.get('in_channels', 3)
  15. self.use_wd = config.get('use_wd', True)
  16. # build backbone
  17. if 'Backbone' not in config or config['Backbone'] is None:
  18. self.use_backbone = False
  19. else:
  20. self.use_backbone = True
  21. config['Backbone']['in_channels'] = in_channels
  22. self.backbone = build_backbone(config['Backbone'])
  23. in_channels = self.backbone.out_channels
  24. # build neck
  25. if 'Neck' not in config or config['Neck'] is None:
  26. self.use_neck = False
  27. else:
  28. self.use_neck = True
  29. config['Neck']['in_channels'] = in_channels
  30. self.neck = build_neck(config['Neck'])
  31. in_channels = self.neck.out_channels
  32. # build head
  33. if 'Head' not in config or config['Head'] is None:
  34. self.use_head = False
  35. else:
  36. self.use_head = True
  37. config['Head']['in_channels'] = in_channels
  38. self.head = build_head(config['Head'])
  39. @torch.jit.ignore
  40. def no_weight_decay(self):
  41. if self.use_wd:
  42. if hasattr(self.backbone, 'no_weight_decay'):
  43. no_weight_decay = self.backbone.no_weight_decay()
  44. else:
  45. no_weight_decay = {}
  46. if hasattr(self.head, 'no_weight_decay'):
  47. no_weight_decay.update(self.head.no_weight_decay())
  48. return no_weight_decay
  49. else:
  50. return {}
  51. def forward(self, x, data=None):
  52. if self.use_backbone:
  53. x = self.backbone(x)
  54. if self.use_neck:
  55. x = self.neck(x)
  56. if self.use_head:
  57. x = self.head(x, data=data)
  58. return x