From ef4c73956b5ed6ec5f0bb88e3024de53b7e45a91 Mon Sep 17 00:00:00 2001 From: julianbollig Date: Thu, 21 Nov 2024 17:19:36 +0100 Subject: [PATCH] 29 - Implemented download functionality based on access status (also compatible with multiple downloads) Signed-off-by: julianbollig --- WebUI/src/assets/i18n/en-US.json | 12 ++++-- WebUI/src/components/DownloadDialog.vue | 51 +++++++++++++------------ WebUI/src/env.d.ts | 2 +- service/model_download_adpater.py | 21 +++++----- service/model_downloader.py | 21 +++++++--- service/web_api.py | 23 +++++------ 6 files changed, 72 insertions(+), 58 deletions(-) diff --git a/WebUI/src/assets/i18n/en-US.json b/WebUI/src/assets/i18n/en-US.json index 06315107..06acea2c 100644 --- a/WebUI/src/assets/i18n/en-US.json +++ b/WebUI/src/assets/i18n/en-US.json @@ -127,10 +127,14 @@ "DOWNLOADER_MODEL":"Model", "DOWNLOADER_INFO":"Info", "DOWNLOADER_FILE_SIZE":"Size", - "DOWNLOADER_GATED":"Gated", - "DOWNLOADER_GATED_INFO":"Some of the models you are trying to download are gated.", - "DOWNLOADER_GATED_TOKEN":"Please add your huggingface.co API token in the settings.", - "DOWNLOADER_GATED_ACCEPT":"Please make sure to visit the model info page and request access. If you have not been granted access to a gated model, the download will fail.", + "DOWNLOADER_GATED": "Gated", + "DOWNLOADER_GATED_TOKEN": "Check, if you have a valid huggingface.co API token added to your settings. ", + "DOWNLOADER_ACCESS_INFO_SINGLE": "You don't have access to the model you want to download", + "DOWNLOADER_GATED_ACCEPT_SINGLE": "The model is gated. Please make sure to visit the model info page and request access. ", + "DOWNLOADER_ACCESS_ACCEPT_SINGLE": "An inaccessible model cannot be downloaded.", + "DOWNLOADER_ACCESS_INFO": "You don't have access to some models you want to download", + "DOWNLOADER_GATED_ACCEPT": "Some of the models are gated. Please make sure to visit the model info page and request access. ", + "DOWNLOADER_ACCESS_ACCEPT": "Inaccessible models will not be downloaded.", "DOWNLOADER_REASON":"Reason", "DOWNLOADER_TERMS":"Visit", "DOWNLOADER_CONFLICT":"Another download task is currently in progress, and a new task cannot be started. You can cancel the current download task and start a new download task", diff --git a/WebUI/src/components/DownloadDialog.vue b/WebUI/src/components/DownloadDialog.vue index 329a0b19..5b78bfcf 100644 --- a/WebUI/src/components/DownloadDialog.vue +++ b/WebUI/src/components/DownloadDialog.vue @@ -43,12 +43,21 @@ -
- {{ languages.DOWNLOADER_GATED_INFO }} -
    -
  • {{ languages.DOWNLOADER_GATED_TOKEN }}
  • -
  • {{ languages.DOWNLOADER_GATED_ACCEPT }}
  • -
