123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- import os
- from argparse import ArgumentParser, RawDescriptionHelpFormatter
- from collections.abc import Mapping
- import yaml
- __all__ = ['Config']
- class ArgsParser(ArgumentParser):
- def __init__(self):
- super(ArgsParser,
- self).__init__(formatter_class=RawDescriptionHelpFormatter)
- self.add_argument('-o',
- '--opt',
- nargs='*',
- help='set configuration options')
- self.add_argument('--local_rank')
- def parse_args(self, argv=None):
- args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, 'Please specify --config=configure_file_path.'
- args.opt = self._parse_opt(args.opt)
- return args
- def _parse_opt(self, opts):
- config = {}
- if not opts:
- return config
- for s in opts:
- s = s.strip()
- k, v = s.split('=', 1)
- if '.' not in k:
- config[k] = yaml.load(v, Loader=yaml.Loader)
- else:
- keys = k.split('.')
- if keys[0] not in config:
- config[keys[0]] = {}
- cur = config[keys[0]]
- for idx, key in enumerate(keys[1:]):
- if idx == len(keys) - 2:
- cur[key] = yaml.load(v, Loader=yaml.Loader)
- else:
- cur[key] = {}
- cur = cur[key]
- return config
- class AttrDict(dict):
- """Single level attribute dict, NOT recursive."""
- def __init__(self, **kwargs):
- super(AttrDict, self).__init__()
- super(AttrDict, self).update(kwargs)
- def __getattr__(self, key):
- if key in self:
- return self[key]
- raise AttributeError("object has no attribute '{}'".format(key))
- def _merge_dict(config, merge_dct):
- """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
- updating only top-level keys, dict_merge recurses down into dicts nested to
- an arbitrary depth, updating keys. The ``merge_dct`` is merged into
- ``dct``.
- Args:
- config: dict onto which the merge is executed
- merge_dct: dct merged into config
- Returns: dct
- """
- for key, value in merge_dct.items():
- sub_keys = key.split('.')
- key = sub_keys[0]
- if key in config and len(sub_keys) > 1:
- _merge_dict(config[key], {'.'.join(sub_keys[1:]): value})
- elif key in config and isinstance(config[key], dict) and isinstance(
- value, Mapping):
- _merge_dict(config[key], value)
- else:
- config[key] = value
- return config
- def print_dict(cfg, print_func=print, delimiter=0):
- """Recursively visualize a dict and indenting acrrording by the
- relationship of keys."""
- for k, v in sorted(cfg.items()):
- if isinstance(v, dict):
- print_func('{}{} : '.format(delimiter * ' ', str(k)))
- print_dict(v, print_func, delimiter + 4)
- elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
- print_func('{}{} : '.format(delimiter * ' ', str(k)))
- for value in v:
- print_dict(value, print_func, delimiter + 4)
- else:
- print_func('{}{} : {}'.format(delimiter * ' ', k, v))
- class Config(object):
- def __init__(self, config_path, BASE_KEY='_BASE_'):
- self.BASE_KEY = BASE_KEY
- self.cfg = self._load_config_with_base(config_path)
- def _load_config_with_base(self, file_path):
- """Load config from file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: global config
- """
- _, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], 'only support yaml files for now'
- with open(file_path) as f:
- file_cfg = yaml.load(f, Loader=yaml.Loader)
- # NOTE: cfgs outside have higher priority than cfgs in _BASE_
- if self.BASE_KEY in file_cfg:
- all_base_cfg = AttrDict()
- base_ymls = list(file_cfg[self.BASE_KEY])
- for base_yml in base_ymls:
- if base_yml.startswith('~'):
- base_yml = os.path.expanduser(base_yml)
- if not base_yml.startswith('/'):
- base_yml = os.path.join(os.path.dirname(file_path),
- base_yml)
- with open(base_yml) as f:
- base_cfg = self._load_config_with_base(base_yml)
- all_base_cfg = _merge_dict(all_base_cfg, base_cfg)
- del file_cfg[self.BASE_KEY]
- file_cfg = _merge_dict(all_base_cfg, file_cfg)
- file_cfg['filename'] = os.path.splitext(
- os.path.split(file_path)[-1])[0]
- return file_cfg
- def merge_dict(self, args):
- self.cfg = _merge_dict(self.cfg, args)
- def print_cfg(self, print_func=print):
- """Recursively visualize a dict and indenting acrrording by the
- relationship of keys."""
- print_func('----------- Config -----------')
- print_dict(self.cfg, print_func)
- print_func('---------------------------------------------')
- def save(self, p, cfg=None):
- if cfg is None:
- cfg = self.cfg
- with open(p, 'w') as f:
- yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False)
|