Skip to content

Commit

Permalink
fix #8306
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Jan 20, 2025
1 parent e39bad9 commit 8ad5964
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files"


def _get_ngc_private_base_url(repo: str) -> str:
Expand Down Expand Up @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
return name


def _get_all_download_files(request_url, headers=None) -> list[str]:
if not has_requests:
raise ValueError("requests package is required, please install it.")
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response.raise_for_status()
model_info = json.loads(response.text)

if not isinstance(model_info, dict) or "modelFiles" not in model_info:
raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.")

model_files = model_info["modelFiles"]
return [f["path"] for f in model_files]


def _download_from_ngc(
download_path: Path,
filename: str,
Expand All @@ -229,12 +244,12 @@ def _download_from_ngc(
# ensure prefix is contained
filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
extract_path = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
filepath = download_path / filename
filepath.mkdir(parents=True, exist_ok=True)
for file in _get_all_download_files(url):
download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress)


def _download_from_ngc_private(
Expand Down

0 comments on commit 8ad5964

Please sign in to comment.