Skip to content

Commit

Permalink
29 - Implemented download functionality based on access status (also …
Browse files Browse the repository at this point in the history
…compatible with multiple downloads)

Signed-off-by: julianbollig <[email protected]>
  • Loading branch information
julianbollig committed Nov 26, 2024
1 parent 3d11d37 commit ef4c739
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 58 deletions.
12 changes: 8 additions & 4 deletions WebUI/src/assets/i18n/en-US.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,14 @@
"DOWNLOADER_MODEL":"Model",
"DOWNLOADER_INFO":"Info",
"DOWNLOADER_FILE_SIZE":"Size",
"DOWNLOADER_GATED":"Gated",
"DOWNLOADER_GATED_INFO":"Some of the models you are trying to download are gated.",
"DOWNLOADER_GATED_TOKEN":"Please add your huggingface.co API token in the settings.",
"DOWNLOADER_GATED_ACCEPT":"Please make sure to visit the model info page and request access. If you have not been granted access to a gated model, the download will fail.",
"DOWNLOADER_GATED": "Gated",
"DOWNLOADER_GATED_TOKEN": "Check, if you have a valid huggingface.co API token added to your settings. ",
"DOWNLOADER_ACCESS_INFO_SINGLE": "You don't have access to the model you want to download",
"DOWNLOADER_GATED_ACCEPT_SINGLE": "The model is gated. Please make sure to visit the model info page and request access. ",
"DOWNLOADER_ACCESS_ACCEPT_SINGLE": "An inaccessible model cannot be downloaded.",
"DOWNLOADER_ACCESS_INFO": "You don't have access to some models you want to download",
"DOWNLOADER_GATED_ACCEPT": "Some of the models are gated. Please make sure to visit the model info page and request access. ",
"DOWNLOADER_ACCESS_ACCEPT": "Inaccessible models will not be downloaded.",
"DOWNLOADER_REASON":"Reason",
"DOWNLOADER_TERMS":"Visit",
"DOWNLOADER_CONFLICT":"Another download task is currently in progress, and a new task cannot be started. You can cancel the current download task and start a new download task",
Expand Down
51 changes: 26 additions & 25 deletions WebUI/src/components/DownloadDialog.vue
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@
</tr>
</tbody>
</table>
<div v-if="downloadList.some((i) => i.gated)" class="flex flex-col items-center gap-2 p-4 border border-red-600 bg-red-600/10 rounded-lg">
<span class="font-bold mx-4">{{ languages.DOWNLOADER_GATED_INFO }}</span>
<ul>
<li v-if="!models.hfTokenIsValid" class="text-left">{{ languages.DOWNLOADER_GATED_TOKEN }}</li>
<li class="text-left">{{ languages.DOWNLOADER_GATED_ACCEPT }}</li>
</ul>
<div v-if="downloadList.some((i) => i.gated && !i.accessGranted) && downloadList.length === 1" class="flex flex-col items-center gap-2 p-4 border border-red-600 bg-red-600/10 rounded-lg">
<span class="font-bold mx-4">{{ languages.DOWNLOADER_ACCESS_INFO_SINGLE }}</span>
<span class="text-left">
{{ !models.hfTokenIsValid ? languages.DOWNLOADER_GATED_TOKEN : ""}}
{{ downloadList.some((i) => i.gated) ? languages.DOWNLOADER_GATED_ACCEPT_SINGLE : ""}}
{{ downloadList.some((i) => !i.accessGranted) ? languages.DOWNLOADER_ACCESS_ACCEPT_SINGLE : ""}}
</span>
</div>
<div v-if="downloadList.some((i) => i.gated && !i.accessGranted) && downloadList.length > 1" class="flex flex-col items-center gap-2 p-4 border border-red-600 bg-red-600/10 rounded-lg">
<span class="font-bold mx-4">{{ languages.DOWNLOADER_ACCESS_INFO }}</span>
<span class="text-left">
{{ !models.hfTokenIsValid ? languages.DOWNLOADER_GATED_TOKEN : ""}}
{{ downloadList.some((i) => i.gated) ? languages.DOWNLOADER_GATED_ACCEPT : ""}}
{{ downloadList.some((i) => !i.accessGranted) ? languages.DOWNLOADER_ACCESS_ACCEPT : ""}}
</span>
</div>
<div class="flex items-center gap-2">
<button class="v-checkbox-control flex-none w-5 h-5"
Expand All @@ -60,7 +69,7 @@
<button @click="cancelConfirm" class="bg-color-control-bg py-1 px-4 rounded">{{
i18nState.COM_CANCEL
}}</button>
<button @click="confirmDownload" :disabled="sizeRequesting || !readTerms"
<button @click="confirmDownload" :disabled="sizeRequesting || !readTerms || downloadList.every((i) => !i.accessGranted)"
class="bg-color-active py-1 px-4 rounded">{{
i18nState.COM_CONFIRM
}}</button>
Expand Down Expand Up @@ -198,18 +207,26 @@ async function showConfirm(downList: DownloadModelParam[], success?: () => void,
"Content-Type": "application/json"
}
});
const accessResponse = await fetch(`${globalSetup.apiHost}/api/isAccessGranted`, {
method: "POST",
body: JSON.stringify([downList,models.hfToken]),
headers: {
"Content-Type": "application/json"
}
});
const sizeData = (await sizeResponse.json()) as ApiResponse & { sizeList: StringKV };
const gatedData = (await gatedResponse.json()) as ApiResponse & { gatedList: Record<string, boolean> };
const accessData = (await accessResponse.json()) as ApiResponse & { accessList: Record<string, boolean> };
for (const item of downloadList.value) {
item.size = sizeData.sizeList[`${item.repo_id}_${item.type}`] || "";
item.gated = gatedData.gatedList[item.repo_id] || false;
item.accessGranted = accessData.accessList[item.repo_id] || false;
}
downloadList.value = downloadList.value;
sizeRequesting.value = false;
} catch (ex) {
fail && fail({ type: "error", error: ex });
}
check_model_access()
}
function getInfoUrl(repoId: string, type: number) {
Expand Down Expand Up @@ -259,7 +276,7 @@ function getFunctionTip(type: number) {
function download() {
downloding = true;
allDownloadTip.value = `${i18nState.DOWNLOADER_DONWLOAD_TASK_PROGRESS} 0/${downloadList.value.length}`;
allDownloadTip.value = `${i18nState.DOWNLOADER_DONWLOAD_TASK_PROGRESS} 0/${downloadList.value.filter(item => item.accessGranted === true).length}`;
percent.value = 0;
completeCount.value = 0;
abortController = new AbortController();
Expand All @@ -281,22 +298,6 @@ function download() {
})
}
async function check_model_access() {
const response = await fetch(`${globalSetup.apiHost}/api/checkModelAccess`, {
method: "POST",
body: JSON.stringify([downloadList.value[0].repo_id, models.hfToken]),
headers: {
"Content-Type": "application/json"
}
})
const data = await response.json()
console.log("Is URL valid:")
console.log(data.valid)
console.log(data.url)
console.log(data.status)
return data.valid
}
function cancelConfirm() {
downloadReject && downloadReject({ type: "cancelConfrim" });
emits("close");
Expand Down
2 changes: 1 addition & 1 deletion WebUI/src/env.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ type CheckModelExistParam = {

type DownloadModelParam = CheckModelExistParam

type DownloadModelRender = { size: string, gated?: boolean } & CheckModelExistParam
type DownloadModelRender = { size: string, gated?: boolean, accessGranted?: boolean } & CheckModelExistParam

type CheckModelExistResult = {
exist: boolean
Expand Down
21 changes: 11 additions & 10 deletions service/model_download_adpater.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,17 @@ def __start_download(self, list: list):
break
if self.has_error:
break
if item["type"] == 4:
self.file_downloader.download_file(
realesrgan.ESRGAN_MODEL_URL,
os.path.join(
utils.get_model_path(item["type"]),
os.path.basename(realesrgan.ESRGAN_MODEL_URL),
),
)
else:
self.hf_downloader.download(item["repo_id"], item["type"])
if item["accessGranted"]:
if item["type"] == 4:
self.file_downloader.download_file(
realesrgan.ESRGAN_MODEL_URL,
os.path.join(
utils.get_model_path(item["type"]),
os.path.basename(realesrgan.ESRGAN_MODEL_URL)
),
)
else:
self.hf_downloader.download(item["repo_id"], item["type"])
self.put_msg({"type": "allComplete"})
self.finish = True
except Exception as ex:
Expand Down
21 changes: 16 additions & 5 deletions service/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,27 @@ def init_download(self, file: HFDonloadItem):

return response, fw

def is_token_valid(self, repo_id: str):
def is_access_granted(self, repo_id: str, model_type):

headers={}
if (self.hf_token is not None):
headers["Authorization"] = f"Bearer {self.hf_token}"

name = self.fs.ls(repo_id, detail=True)[0].get("name")
url = hf_hub_url(repo_id=repo_id, filename = path.basename(path.relpath(name, repo_id)))
response = requests.get(url, stream=True, verify=False, headers=headers)
self.file_queue = queue.Queue()
self.repo_id = repo_id
self.save_path = path.join(utils.get_model_path(model_type))
self.save_path_tmp = path.abspath(
path.join(self.save_path, repo_id.replace("/", "---") + "_tmp")
)

file_list = list()
self.enum_file_list(file_list, repo_id, model_type)
self.build_queue(file_list)
file = self.file_queue.get_nowait()

response = requests.get(file.url, stream=True, verify=False, headers=headers)

Check failure

Code scanning / Bandit

Call to requests with verify=False disabling SSL certificate checks, security issue. Error

Call to requests with verify=False disabling SSL certificate checks, security issue.

Check warning

Code scanning / Bandit

Call to requests without timeout Warning

Call to requests without timeout

return response.status_code == 200, url, response.status_code
return response.status_code == 200


def download_model_file(self):
Expand Down
23 changes: 10 additions & 13 deletions service/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,6 @@ def check_model_exist():
result_list.append({"repo_id": repo_id, "type": type, "exist": exist})
return jsonify({"code": 0, "message": "success", "exists": result_list})

@app.route("/api/checkModelAccess", methods=["POST"])
def checkModelAccess():
repo_id, hf_token = request.get_json()
downloader = HFPlaygroundDownloader(hf_token)
valid, url, status = downloader.is_token_valid(repo_id)
return jsonify(
{
"valid": valid,
"url": url,
"status": status
}
)

size_cache = dict()
lock = threading.Lock()

Expand All @@ -226,6 +213,16 @@ def is_model_gated():
}
)

@app.route("/api/isAccessGranted", methods=["POST"])
def is_access_granted():
list, hf_token = request.get_json()
downloader = HFPlaygroundDownloader(hf_token)
accessGranted = { item["repo_id"] : downloader.is_access_granted(item["repo_id"], item["type"]) for item in list }
return jsonify(
{
"accessList": accessGranted
}
)

@app.route("/api/getModelSize", methods=["POST"])
def get_model_size():
Expand Down

0 comments on commit ef4c739

Please sign in to comment.