Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: always make an HTTP call when calling self API methods #5017

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ dependencies = [
"aiosqlite>=0.20.0",
"uv",
"questionary>=2.0.1",
"a2wsgi>=1.10.7",
]
dynamic = ["version"]
[project.urls]
Expand Down
19 changes: 15 additions & 4 deletions src/_bentoml_impl/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from httpx._types import RequestFiles

from _bentoml_sdk import Service
from bentoml._internal.external_typing import ASGIApp

from ..serde import Serde

T = t.TypeVar("T", bound="HTTPClient[t.Any]")
A = t.TypeVar("A")

C = t.TypeVar("C", httpx.Client, httpx.AsyncClient)
AnyClient = t.TypeVar("AnyClient", httpx.Client, httpx.AsyncClient)
logger = logging.getLogger("bentoml.io")
Expand Down Expand Up @@ -72,6 +72,7 @@ class HTTPClient(AbstractClient, t.Generic[C]):
media_type: str = "application/json"
timeout: float = 30
default_headers: dict[str, str] = attr.field(factory=dict)
app: ASGIApp | None = None

_opened_files: list[io.BufferedReader] = attr.field(init=False, factory=list)
_temp_dir: tempfile.TemporaryDirectory[str] = attr.field(init=False)
Expand All @@ -82,6 +83,7 @@ def _make_client(
url: str,
headers: t.Mapping[str, str],
timeout: float,
app: ASGIApp | None = None,
) -> AnyClient:
parsed = urlparse(url)
transport = None
Expand All @@ -94,6 +96,13 @@ def _make_client(
url = "http://127.0.0.1:3000"
elif parsed.scheme == "tcp":
url = f"http://{parsed.netloc}"
elif app is not None:
if client_cls is httpx.Client:
from a2wsgi import ASGIMiddleware

transport = httpx.WSGITransport(app=ASGIMiddleware(app))
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
else:
transport = httpx.ASGITransport(app=app)
return client_cls(
base_url=url,
transport=transport, # type: ignore
Expand All @@ -115,6 +124,7 @@ def __init__(
server_ready_timeout: float | None = None,
token: str | None = None,
timeout: float = 30,
app: ASGIApp | None = None,
) -> None:
"""Create a client instance from a URL.

Expand Down Expand Up @@ -165,14 +175,15 @@ def __init__(
media_type=media_type,
default_headers=default_headers,
timeout=timeout,
app=app,
)
if server_ready_timeout is None or server_ready_timeout > 0:
if app is None and (server_ready_timeout is None or server_ready_timeout > 0):
self.wait_until_server_ready(server_ready_timeout)
if service is None:
schema_url = urljoin(url, "/schema.json")

with self._make_client(
httpx.Client, url, default_headers, timeout
httpx.Client, url, default_headers, timeout, app=app
) as client:
resp = client.get("/schema.json")

Expand All @@ -193,7 +204,7 @@ def __init__(
@cached_property
def client(self) -> C:
return self._make_client(
self.client_cls, self.url, self.default_headers, self.timeout
self.client_cls, self.url, self.default_headers, self.timeout, self.app
)

@cached_property
Expand Down
5 changes: 4 additions & 1 deletion src/_bentoml_impl/client/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

if t.TYPE_CHECKING:
from _bentoml_sdk.service import ServiceConfig

from bentoml._internal.external_typing import ASGIApp
T = t.TypeVar("T")
logger = logging.getLogger("bentoml.impl")

Expand All @@ -30,6 +30,7 @@ def __init__(
*,
service: Service[T] | None = None,
media_type: str = "application/vnd.bentoml+pickle",
app: ASGIApp | None = None,
) -> None:
from bentoml.container import BentoMLContainer

Expand All @@ -48,13 +49,15 @@ def __init__(
service=service,
timeout=timeout,
server_ready_timeout=0,
app=app,
)
self._async = AsyncHTTPClient(
url,
media_type=media_type,
service=service,
timeout=timeout,
server_ready_timeout=0,
app=app,
)
if service is not None:
self._inner = service.inner
Expand Down
32 changes: 21 additions & 11 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import inspect
import logging
import math
import os
import typing as t
from http import HTTPStatus
from pathlib import Path

import anyio
Expand Down Expand Up @@ -271,9 +271,20 @@ def server_request_hook(span: Span | None, _scope: dict[str, t.Any]) -> None:
middlewares.append(Middleware(ContextMiddleware, context=self.service.context))
return middlewares

async def create_instance(self) -> None:
async def create_instance(self, app: Starlette) -> None:
from ..client import RemoteProxy

self._service_instance = self.service()
logger.info("Service %s initialized", self.service.name)
if deployment_url := os.getenv("BENTOCLOUD_DEPLOYMENT_URL"):
proxy = RemoteProxy(
deployment_url, service=self.service, media_type="application/json"
)
else:
proxy = RemoteProxy("http://localhost:3000", service=self.service, app=app)
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
self._service_instance.__self_proxy__ = proxy # type: ignore[attr-defined]
self._service_instance.to_async = proxy.to_async # type: ignore[attr-defined]
self._service_instance.to_sync = proxy.to_sync # type: ignore[attr-defined]
set_current_service(self._service_instance)
store_path = BentoMLContainer.result_store_file.get()
self._result_store = Sqlite3Store(store_path)
Expand All @@ -300,9 +311,11 @@ def _add_response_headers(
trace_context.trace_id, logging_format["trace_id"]
)

async def destroy_instance(self) -> None:
async def destroy_instance(self, _: Starlette) -> None:
from _bentoml_sdk.service.dependency import cleanup

from ..client import RemoteProxy

# Call on_shutdown hook with optional ctx or context parameter
for name, member in vars(self.service.inner).items():
if callable(member) and getattr(member, "__bentoml_shutdown_hook__", False):
Expand All @@ -313,6 +326,9 @@ async def destroy_instance(self) -> None:
await result

await cleanup()
own_proxy = getattr(self._service_instance, "__self_proxy__", None)
if isinstance(own_proxy, RemoteProxy):
await own_proxy.close()
self._service_instance = None
set_current_service(None)
await self._result_store.__aexit__(None, None, None)
Expand Down Expand Up @@ -512,7 +528,7 @@ async def batch_infer(
self, name: str, input_args: tuple[t.Any, ...], input_kwargs: dict[str, t.Any]
) -> t.Any:
method = self.service.apis[name]
func = getattr(self._service_instance, name)
func = getattr(self._service_instance, name).local

async def inner_infer(
batches: t.Sequence[t.Any], **kwargs: t.Any
Expand Down Expand Up @@ -608,15 +624,9 @@ async def api_endpoint(self, name: str, request: Request) -> Response:

media_type = request.headers.get("Content-Type", "application/json")
media_type = media_type.split(";")[0].strip()
if self.is_main and media_type == "application/vnd.bentoml+pickle":
# Disallow pickle media type for main service for security reasons
raise BentoMLException(
"Pickle media type is not allowed for main service",
error_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
)

method = self.service.apis[name]
func = getattr(self._service_instance, name)
func = getattr(self._service_instance, name).local
ctx = self.service.context
serde = ALL_SERDE[media_type]()
input_data = await method.input_spec.from_http_request(request, serde)
Expand Down
4 changes: 0 additions & 4 deletions src/_bentoml_sdk/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def wrapper(func: t.Callable[t.Concatenate[t.Any, P], R]) -> APIMethod[P, R]:
}
if route is not None:
params["route"] = route
if name is not None:
params["name"] = name
if input_spec is not None:
params["input_spec"] = input_spec
if output_spec is not None:
Expand Down Expand Up @@ -192,8 +190,6 @@ def wrapper(func: t.Callable[t.Concatenate[t.Any, P], R]) -> APIMethod[P, R]:
}
if route is not None:
params["route"] = route
if name is not None:
params["name"] = name
if input_spec is not None:
params["input_spec"] = input_spec
if output_spec is not None:
Expand Down
18 changes: 14 additions & 4 deletions src/_bentoml_sdk/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _io_descriptor_converter(it: t.Any) -> type[IODescriptor]:
class APIMethod(t.Generic[P, R]):
func: t.Callable[t.Concatenate[t.Any, P], R]
route: str = attrs.field()
name: str = attrs.field()
name: str = attrs.field(init=False)
input_spec: type[IODescriptor] = attrs.field(converter=_io_descriptor_converter)
output_spec: type[IODescriptor] = attrs.field(converter=_io_descriptor_converter)
batchable: bool = False
Expand Down Expand Up @@ -107,12 +107,22 @@ def __get__(self: T, instance: None, owner: type) -> T: ...
def __get__(self, instance: object, owner: type) -> t.Callable[P, R]: ...

def __get__(self: T, instance: t.Any, owner: type) -> t.Callable[P, R] | T:
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined

if instance is None:
return self

local_caller = self._local_call(instance)

if proxy := getattr(instance, "__self_proxy__", None):
func = getattr(proxy, self.name)
else:
func = local_caller
func.local = local_caller # type: ignore[attr-defined]
return func

def _local_call(self, instance: t.Any) -> t.Callable[P, R]:
from pydantic.fields import FieldInfo
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
from pydantic_core import PydanticUndefined

func_sig = inspect.signature(self.func)
# skip the `self` parameter
params = list(func_sig.parameters.values())[1:]
Expand Down
60 changes: 51 additions & 9 deletions src/_bentoml_sdk/service/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import math
Expand All @@ -10,6 +11,7 @@
from functools import lru_cache
from functools import partial

import anyio.to_thread
import attrs
from simple_di import Provide
from simple_di import inject
Expand Down Expand Up @@ -256,8 +258,7 @@ def mount_asgi_app(
def mount_wsgi_app(
self, app: ext.WSGIApp, path: str = "/", name: str | None = None
) -> None:
# TODO: Migrate to a2wsgi
from starlette.middleware.wsgi import WSGIMiddleware
from a2wsgi import WSGIMiddleware

self.mount_apps.append((WSGIMiddleware(app), path, name)) # type: ignore

Expand All @@ -270,6 +271,7 @@ def __call__(self) -> T:
try:
instance = self.inner()
instance.to_async = _AsyncWrapper(instance, self.apis.keys())
instance.to_sync = _SyncWrapper(instance, self.apis.keys())
return instance
except Exception:
logger.exception("Initializing service error")
Expand Down Expand Up @@ -469,26 +471,27 @@ def __init__(self) -> None:
)


class _AsyncWrapper:
class _Wrapper:
def __init__(self, wrapped: t.Any, apis: t.Iterable[str]) -> None:
self.__call = None
for name in apis:
if name == "__call__":
self.__call = self.__make_method(wrapped, name)
self.__call = self._make_method(wrapped, name)
else:
setattr(self, name, self.__make_method(wrapped, name))
setattr(self, name, self._make_method(wrapped, name))

def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if self.__call is None:
raise TypeError("This service is not callable.")
return self.__call(*args, **kwargs)

def __make_method(self, inner: t.Any, name: str) -> t.Any:
import asyncio
def _make_method(self, instance: t.Any, name: str) -> t.Any:
raise NotImplementedError

import anyio.to_thread

original_func = func = getattr(inner, name)
class _AsyncWrapper(_Wrapper):
def _make_method(self, instance: t.Any, name: str) -> t.Any:
original_func = func = getattr(instance, name).local
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
Expand Down Expand Up @@ -526,3 +529,42 @@ async def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
return await anyio.to_thread.run_sync(partial(func, **kwargs), *args)

return wrapped


class _SyncWrapper(_Wrapper):
def _make_method(self, instance: t.Any, name: str) -> t.Any:
original_func = func = getattr(instance, name).local
while hasattr(original_func, "func"):
original_func = original_func.func
is_async_func = (
asyncio.iscoroutinefunction(original_func)
or (
callable(original_func)
and asyncio.iscoroutinefunction(original_func.__call__) # type: ignore
)
or inspect.isasyncgenfunction(original_func)
)
if not is_async_func:
return func

if inspect.isasyncgenfunction(original_func):

def wrapped_gen(
*args: t.Any, **kwargs: t.Any
) -> t.Generator[t.Any, None, None]:
agen = func(*args, **kwargs)
loop = asyncio.get_event_loop()
while True:
try:
yield loop.run_until_complete(agen.__anext__())
except StopAsyncIteration:
break

return wrapped_gen
else:

def wrapped(*args: P.args, **kwargs: P.kwargs) -> t.Any:
loop = asyncio.get_event_loop()
return loop.run_until_complete(func(*args, **kwargs))

return wrapped
Loading
Loading