train_det.py 895 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(__dir__)
  5. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  6. from tools.engine.config import Config
  7. from tools.engine.trainer import Trainer
  8. from tools.utility import ArgsParser
  9. def parse_args():
  10. parser = ArgsParser()
  11. parser.add_argument(
  12. '--eval',
  13. action='store_true',
  14. default=True,
  15. help='Whether to perform evaluation in train',
  16. )
  17. args = parser.parse_args()
  18. return args
  19. def main():
  20. FLAGS = parse_args()
  21. cfg = Config(FLAGS.config)
  22. FLAGS = vars(FLAGS)
  23. opt = FLAGS.pop('opt')
  24. cfg.merge_dict(FLAGS)
  25. cfg.merge_dict(opt)
  26. trainer = Trainer(cfg,
  27. mode='train_eval' if FLAGS['eval'] else 'train',
  28. task='det')
  29. trainer.train()
  30. if __name__ == '__main__':
  31. main()