__init__.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import copy
  2. import torch
  3. from torch import nn
  4. __all__ = ['build_optimizer']
  5. def param_groups_weight_decay(model: nn.Module,
  6. weight_decay=1e-5,
  7. no_weight_decay_list=()):
  8. no_weight_decay_list = set(no_weight_decay_list)
  9. decay = []
  10. no_decay = []
  11. for name, param in model.named_parameters():
  12. if not param.requires_grad:
  13. continue
  14. if param.ndim <= 1 or name.endswith(
  15. '.bias') or any(nd in name for nd in no_weight_decay_list):
  16. no_decay.append(param)
  17. else:
  18. decay.append(param)
  19. return [
  20. {
  21. 'params': no_decay,
  22. 'weight_decay': 0.0
  23. },
  24. {
  25. 'params': decay,
  26. 'weight_decay': weight_decay
  27. },
  28. ]
  29. def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch,
  30. model):
  31. from . import lr
  32. config = copy.deepcopy(optim_config)
  33. if isinstance(model, nn.Module):
  34. # a model was passed in, extract parameters and add weight decays to appropriate layers
  35. weight_decay = config.get('weight_decay', 0.0)
  36. filter_bias_and_bn = (config.pop('filter_bias_and_bn')
  37. if 'filter_bias_and_bn' in config else False)
  38. if weight_decay > 0.0 and filter_bias_and_bn:
  39. no_weight_decay = {}
  40. if hasattr(model, 'no_weight_decay'):
  41. no_weight_decay = model.no_weight_decay()
  42. parameters = param_groups_weight_decay(model, weight_decay,
  43. no_weight_decay)
  44. config['weight_decay'] = 0.0
  45. # print('debug adamw')
  46. else:
  47. parameters = model.parameters()
  48. else:
  49. # iterable of parameters or param groups passed in
  50. parameters = model
  51. optim = getattr(torch.optim, config.pop('name'))(params=parameters,
  52. **config)
  53. lr_config = copy.deepcopy(lr_scheduler_config)
  54. lr_config.update({
  55. 'epochs': epochs,
  56. 'step_each_epoch': step_each_epoch,
  57. 'lr': config['lr']
  58. })
  59. lr_scheduler = getattr(lr,
  60. lr_config.pop('name'))(**lr_config)(optimizer=optim)
  61. return optim, lr_scheduler