Skip to content

Commit

Permalink
fix: allow setting timeout for cloud client
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Sep 24, 2024
1 parent cfe8a73 commit b206950
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
22 changes: 16 additions & 6 deletions src/_bentoml_sdk/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from .base import Model

if t.TYPE_CHECKING:
from huggingface_hub import HfApi

CONFIG_FILE = "config.json"
DEFAULT_HF_ENDPOINT = "https://huggingface.co"

Expand All @@ -39,10 +42,17 @@ class HuggingFaceModel(Model[str]):
endpoint: str | None = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT"))

@cached_property
def commit_hash(self) -> str:
from huggingface_hub import model_info
def _hf_api(self) -> HfApi:
from huggingface_hub import HfApi

return model_info(self.model_id, revision=self.revision).sha or self.revision
return HfApi(endpoint=self.endpoint)

@cached_property
def commit_hash(self) -> str:
return (
self._hf_api.model_info(self.model_id, revision=self.revision).sha
or self.revision
)

def resolve(self, base_path: t.Union[PathType, FS, None] = None) -> str:
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -91,9 +101,9 @@ def from_info(cls, info: BentoModelInfo) -> HuggingFaceModel:
)

def _get_model_size(self, revision: str) -> int:
from huggingface_hub import model_info

info = model_info(self.model_id, revision=revision, files_metadata=True)
info = self._hf_api.model_info(
self.model_id, revision=revision, files_metadata=True
)
return sum((file.size or 0) for file in (info.siblings or []))

def to_create_schema(self) -> CreateModelSchema:
Expand Down
11 changes: 7 additions & 4 deletions src/bentoml/_internal/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from bentoml._internal.cloud.client import RestApiClient

from .bento import BentoAPI
from .config import DEFAULT_ENDPOINT
from .config import CloudClientConfig
from .deployment import DeploymentAPI
from .model import ModelAPI
from .secret import SecretAPI
from .yatai import YataiClient as YataiClient

DEFAULT_ENDPOINT = "https://cloud.bentoml.com"


@attrs.frozen
class BentoCloudClient:
Expand All @@ -22,6 +21,7 @@ class BentoCloudClient:
Args:
api_key: The API key to use for the client. env: BENTO_CLOUD_API_KEY
endpoint: The endpoint to use for the client. env: BENTO_CLOUD_ENDPOINT
timeout: The timeout to use for the client. Defaults to 60 seconds.
Attributes:
bento: Bento API
Expand All @@ -36,7 +36,10 @@ class BentoCloudClient:
secret: SecretAPI

def __init__(
self, api_key: str | None = None, endpoint: str = DEFAULT_ENDPOINT
self,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
timeout: int = 60,
) -> None:
if api_key is None:
from ..configuration.containers import BentoMLContainer
Expand All @@ -46,7 +49,7 @@ def __init__(
api_key = ctx.api_token
endpoint = ctx.endpoint

client = RestApiClient(endpoint, api_key)
client = RestApiClient(endpoint, api_key, timeout)
bento = BentoAPI(client)
model = ModelAPI(client)
deployment = DeploymentAPI(client)
Expand Down
7 changes: 5 additions & 2 deletions src/bentoml/_internal/cloud/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def save(self, *, ignore_warning: bool = False) -> None:
config.to_yaml_file()


DEFAULT_ENDPOINT = "https://cloud.bentoml.com"


@attr.define
class CloudClientConfig:
contexts: t.List[CloudClientContext] = attr.field(factory=list)
Expand All @@ -67,10 +70,10 @@ class CloudClientConfig:
def get_context(self, context: t.Optional[str] = None) -> CloudClientContext:
from os import environ

if "BENTO_CLOUD_API_KEY" in environ and "BENTO_CLOUD_API_ENDPOINT" in environ:
if "BENTO_CLOUD_API_KEY" in environ:
return CloudClientContext(
name="__env__",
endpoint=environ["BENTO_CLOUD_API_ENDPOINT"],
endpoint=environ.get("BENTO_CLOUD_API_ENDPOINT", DEFAULT_ENDPOINT),
api_token=environ["BENTO_CLOUD_API_KEY"],
)
if context is None:
Expand Down
3 changes: 2 additions & 1 deletion src/bentoml_cli/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import rich

from bentoml._internal.cloud.client import RestApiClient
from bentoml._internal.cloud.config import DEFAULT_ENDPOINT
from bentoml._internal.cloud.config import CloudClientConfig
from bentoml._internal.cloud.config import CloudClientContext
from bentoml._internal.configuration.containers import BentoMLContainer
Expand All @@ -30,7 +31,7 @@ def cloud_command():
"--endpoint",
type=click.STRING,
help="BentoCloud endpoint",
default="https://cloud.bentoml.com",
default=DEFAULT_ENDPOINT,
envvar="BENTO_CLOUD_API_ENDPOINT",
show_default=True,
show_envvar=True,
Expand Down

0 comments on commit b206950

Please sign in to comment.