+
+ {{ languages.DOWNLOADER_ACCESS_INFO_SINGLE }} + + {{ !models.hfTokenIsValid ? languages.DOWNLOADER_GATED_TOKEN : ""}} + {{ downloadList.some((i) => i.gated) ? languages.DOWNLOADER_GATED_ACCEPT_SINGLE : ""}} + {{ downloadList.some((i) => !i.accessGranted) ? languages.DOWNLOADER_ACCESS_ACCEPT_SINGLE : ""}} + +
+
+ {{ languages.DOWNLOADER_ACCESS_INFO }} + + {{ !models.hfTokenIsValid ? languages.DOWNLOADER_GATED_TOKEN : ""}} + {{ downloadList.some((i) => i.gated) ? languages.DOWNLOADER_GATED_ACCEPT : ""}} + {{ downloadList.some((i) => !i.accessGranted) ? languages.DOWNLOADER_ACCESS_ACCEPT : ""}} +
- @@ -198,18 +207,26 @@ async function showConfirm(downList: DownloadModelParam[], success?: () => void, "Content-Type": "application/json" } }); + const accessResponse = await fetch(`${globalSetup.apiHost}/api/isAccessGranted`, { + method: "POST", + body: JSON.stringify([downList,models.hfToken]), + headers: { + "Content-Type": "application/json" + } + }); const sizeData = (await sizeResponse.json()) as ApiResponse & { sizeList: StringKV }; const gatedData = (await gatedResponse.json()) as ApiResponse & { gatedList: Record }; + const accessData = (await accessResponse.json()) as ApiResponse & { accessList: Record }; for (const item of downloadList.value) { item.size = sizeData.sizeList[`${item.repo_id}_${item.type}`] || ""; item.gated = gatedData.gatedList[item.repo_id] || false; + item.accessGranted = accessData.accessList[item.repo_id] || false; } downloadList.value = downloadList.value; sizeRequesting.value = false; } catch (ex) { fail && fail({ type: "error", error: ex }); } - check_model_access() } function getInfoUrl(repoId: string, type: number) { @@ -259,7 +276,7 @@ function getFunctionTip(type: number) { function download() { downloding = true; - allDownloadTip.value = `${i18nState.DOWNLOADER_DONWLOAD_TASK_PROGRESS} 0/${downloadList.value.length}`; + allDownloadTip.value = `${i18nState.DOWNLOADER_DONWLOAD_TASK_PROGRESS} 0/${downloadList.value.filter(item => item.accessGranted === true).length}`; percent.value = 0; completeCount.value = 0; abortController = new AbortController(); @@ -281,22 +298,6 @@ function download() { }) } -async function check_model_access() { - const response = await fetch(`${globalSetup.apiHost}/api/checkModelAccess`, { - method: "POST", - body: JSON.stringify([downloadList.value[0].repo_id, models.hfToken]), - headers: { - "Content-Type": "application/json" - } - }) - const data = await response.json() - console.log("Is URL valid:") - console.log(data.valid) - console.log(data.url) - console.log(data.status) - return data.valid -} - function cancelConfirm() { downloadReject && downloadReject({ type: "cancelConfrim" }); emits("close"); diff --git a/WebUI/src/env.d.ts b/WebUI/src/env.d.ts index 41b2f508..0dc1f2b8 100644 --- a/WebUI/src/env.d.ts +++ b/WebUI/src/env.d.ts @@ -300,7 +300,7 @@ type CheckModelExistParam = { type DownloadModelParam = CheckModelExistParam -type DownloadModelRender = { size: string, gated?: boolean } & CheckModelExistParam +type DownloadModelRender = { size: string, gated?: boolean, accessGranted?: boolean } & CheckModelExistParam type CheckModelExistResult = { exist: boolean diff --git a/service/model_download_adpater.py b/service/model_download_adpater.py index 15430c8a..7b4ab5db 100644 --- a/service/model_download_adpater.py +++ b/service/model_download_adpater.py @@ -115,16 +115,17 @@ def __start_download(self, list: list): break if self.has_error: break - if item["type"] == 4: - self.file_downloader.download_file( - realesrgan.ESRGAN_MODEL_URL, - os.path.join( - utils.get_model_path(item["type"]), - os.path.basename(realesrgan.ESRGAN_MODEL_URL), - ), - ) - else: - self.hf_downloader.download(item["repo_id"], item["type"]) + if item["accessGranted"]: + if item["type"] == 4: + self.file_downloader.download_file( + realesrgan.ESRGAN_MODEL_URL, + os.path.join( + utils.get_model_path(item["type"]), + os.path.basename(realesrgan.ESRGAN_MODEL_URL) + ), + ) + else: + self.hf_downloader.download(item["repo_id"], item["type"]) self.put_msg({"type": "allComplete"}) self.finish = True except Exception as ex: diff --git a/service/model_downloader.py b/service/model_downloader.py index 96a557ae..567d0da6 100644 --- a/service/model_downloader.py +++ b/service/model_downloader.py @@ -301,16 +301,27 @@ def init_download(self, file: HFDonloadItem): return response, fw - def is_token_valid(self, repo_id: str): + def is_access_granted(self, repo_id: str, model_type): + headers={} if (self.hf_token is not None): headers["Authorization"] = f"Bearer {self.hf_token}" - name = self.fs.ls(repo_id, detail=True)[0].get("name") - url = hf_hub_url(repo_id=repo_id, filename = path.basename(path.relpath(name, repo_id))) - response = requests.get(url, stream=True, verify=False, headers=headers) + self.file_queue = queue.Queue() + self.repo_id = repo_id + self.save_path = path.join(utils.get_model_path(model_type)) + self.save_path_tmp = path.abspath( + path.join(self.save_path, repo_id.replace("/", "---") + "_tmp") + ) + + file_list = list() + self.enum_file_list(file_list, repo_id, model_type) + self.build_queue(file_list) + file = self.file_queue.get_nowait() + + response = requests.get(file.url, stream=True, verify=False, headers=headers) - return response.status_code == 200, url, response.status_code + return response.status_code == 200 def download_model_file(self): diff --git a/service/web_api.py b/service/web_api.py index 1d5ca019..07c17ac0 100644 --- a/service/web_api.py +++ b/service/web_api.py @@ -195,19 +195,6 @@ def check_model_exist(): result_list.append({"repo_id": repo_id, "type": type, "exist": exist}) return jsonify({"code": 0, "message": "success", "exists": result_list}) -@app.route("/api/checkModelAccess", methods=["POST"]) -def checkModelAccess(): - repo_id, hf_token = request.get_json() - downloader = HFPlaygroundDownloader(hf_token) - valid, url, status = downloader.is_token_valid(repo_id) - return jsonify( - { - "valid": valid, - "url": url, - "status": status - } - ) - size_cache = dict() lock = threading.Lock() @@ -226,6 +213,16 @@ def is_model_gated(): } ) +@app.route("/api/isAccessGranted", methods=["POST"]) +def is_access_granted(): + list, hf_token = request.get_json() + downloader = HFPlaygroundDownloader(hf_token) + accessGranted = { item["repo_id"] : downloader.is_access_granted(item["repo_id"], item["type"]) for item in list } + return jsonify( + { + "accessList": accessGranted + } + ) @app.route("/api/getModelSize", methods=["POST"]) def get_model_size():