Skip to content

Commit

Permalink
Revert "fix: get model info concurrently" (#6)
Browse files Browse the repository at this point in the history
This reverts commit 5abbce1.
  • Loading branch information
bojiang authored Dec 12, 2024
1 parent 938be88 commit 45cb9e3
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 159 deletions.
71 changes: 45 additions & 26 deletions nodes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
import time
import uuid
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Union

import folder_paths
from aiohttp import web
from server import PromptServer

from comfy_pack.hash import ModelHashes
from comfy_pack.model_helper import ModelEntry, get_model_entry
from comfy_pack.hash import async_batch_get_sha256
from comfy_pack.model_helper import alookup_model_source

ZPath = Union[Path, zipfile.Path]
TEMP_FOLDER = Path(__file__).parent.parent / "temp"
Expand Down Expand Up @@ -96,8 +95,9 @@ async def _get_models(
store_models: bool = False,
workflow_api: dict | None = None,
model_filter: set[str] | None = None,
ensure_source: bool = True,
) -> list[ModelEntry]:
ensure_sha=True,
ensure_source=True,
) -> list:
proc = await asyncio.subprocess.create_subprocess_exec(
"git",
"ls-files",
Expand All @@ -107,33 +107,51 @@ async def _get_models(
)
stdout, _ = await proc.communicate()

models = []
model_filenames = [
os.path.abspath(line)
for line in stdout.decode().splitlines()
if not os.path.basename(line).startswith(".")
]
model_hashes = ModelHashes()
await model_hashes.load()
with ThreadPoolExecutor() as executor:
models = await asyncio.gather(
*(
get_model_entry(
f,
executor,
model_hashes,
store_models=store_models,
ensure_source=ensure_source,
)
for f in model_filenames
)
)
# save the hashes to cache
await model_hashes.save()
for model in models:
model["disabled"] = (
model_filter is not None and model["filename"] not in model_filter
model_hashes = await async_batch_get_sha256(
model_filenames,
cache_only=not (ensure_sha or store_models),
)

for filename in model_filenames:
relpath = os.path.relpath(filename, folder_paths.base_path)

model_data = {
"filename": relpath,
"size": os.path.getsize(filename),
"atime": os.path.getatime(filename),
"ctime": os.path.getctime(filename),
"disabled": relpath not in model_filter
if model_filter is not None
else False,
"sha256": model_hashes.get(filename),
}

model_data["source"] = await alookup_model_source(
model_data["sha256"],
cache_only=not ensure_source,
)
if workflow_api:

if store_models:
import bentoml

model_tag = f'cpack-model:{model_data["sha256"][:16]}'
try:
model = bentoml.models.get(model_tag)
except bentoml.exceptions.NotFound:
with bentoml.models.create(
model_tag, labels={"filename": relpath}
) as model:
shutil.copy(filename, model.path_of("model.bin"))
model_data["model_tag"] = model_tag
models.append(model_data)
if workflow_api:
for model in models:
model["refered"] = _is_file_refered(Path(model["filename"]), workflow_api)
return models

Expand Down Expand Up @@ -459,6 +477,7 @@ async def get_models(request):
data = await request.json()
models = await _get_models(
workflow_api=data.get("workflow_api"),
ensure_sha=False,
ensure_source=False,
)
return web.json_response({"models": models})
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ bentoml
fastapi
comfy-cli
googlesearch-python
anyio
173 changes: 116 additions & 57 deletions src/comfy_pack/hash.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,122 @@
from __future__ import annotations

import contextlib
import hashlib
import asyncio
import json
import os
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import TypedDict

import anyio
from functools import partial
from typing import Dict, List

from .const import SHA_CACHE_FILE

CALC_CMD = """
import hashlib
import sys
filepath = sys.argv[1]
chunk_size = int(sys.argv[2])
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
sha256.update(chunk)
print(sha256.hexdigest())
"""


def calculate_sha256_worker(filepath: str, chunk_size: int = 4 * 1024 * 1024) -> str:
"""Calculate SHA-256 in a separate process"""
result = subprocess.run(
[sys.executable, "-c", CALC_CMD, filepath, str(chunk_size)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
assert result.returncode == 0, result.stderr
return result.stdout.strip()


def get_sha256(filepath: str) -> str:
return batch_get_sha256([filepath])[filepath]


def async_get_sha256(filepath: str) -> str:
return asyncio.run(async_batch_get_sha256([filepath]))[filepath]


def batch_get_sha256(filepaths: List[str], cache_only: bool = False) -> Dict[str, str]:
return asyncio.run(async_batch_get_sha256(filepaths, cache_only=cache_only))


async def async_batch_get_sha256(
filepaths: List[str],
cache_only: bool = False,
) -> Dict[str, str]:
# Load cache
cache = {}
if SHA_CACHE_FILE.exists():
try:
with SHA_CACHE_FILE.open("r") as f:
cache = json.load(f)
except (json.JSONDecodeError, IOError):
pass

# Initialize process pool
max_workers = max(1, (os.cpu_count() or 1))

# Process files
results = {}
new_cache = {}
async with asyncio.Lock():
with ThreadPoolExecutor(max_workers=max_workers) as pool:
loop = asyncio.get_event_loop()

for filepath in filepaths:
if not os.path.exists(filepath):
results[filepath] = None
continue

# Get file info
stat = os.stat(filepath)
current_size = stat.st_size
current_time = stat.st_ctime

# Check cache
cache_entry = cache.get(filepath)
if cache_entry:
if (
cache_entry["size"] == current_size
and cache_entry["birthtime"] == current_time
):
results[filepath] = cache_entry["sha256"]
continue

if cache_only:
results[filepath] = ""
continue

# Calculate new SHA
calc_func = partial(calculate_sha256_worker, filepath)
sha256 = await loop.run_in_executor(pool, calc_func)

# Update cache and results
new_cache[filepath] = {
"sha256": sha256,
"size": current_size,
"birthtime": current_time,
"last_verified": datetime.now().isoformat(),
}
results[filepath] = sha256

# Save cache
try:
with SHA_CACHE_FILE.open("r") as f:
cache = json.load(f)
cache.update(new_cache)
with SHA_CACHE_FILE.open("w") as f:
json.dump(cache, f, indent=2)
except (IOError, OSError):
pass

class ModelCache(TypedDict):
sha256: str
size: int
birthtime: float
last_verified: str


class ModelHashes:
def __init__(self) -> None:
self._data: dict[str, ModelCache] = {}

async def load(self) -> None:
path = anyio.Path(SHA_CACHE_FILE)
if not await path.exists():
return
async with await path.open("r") as f:
self._data = json.loads(await f.read())

async def save(self) -> None:
with contextlib.suppress(OSError):
async with await anyio.open_file(SHA_CACHE_FILE, "w") as f:
await f.write(json.dumps(self._data, indent=2))

async def get(self, filepath: str, cache_only: bool = False) -> str:
afile = anyio.Path(filepath)
stat = await afile.stat()
entry = self._data.get(filepath)
if (
entry is not None
and entry["size"] == stat.st_size
and entry["birthtime"] == stat.st_ctime
):
return entry["sha256"]
if cache_only:
return ""
sha256 = await self.calculate_sha256(filepath)
self._data[filepath] = {
"sha256": sha256,
"size": stat.st_size,
"birthtime": stat.st_ctime,
"last_verified": datetime.now().isoformat(),
}
return sha256

async def calculate_sha256(self, filepath: str, chunk_size: int = 8192) -> str:
async with await anyio.open_file(filepath, "rb") as f:
sha256 = hashlib.sha256()
while chunk := await f.read(chunk_size):
sha256.update(chunk)
return sha256.hexdigest()
return results
Loading

0 comments on commit 45cb9e3

Please sign in to comment.