config.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import os
  2. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  3. from collections.abc import Mapping
  4. import yaml
  5. __all__ = ['Config']
  6. class ArgsParser(ArgumentParser):
  7. def __init__(self):
  8. super(ArgsParser,
  9. self).__init__(formatter_class=RawDescriptionHelpFormatter)
  10. self.add_argument('-o',
  11. '--opt',
  12. nargs='*',
  13. help='set configuration options')
  14. self.add_argument('--local_rank')
  15. def parse_args(self, argv=None):
  16. args = super(ArgsParser, self).parse_args(argv)
  17. assert args.config is not None, 'Please specify --config=configure_file_path.'
  18. args.opt = self._parse_opt(args.opt)
  19. return args
  20. def _parse_opt(self, opts):
  21. config = {}
  22. if not opts:
  23. return config
  24. for s in opts:
  25. s = s.strip()
  26. k, v = s.split('=', 1)
  27. if '.' not in k:
  28. config[k] = yaml.load(v, Loader=yaml.Loader)
  29. else:
  30. keys = k.split('.')
  31. if keys[0] not in config:
  32. config[keys[0]] = {}
  33. cur = config[keys[0]]
  34. for idx, key in enumerate(keys[1:]):
  35. if idx == len(keys) - 2:
  36. cur[key] = yaml.load(v, Loader=yaml.Loader)
  37. else:
  38. cur[key] = {}
  39. cur = cur[key]
  40. return config
  41. class AttrDict(dict):
  42. """Single level attribute dict, NOT recursive."""
  43. def __init__(self, **kwargs):
  44. super(AttrDict, self).__init__()
  45. super(AttrDict, self).update(kwargs)
  46. def __getattr__(self, key):
  47. if key in self:
  48. return self[key]
  49. raise AttributeError("object has no attribute '{}'".format(key))
  50. def _merge_dict(config, merge_dct):
  51. """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
  52. updating only top-level keys, dict_merge recurses down into dicts nested to
  53. an arbitrary depth, updating keys. The ``merge_dct`` is merged into
  54. ``dct``.
  55. Args:
  56. config: dict onto which the merge is executed
  57. merge_dct: dct merged into config
  58. Returns: dct
  59. """
  60. for key, value in merge_dct.items():
  61. sub_keys = key.split('.')
  62. key = sub_keys[0]
  63. if key in config and len(sub_keys) > 1:
  64. _merge_dict(config[key], {'.'.join(sub_keys[1:]): value})
  65. elif key in config and isinstance(config[key], dict) and isinstance(
  66. value, Mapping):
  67. _merge_dict(config[key], value)
  68. else:
  69. config[key] = value
  70. return config
  71. def print_dict(cfg, print_func=print, delimiter=0):
  72. """Recursively visualize a dict and indenting acrrording by the
  73. relationship of keys."""
  74. for k, v in sorted(cfg.items()):
  75. if isinstance(v, dict):
  76. print_func('{}{} : '.format(delimiter * ' ', str(k)))
  77. print_dict(v, print_func, delimiter + 4)
  78. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  79. print_func('{}{} : '.format(delimiter * ' ', str(k)))
  80. for value in v:
  81. print_dict(value, print_func, delimiter + 4)
  82. else:
  83. print_func('{}{} : {}'.format(delimiter * ' ', k, v))
  84. class Config(object):
  85. def __init__(self, config_path, BASE_KEY='_BASE_'):
  86. self.BASE_KEY = BASE_KEY
  87. self.cfg = self._load_config_with_base(config_path)
  88. def _load_config_with_base(self, file_path):
  89. """Load config from file.
  90. Args:
  91. file_path (str): Path of the config file to be loaded.
  92. Returns: global config
  93. """
  94. _, ext = os.path.splitext(file_path)
  95. assert ext in ['.yml', '.yaml'], 'only support yaml files for now'
  96. with open(file_path) as f:
  97. file_cfg = yaml.load(f, Loader=yaml.Loader)
  98. # NOTE: cfgs outside have higher priority than cfgs in _BASE_
  99. if self.BASE_KEY in file_cfg:
  100. all_base_cfg = AttrDict()
  101. base_ymls = list(file_cfg[self.BASE_KEY])
  102. for base_yml in base_ymls:
  103. if base_yml.startswith('~'):
  104. base_yml = os.path.expanduser(base_yml)
  105. if not base_yml.startswith('/'):
  106. base_yml = os.path.join(os.path.dirname(file_path),
  107. base_yml)
  108. with open(base_yml) as f:
  109. base_cfg = self._load_config_with_base(base_yml)
  110. all_base_cfg = _merge_dict(all_base_cfg, base_cfg)
  111. del file_cfg[self.BASE_KEY]
  112. file_cfg = _merge_dict(all_base_cfg, file_cfg)
  113. file_cfg['filename'] = os.path.splitext(
  114. os.path.split(file_path)[-1])[0]
  115. return file_cfg
  116. def merge_dict(self, args):
  117. self.cfg = _merge_dict(self.cfg, args)
  118. def print_cfg(self, print_func=print):
  119. """Recursively visualize a dict and indenting acrrording by the
  120. relationship of keys."""
  121. print_func('----------- Config -----------')
  122. print_dict(self.cfg, print_func)
  123. print_func('---------------------------------------------')
  124. def save(self, p, cfg=None):
  125. if cfg is None:
  126. cfg = self.cfg
  127. with open(p, 'w') as f:
  128. yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False)