__init__.py 563 B

1234567891011121314151617181920
  1. import copy
  2. from importlib import import_module
  3. name_to_module = {
  4. 'DBLoss': '.db_loss',
  5. }
  6. def build_loss(config):
  7. config = copy.deepcopy(config)
  8. module_name = config.pop('name')
  9. assert module_name in name_to_module, Exception(
  10. '{} is not supported. The losses in {} are supportes'.format(
  11. module_name, list(name_to_module.keys())))
  12. module_path = name_to_module[module_name]
  13. module = import_module(module_path, package=__package__)
  14. module_class = getattr(module, module_name)
  15. return module_class(**config)