utils.py 1.0 KB

1234567891011121314151617181920212223
  1. import urllib
  2. import ssl
  3. from tqdm import tqdm
  4. import os
  5. def get_dataset_info(cfg):
  6. download_urls, filenames, check_validity = cfg["download_links"], cfg["filenames"], cfg["check_validity"]
  7. return download_urls, filenames, check_validity
  8. # Modified from torchvision as some datasets cant pass the certificate validity check:
  9. # https://github.com/pytorch/vision/blob/868a3b42f4bffe29e4414ad7e4c7d9d0b4690ecb/torchvision/datasets/utils.py#L27C1-L32C40
  10. def urlretrieve(url, filename, chunk_size=1024 * 32, check_validity=True):
  11. os.makedirs(os.path.dirname(filename), exist_ok=True)
  12. ctx = ssl.create_default_context()
  13. if not check_validity:
  14. ctx.check_hostname = False
  15. ctx.verify_mode = ssl.CERT_NONE
  16. request = urllib.request.Request(url)
  17. with urllib.request.urlopen(request, context=ctx) as response:
  18. with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar:
  19. while chunk := response.read(chunk_size):
  20. fh.write(chunk)
  21. pbar.update(len(chunk))