Skip to content

Commit

Permalink
Enhance download_and_extract
Browse files Browse the repository at this point in the history
Signed-off-by: jerome_Hsieh <[email protected]>
  • Loading branch information
Jerome-Hsieh committed Dec 15, 2024
1 parent a9a0171 commit e70e59c
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import re
import shutil
import sys
import tarfile
Expand All @@ -24,9 +25,11 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.error import ContentTooShortError, HTTPError, URLError
from urllib.parse import urlparse
from urllib.parse import unquote, urlparse
from urllib.request import urlopen, urlretrieve

import requests

from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import

Expand Down Expand Up @@ -298,6 +301,20 @@ def extractall(
)


def get_filename_from_url(data_url: str):
try:
response = requests.head(data_url, allow_redirects=True)
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename = re.findall("filename=(.+)", content_disposition)
return filename[0].strip('"').strip("'")
else:
filename = _basename(data_url)
return filename
except Exception as e:
raise Exception(f"Error processing URL: {e}")


def download_and_extract(
url: str,
filepath: PathLike = "",
Expand Down Expand Up @@ -327,18 +344,21 @@ def download_and_extract(
be False.
progress: whether to display progress bar.
"""
url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes)
filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes)
if filepath not in ["", "."]:
if filepath_ext == "":
new_filepath = filepath + url_filename_ext
logger.warning(
f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}"
)
filepath = new_filepath
if filepath_ext and filepath_ext != url_filename_ext:
logger.warning(f"Expected extension {url_filename_ext}, but get {filepath_ext}, may cause unexpected errors!")
with tempfile.TemporaryDirectory() as tmp_dir:
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
if not filepath:
filename = get_filename_from_url(url)
full_path = Path(tmp_dir, filename)
elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]:
filename = get_filename_from_url(url)
full_path = Path(os.path.join(filepath, filename))
logger.warning(f"No compress file extension provided, downloading as: '{full_path}'")
else:
url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes)
filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes)
if filepath_ext != url_filename_ext:
raise ValueError(
f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}"
)
full_path = Path(filepath)
download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base)

0 comments on commit e70e59c

Please sign in to comment.