diff --git a/WebUI/src/assets/i18n/en-US.json b/WebUI/src/assets/i18n/en-US.json
index 06315107..06acea2c 100644
--- a/WebUI/src/assets/i18n/en-US.json
+++ b/WebUI/src/assets/i18n/en-US.json
@@ -127,10 +127,14 @@
"DOWNLOADER_MODEL":"Model",
"DOWNLOADER_INFO":"Info",
"DOWNLOADER_FILE_SIZE":"Size",
- "DOWNLOADER_GATED":"Gated",
- "DOWNLOADER_GATED_INFO":"Some of the models you are trying to download are gated.",
- "DOWNLOADER_GATED_TOKEN":"Please add your huggingface.co API token in the settings.",
- "DOWNLOADER_GATED_ACCEPT":"Please make sure to visit the model info page and request access. If you have not been granted access to a gated model, the download will fail.",
+ "DOWNLOADER_GATED": "Gated",
+ "DOWNLOADER_GATED_TOKEN": "Check, if you have a valid huggingface.co API token added to your settings. ",
+ "DOWNLOADER_ACCESS_INFO_SINGLE": "You don't have access to the model you want to download",
+ "DOWNLOADER_GATED_ACCEPT_SINGLE": "The model is gated. Please make sure to visit the model info page and request access. ",
+ "DOWNLOADER_ACCESS_ACCEPT_SINGLE": "An inaccessible model cannot be downloaded.",
+ "DOWNLOADER_ACCESS_INFO": "You don't have access to some models you want to download",
+ "DOWNLOADER_GATED_ACCEPT": "Some of the models are gated. Please make sure to visit the model info page and request access. ",
+ "DOWNLOADER_ACCESS_ACCEPT": "Inaccessible models will not be downloaded.",
"DOWNLOADER_REASON":"Reason",
"DOWNLOADER_TERMS":"Visit",
"DOWNLOADER_CONFLICT":"Another download task is currently in progress, and a new task cannot be started. You can cancel the current download task and start a new download task",
diff --git a/WebUI/src/components/DownloadDialog.vue b/WebUI/src/components/DownloadDialog.vue
index 329a0b19..5b78bfcf 100644
--- a/WebUI/src/components/DownloadDialog.vue
+++ b/WebUI/src/components/DownloadDialog.vue
@@ -43,12 +43,21 @@
-
-
{{ languages.DOWNLOADER_GATED_INFO }}
-
- {{ languages.DOWNLOADER_GATED_TOKEN }}
- {{ languages.DOWNLOADER_GATED_ACCEPT }}
-
+
+ {{ languages.DOWNLOADER_ACCESS_INFO_SINGLE }}
+
+ {{ !models.hfTokenIsValid ? languages.DOWNLOADER_GATED_TOKEN : ""}}
+ {{ downloadList.some((i) => i.gated) ? languages.DOWNLOADER_GATED_ACCEPT_SINGLE : ""}}
+ {{ downloadList.some((i) => !i.accessGranted) ? languages.DOWNLOADER_ACCESS_ACCEPT_SINGLE : ""}}
+
+
+
+ {{ languages.DOWNLOADER_ACCESS_INFO }}
+
+ {{ !models.hfTokenIsValid ? languages.DOWNLOADER_GATED_TOKEN : ""}}
+ {{ downloadList.some((i) => i.gated) ? languages.DOWNLOADER_GATED_ACCEPT : ""}}
+ {{ downloadList.some((i) => !i.accessGranted) ? languages.DOWNLOADER_ACCESS_ACCEPT : ""}}
+
{{
i18nState.COM_CANCEL
}}
- {{
i18nState.COM_CONFIRM
}}
@@ -198,18 +207,26 @@ async function showConfirm(downList: DownloadModelParam[], success?: () => void,
"Content-Type": "application/json"
}
});
+ const accessResponse = await fetch(`${globalSetup.apiHost}/api/isAccessGranted`, {
+ method: "POST",
+ body: JSON.stringify([downList,models.hfToken]),
+ headers: {
+ "Content-Type": "application/json"
+ }
+ });
const sizeData = (await sizeResponse.json()) as ApiResponse & { sizeList: StringKV };
const gatedData = (await gatedResponse.json()) as ApiResponse & { gatedList: Record };
+ const accessData = (await accessResponse.json()) as ApiResponse & { accessList: Record };
for (const item of downloadList.value) {
item.size = sizeData.sizeList[`${item.repo_id}_${item.type}`] || "";
item.gated = gatedData.gatedList[item.repo_id] || false;
+ item.accessGranted = accessData.accessList[item.repo_id] || false;
}
downloadList.value = downloadList.value;
sizeRequesting.value = false;
} catch (ex) {
fail && fail({ type: "error", error: ex });
}
- check_model_access()
}
function getInfoUrl(repoId: string, type: number) {
@@ -259,7 +276,7 @@ function getFunctionTip(type: number) {
function download() {
downloding = true;
- allDownloadTip.value = `${i18nState.DOWNLOADER_DONWLOAD_TASK_PROGRESS} 0/${downloadList.value.length}`;
+ allDownloadTip.value = `${i18nState.DOWNLOADER_DONWLOAD_TASK_PROGRESS} 0/${downloadList.value.filter(item => item.accessGranted === true).length}`;
percent.value = 0;
completeCount.value = 0;
abortController = new AbortController();
@@ -281,22 +298,6 @@ function download() {
})
}
-async function check_model_access() {
- const response = await fetch(`${globalSetup.apiHost}/api/checkModelAccess`, {
- method: "POST",
- body: JSON.stringify([downloadList.value[0].repo_id, models.hfToken]),
- headers: {
- "Content-Type": "application/json"
- }
- })
- const data = await response.json()
- console.log("Is URL valid:")
- console.log(data.valid)
- console.log(data.url)
- console.log(data.status)
- return data.valid
-}
-
function cancelConfirm() {
downloadReject && downloadReject({ type: "cancelConfrim" });
emits("close");
diff --git a/WebUI/src/env.d.ts b/WebUI/src/env.d.ts
index 41b2f508..0dc1f2b8 100644
--- a/WebUI/src/env.d.ts
+++ b/WebUI/src/env.d.ts
@@ -300,7 +300,7 @@ type CheckModelExistParam = {
type DownloadModelParam = CheckModelExistParam
-type DownloadModelRender = { size: string, gated?: boolean } & CheckModelExistParam
+type DownloadModelRender = { size: string, gated?: boolean, accessGranted?: boolean } & CheckModelExistParam
type CheckModelExistResult = {
exist: boolean
diff --git a/service/model_download_adpater.py b/service/model_download_adpater.py
index 15430c8a..7b4ab5db 100644
--- a/service/model_download_adpater.py
+++ b/service/model_download_adpater.py
@@ -115,16 +115,17 @@ def __start_download(self, list: list):
break
if self.has_error:
break
- if item["type"] == 4:
- self.file_downloader.download_file(
- realesrgan.ESRGAN_MODEL_URL,
- os.path.join(
- utils.get_model_path(item["type"]),
- os.path.basename(realesrgan.ESRGAN_MODEL_URL),
- ),
- )
- else:
- self.hf_downloader.download(item["repo_id"], item["type"])
+ if item["accessGranted"]:
+ if item["type"] == 4:
+ self.file_downloader.download_file(
+ realesrgan.ESRGAN_MODEL_URL,
+ os.path.join(
+ utils.get_model_path(item["type"]),
+ os.path.basename(realesrgan.ESRGAN_MODEL_URL)
+ ),
+ )
+ else:
+ self.hf_downloader.download(item["repo_id"], item["type"])
self.put_msg({"type": "allComplete"})
self.finish = True
except Exception as ex:
diff --git a/service/model_downloader.py b/service/model_downloader.py
index 96a557ae..567d0da6 100644
--- a/service/model_downloader.py
+++ b/service/model_downloader.py
@@ -301,16 +301,27 @@ def init_download(self, file: HFDonloadItem):
return response, fw
- def is_token_valid(self, repo_id: str):
+ def is_access_granted(self, repo_id: str, model_type):
+
headers={}
if (self.hf_token is not None):
headers["Authorization"] = f"Bearer {self.hf_token}"
- name = self.fs.ls(repo_id, detail=True)[0].get("name")
- url = hf_hub_url(repo_id=repo_id, filename = path.basename(path.relpath(name, repo_id)))
- response = requests.get(url, stream=True, verify=False, headers=headers)
+ self.file_queue = queue.Queue()
+ self.repo_id = repo_id
+ self.save_path = path.join(utils.get_model_path(model_type))
+ self.save_path_tmp = path.abspath(
+ path.join(self.save_path, repo_id.replace("/", "---") + "_tmp")
+ )
+
+ file_list = list()
+ self.enum_file_list(file_list, repo_id, model_type)
+ self.build_queue(file_list)
+ file = self.file_queue.get_nowait()
+
+ response = requests.get(file.url, stream=True, verify=False, headers=headers)
- return response.status_code == 200, url, response.status_code
+ return response.status_code == 200
def download_model_file(self):
diff --git a/service/web_api.py b/service/web_api.py
index 1d5ca019..07c17ac0 100644
--- a/service/web_api.py
+++ b/service/web_api.py
@@ -195,19 +195,6 @@ def check_model_exist():
result_list.append({"repo_id": repo_id, "type": type, "exist": exist})
return jsonify({"code": 0, "message": "success", "exists": result_list})
-@app.route("/api/checkModelAccess", methods=["POST"])
-def checkModelAccess():
- repo_id, hf_token = request.get_json()
- downloader = HFPlaygroundDownloader(hf_token)
- valid, url, status = downloader.is_token_valid(repo_id)
- return jsonify(
- {
- "valid": valid,
- "url": url,
- "status": status
- }
- )
-
size_cache = dict()
lock = threading.Lock()
@@ -226,6 +213,16 @@ def is_model_gated():
}
)
+@app.route("/api/isAccessGranted", methods=["POST"])
+def is_access_granted():
+ list, hf_token = request.get_json()
+ downloader = HFPlaygroundDownloader(hf_token)
+ accessGranted = { item["repo_id"] : downloader.is_access_granted(item["repo_id"], item["type"]) for item in list }
+ return jsonify(
+ {
+ "accessList": accessGranted
+ }
+ )
@app.route("/api/getModelSize", methods=["POST"])
def get_model_size():