From e70e59c4d52a6b6aca808e81bc28c2988a080751 Mon Sep 17 00:00:00 2001 From: jerome_Hsieh Date: Sun, 15 Dec 2024 19:04:38 +0800 Subject: [PATCH] Enhance download_and_extract Signed-off-by: jerome_Hsieh --- monai/apps/utils.py | 50 +++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 660c34699e..f79362afee 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -15,6 +15,7 @@ import json import logging import os +import re import shutil import sys import tarfile @@ -24,9 +25,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any from urllib.error import ContentTooShortError, HTTPError, URLError -from urllib.parse import urlparse +from urllib.parse import unquote, urlparse from urllib.request import urlopen, urlretrieve +import requests + from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import @@ -298,6 +301,20 @@ def extractall( ) +def get_filename_from_url(data_url: str): + try: + response = requests.head(data_url, allow_redirects=True) + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename = re.findall("filename=(.+)", content_disposition) + return filename[0].strip('"').strip("'") + else: + filename = _basename(data_url) + return filename + except Exception as e: + raise Exception(f"Error processing URL: {e}") + + def download_and_extract( url: str, filepath: PathLike = "", @@ -327,18 +344,21 @@ def download_and_extract( be False. progress: whether to display progress bar. """ - url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes) - filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes) - if filepath not in ["", "."]: - if filepath_ext == "": - new_filepath = filepath + url_filename_ext - logger.warning( - f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}" - ) - filepath = new_filepath - if filepath_ext and filepath_ext != url_filename_ext: - logger.warning(f"Expected extension {url_filename_ext}, but get {filepath_ext}, may cause unexpected errors!") with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or Path(tmp_dir, _basename(url)).resolve() - download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) - extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) + if not filepath: + filename = get_filename_from_url(url) + full_path = Path(tmp_dir, filename) + elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]: + filename = get_filename_from_url(url) + full_path = Path(os.path.join(filepath, filename)) + logger.warning(f"No compress file extension provided, downloading as: '{full_path}'") + else: + url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes) + filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes) + if filepath_ext != url_filename_ext: + raise ValueError( + f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}" + ) + full_path = Path(filepath) + download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress) + extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base)