download_dataset.py 1.1 KB

1234567891011121314151617181920212223242526272829303132
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(__dir__)
  5. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  6. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..')))
  7. from engine import Config
  8. from utility import ArgsParser
  9. import download.utils
  10. from torchvision.datasets.utils import extract_archive
  11. def main(cfg):
  12. urls, filename_paths, check_validity = download.utils.get_dataset_info(cfg)
  13. for url, filename_path in zip(urls, filename_paths):
  14. print(f"Downloading {filename_path} from {url} . . .")
  15. download.utils.urlretrieve(url=url, filename=filename_path, check_validity=check_validity)
  16. if not filename_path.endswith(".mdb"):
  17. extract_archive(from_path=filename_path, to_path=cfg["root"], remove_finished=True)
  18. print("Downloads finished!")
  19. if __name__ == "__main__":
  20. FLAGS = ArgsParser().parse_args()
  21. cfg = Config(FLAGS.config)
  22. FLAGS = vars(FLAGS)
  23. opt = FLAGS.pop('opt')
  24. cfg.merge_dict(FLAGS)
  25. cfg.merge_dict(opt)
  26. main(cfg.cfg)