diff --git a/core/.pdm-python b/core/.pdm-python new file mode 100644 index 00000000000..d3f452d25e2 --- /dev/null +++ b/core/.pdm-python @@ -0,0 +1 @@ +/usr/bin/python3.8 diff --git a/docs/source/concepts/runner.rst b/docs/source/concepts/runner.rst index b4480a6f011..ce9ffc80950 100644 --- a/docs/source/concepts/runner.rst +++ b/docs/source/concepts/runner.rst @@ -322,15 +322,17 @@ Runner Definition # below are also configurable via config file: - # default configs: - max_batch_size=.. # default max batch size will be applied to all run methods, unless override in the runnable_method_configs - max_latency_ms=.. # default max latency will be applied to all run methods, unless override in the runnable_method_configs + # default configs, which will be applied to all run methods, unless overriden for a specific method: + max_batch_size=.., + max_latency_ms=.., + batching_strategy=.., runnable_method_configs=[ { method_name="predict", max_batch_size=.., max_latency_ms=.., + batching_strategy=.., } ], ) @@ -363,6 +365,10 @@ To explicitly disable or control adaptive batching behaviors at runtime, configu enabled: true max_batch_size: 100 max_latency_ms: 500 + strategy: + name: intelligent_wait + options: + decay: 0.95 .. tab-item:: Individual Runner :sync: individual_runner @@ -376,6 +382,10 @@ To explicitly disable or control adaptive batching behaviors at runtime, configu enabled: true max_batch_size: 100 max_latency_ms: 500 + strategy: + name: intelligent_wait + options: + decay: 0.95 Resource Allocation ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/guides/batching.rst b/docs/source/guides/batching.rst index 8aca5d42f52..10822090c74 100644 --- a/docs/source/guides/batching.rst +++ b/docs/source/guides/batching.rst @@ -52,28 +52,83 @@ In addition to declaring model as batchable, batch dimensions can also be config Configuring Batching -------------------- -If a model supports batching, adaptive batching is enabled by default. To explicitly disable or control adaptive batching behaviors at runtime, configuration can be specified under the ``batching`` key. -Additionally, there are two configurations for customizing batching behaviors, `max_batch_size` and `max_latency_ms`. +If a model supports batching, adaptive batching is enabled by default. To explicitly disable or +control adaptive batching behaviors at runtime, configuration can be specified under the +``batching`` key. Additionally, there are three configuration keys for customizing batching +behaviors, ``max_batch_size``, ``max_latency_ms``, and ``strategy``. Max Batch Size ^^^^^^^^^^^^^^ -Configured through the ``max_batch_size`` key, max batch size represents the maximum size a batch can reach before releasing for inferencing. Max batch size should be set based on the capacity of the available system resources, e.g. memory or GPU memory. +Configured through the ``max_batch_size`` key, max batch size represents the maximum size a batch +can reach before being released for inferencing. Max batch size should be set based on the capacity +of the available system resources, e.g. memory or GPU memory. Max Latency ^^^^^^^^^^^ -Configured through the ``max_latency_ms`` key, max latency represents the maximum latency in milliseconds that a batch should wait before releasing for inferencing. Max latency should be set based on the service level objective (SLO) of the inference requests. +Configured through the ``max_latency_ms`` key, max latency represents the maximum latency in +milliseconds that the scheduler will attempt to uphold by cancelling requests when it thinks the +runner server is incapable of servicing that request in time. Max latency should be set based on the +service level objective (SLO) of the inference requests. + +Batching Strategy +^^^^^^^^^^^^^^^^^ + +Configured through the ``strategy`` and ``strategy_options`` keys, the batching strategy determines +the way that the scheduler chooses a batching window, i.e. the time it waits for requests to combine +them into a batch before dispatching it to begin execution. There are three options: + + - target_latency: this strategy waits until it expects the first request received will take around + ``latency`` time to complete before beginning execution. Choose this method if + you think that your service workload will be very bursty and so the intelligent + wait algorithm will do a poor job of identifying average wait times. + + It takes one option, ``latency_ms`` (default 1000), which is the latency target + to use for dispatch. + + - fixed_wait: this strategy will wait a fixed amount of time after the first request has been + received. It differs from the target_latency strategy in that it does not consider + the amount of time that it expects a batch will take to execute. + + It takes one option, ``wait_ms`` (default 1000), the amount of time to wait after + receiving the first request. + + - intelligent_wait: this strategy waits intelligently in an effort to optimize average latency + across all requests. It takes the average the average time spent in queue, then + calculates the average time it expects to take to wait for and then execute the + batch including the next request. If that time, multiplied by number of + requests in the queue, is less than the average wait time, it will continue + waiting for the next request to arrive. This is the default, and the other + options should only be chosen if undesirable latency behavior is observed. + + It has one option, ``decay`` (default 0.95), which is the rate at which the + dispatcher decays the wait time, per dispatched job. Note that this does not + decay the actual expected wait time, but instead reduces the batching window, + which indirectly reduces the average waiting time. + .. code-block:: yaml :caption: ⚙️ `configuration.yml` runners: + # batching options for all runners + batching: + enabled: true + max_batch_size: 100 + max_latency_ms: 500 + strategy: avg_wait iris_clf: + # batching options for specifically the iris_clf runner + # these options override the above batching: enabled: true max_batch_size: 100 max_latency_ms: 500 + strategy: + name: target_latency + options: + latency_ms: 200 Monitoring ---------- diff --git a/src/bentoml/_internal/configuration/containers.py b/src/bentoml/_internal/configuration/containers.py index 068a783ca4e..93bc4973c70 100644 --- a/src/bentoml/_internal/configuration/containers.py +++ b/src/bentoml/_internal/configuration/containers.py @@ -143,11 +143,13 @@ def __init__( def _finalize(self): RUNNER_CFG_KEYS = [ + "optimizer", "batching", "resources", "logging", "metrics", "traffic", + "strategy", "workers_per_resource", ] global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS} diff --git a/src/bentoml/_internal/configuration/v1/__init__.py b/src/bentoml/_internal/configuration/v1/__init__.py index 55ecf41ce9a..9ff05dadf49 100644 --- a/src/bentoml/_internal/configuration/v1/__init__.py +++ b/src/bentoml/_internal/configuration/v1/__init__.py @@ -142,10 +142,24 @@ }, } _RUNNER_CONFIG = { + s.Optional("optimizer"): s.Or( + str, + { + s.Optional("name"): str, + s.Optional("options"): dict, + }, + ), s.Optional("batching"): { s.Optional("enabled"): bool, s.Optional("max_batch_size"): s.And(int, ensure_larger_than_zero), s.Optional("max_latency_ms"): s.And(int, ensure_larger_than_zero), + s.Optional("strategy"): s.Or( + str, + { + s.Optional("name"): str, + s.Optional("options"): dict, + }, + ), }, # NOTE: there is a distinction between being unset and None here; if set to 'None' # in configuration for a specific runner, it will override the global configuration. diff --git a/src/bentoml/_internal/configuration/v1/default_configuration.yaml b/src/bentoml/_internal/configuration/v1/default_configuration.yaml index 9edb9b22360..82cee747bbb 100644 --- a/src/bentoml/_internal/configuration/v1/default_configuration.yaml +++ b/src/bentoml/_internal/configuration/v1/default_configuration.yaml @@ -72,9 +72,22 @@ runners: traffic: timeout: 900 max_concurrency: ~ + optimizer: + name: linear + options: + initial_slope_ms: 2. + initial_intercept_ms: 1. batching: enabled: true max_batch_size: 100 + # which strategy to use to batch requests + # there are currently two available options: + # - target_latency: attempt to ensure requests are served within a certain amount of time + # - adaptive: wait a variable amount of time in order to optimize for minimal average latency + strategy: + name: adaptive + options: + decay: 0.95 max_latency_ms: 60000 logging: access: diff --git a/src/bentoml/_internal/marshal/dispatcher.py b/src/bentoml/_internal/marshal/dispatcher.py index 22d929cea51..649032728c9 100644 --- a/src/bentoml/_internal/marshal/dispatcher.py +++ b/src/bentoml/_internal/marshal/dispatcher.py @@ -7,11 +7,14 @@ import time import traceback import typing as t +from abc import ABC +from abc import abstractmethod from functools import cached_property import attr import numpy as np +from ...exceptions import BadInput from ..utils.alg import TokenBucket logger = logging.getLogger(__name__) @@ -47,64 +50,123 @@ class Job: dispatch_time: float = 0 -class Optimizer: +OPTIMIZER_REGISTRY = {} + + +class Optimizer(ABC): + optimizer_id: str + n_skipped_sample: int = 0 + + @abstractmethod + def __init__(self, options: dict[str, t.Any]): + pass + + @abstractmethod + def log_outbound(self, batch_size: int, duration: float): + pass + + @abstractmethod + def predict(self, batch_size: int) -> float: + pass + + def predict_diff(self, first_batch_size: int, second_batch_size: int) -> float: + """ + Predict the difference + """ + return self.predict(second_batch_size) - self.predict(first_batch_size) + + def trigger_refresh(self): + pass + + def __init_subclass__(cls, optimizer_id: str): + OPTIMIZER_REGISTRY[optimizer_id] = cls + cls.optimizer_id = optimizer_id + + +class FixedOptimizer(Optimizer, optimizer_id="fixed"): + time: float + + def __init__(self, options: dict[str, t.Any]): + if "time_ms" not in options: + raise BadInput("Attempted to initialize ") + self.time = options["time_ms"] + + def predict(self, batch_size: int): + # explicitly unused parameters + del batch_size + + return self.time + + +class LinearOptimizer(Optimizer, optimizer_id="linear"): """ - Analyse historical data to optimize CorkDispatcher. + Analyze historical data to predict execution time using a simple linear regression on batch size. """ - N_KEPT_SAMPLE = 50 # amount of outbound info kept for inferring params - N_SKIPPED_SAMPLE = 2 # amount of outbound info skipped after init - INTERVAL_REFRESH_PARAMS = 5 # seconds between each params refreshing + o_a: float = 2.0 + o_b: float = 1.0 + + n_kept_sample = 50 # amount of outbound info kept for inferring params + n_skipped_sample = 2 # amount of outbound info skipped after init + param_refresh_interval = 5 # seconds between each params refreshing - def __init__(self, max_latency: float): + def __init__(self, options: dict[str, t.Any]): """ assume the outbound duration follows duration = o_a * n + o_b (all in seconds) """ - self.o_stat: collections.deque[tuple[int, float, float]] = collections.deque( - maxlen=self.N_KEPT_SAMPLE - ) # to store outbound stat data - self.o_a = min(2, max_latency * 2.0 / 30) - self.o_b = min(1, max_latency * 1.0 / 30) + for key in options: + if key == "initial_slope_ms": + self.o_a = options[key] / 1000.0 + elif key == "initial_intercept_ms": + self.o_b = options[key] / 1000.0 + elif key == "n_kept_sample": + self.n_kept_sample = options[key] + elif key == "n_skipped_sample": + self.n_skipped_sample = options[key] + elif key == "param_refresh_interval": + self.param_refresh_interval = options[key] + else: + logger.warning( + f"Optimizer 'linear' ignoring unknown configuration key '{key}'." + ) - self.wait = 0 # the avg wait time before outbound called + self.o_stat: collections.deque[tuple[int, float]] = collections.deque( + maxlen=self.n_kept_sample + ) # to store outbound stat data self._refresh_tb = TokenBucket(2) # to limit params refresh interval self.outbound_counter = 0 - def log_outbound(self, n: int, wait: float, duration: float): - if self.outbound_counter <= self.N_SKIPPED_SAMPLE + 4: - self.outbound_counter += 1 + def log_outbound(self, batch_size: int, duration: float): + if self.outbound_counter <= self.n_skipped_sample: # skip inaccurate info at beginning - if self.outbound_counter <= self.N_SKIPPED_SAMPLE: - return + self.outbound_counter += 1 + return - self.o_stat.append((n, duration, wait)) + self.o_stat.append((batch_size, duration)) - if self._refresh_tb.consume(1, 1.0 / self.INTERVAL_REFRESH_PARAMS, 1): + if self._refresh_tb.consume(1, 1.0 / self.param_refresh_interval, 1): self.trigger_refresh() - def trigger_refresh(self): - if not self.o_stat: - logger.debug( - "o_stat is empty, skip dynamic batching optimizer params update" - ) - return + def predict(self, batch_size: int): + return self.o_a * batch_size + self.o_b + + def predict_diff(self, first_batch_size: int, second_batch_size: int): + return self.o_a * (second_batch_size - first_batch_size) - x = tuple((i, 1) for i, _, _ in self.o_stat) - y = tuple(i for _, i, _ in self.o_stat) + def trigger_refresh(self): + x = tuple((i, 1) for i, _ in self.o_stat) + y = tuple(i for _, i in self.o_stat) - _factors: tuple[float, float] = np.linalg.lstsq(x, y, rcond=None)[0] # type: ignore + _factors = t.cast(tuple[float, float], np.linalg.lstsq(x, y, rcond=None)[0]) _o_a, _o_b = _factors - _o_w = sum(w for _, _, w in self.o_stat) * 1.0 / len(self.o_stat) self.o_a, self.o_b = max(0.000001, _o_a), max(0, _o_b) - self.wait = max(0, _o_w) logger.debug( - "Dynamic batching optimizer params updated: o_a: %.6f, o_b: %.6f, wait: %.6f", + "Dynamic batching optimizer params updated: o_a: %.6f, o_b: %.6f", _o_a, _o_b, - _o_w, ) @@ -112,7 +174,156 @@ def trigger_refresh(self): T_OUT = t.TypeVar("T_OUT") -class CorkDispatcher: +BATCHING_STRATEGY_REGISTRY = {} + + +class BatchingStrategy(ABC): + strategy_id: str + + @abstractmethod + def __init__(self, optimizer: Optimizer, options: dict[t.Any, t.Any]): + pass + + @abstractmethod + async def batch( + self, + optimizer: Optimizer, + queue: t.Deque[Job], + max_latency: float, + max_batch_size: int, + tick_interval: float, + dispatch: t.Callable[[t.Sequence[Job], int], None], + ): + pass + + def __init_subclass__(cls, strategy_id: str): + BATCHING_STRATEGY_REGISTRY[strategy_id] = cls + cls.strategy_id = strategy_id + + +class TargetLatencyStrategy(BatchingStrategy, strategy_id="target_latency"): + latency: float = 1.0 + + def __init__(self, options: dict[t.Any, t.Any]): + for key in options: + if key == "latency": + self.latency = options[key] / 1000.0 + else: + logger.warning( + f"Strategy 'target_latency' ignoring unknown configuration key '{key}'." + ) + + async def batch( + self, + optimizer: Optimizer, + queue: t.Deque[Job], + max_latency: float, + max_batch_size: int, + tick_interval: float, + dispatch: t.Callable[[t.Sequence[Job], int], None], + ): + # explicitly unused parameters + del max_latency + + n = len(queue) + now = time.time() + w0 = now - queue[0].enqueue_time + latency_0 = w0 + optimizer.predict(n) + + while latency_0 < self.latency and n < max_batch_size: + n = len(queue) + now = time.time() + w0 = now - queue[0].enqueue_time + latency_0 = w0 + optimizer.predict(n) + + await asyncio.sleep(tick_interval) + + # call + n_call_out = 0 + batch_size = 0 + for job in queue: + if batch_size + job.data.sample.batch_size <= max_batch_size: + n_call_out += 1 + batch_size += job.data.sample.batch_size + inputs_info = tuple(queue.pop() for _ in range(n_call_out)) + dispatch(inputs_info, batch_size) + + +class AdaptiveStrategy(BatchingStrategy, strategy_id="adaptive"): + decay: float = 0.95 + + n_kept_samples = 50 + avg_wait_times: collections.deque[float] + avg_req_wait: float = 0 + + def __init__(self, options: dict[t.Any, t.Any]): + for key in options: + if key == "decay": + self.decay = options[key] + elif key == "n_kept_samples": + self.n_kept_samples = options[key] + else: + logger.warning( + "Strategy 'adaptive' ignoring unknown configuration value" + ) + + self.avg_wait_times = collections.deque(maxlen=self.n_kept_samples) + + async def batch( + self, + optimizer: Optimizer, + queue: t.Deque[Job], + max_latency: float, + max_batch_size: int, + tick_interval: float, + dispatch: t.Callable[[t.Sequence[Job], int], None], + ): + n = len(queue) + now = time.time() + w0 = now - queue[0].enqueue_time + wn = now - queue[-1].enqueue_time + latency_0 = w0 + optimizer.predict(n) + while ( + # if we don't already have enough requests, + n < max_batch_size + # we are not about to cancel the first request, + and latency_0 + tick_interval <= max_latency * 0.95 + # and waiting will cause average latency to decrese + and n * (wn + tick_interval + optimizer.predict_diff(n, n + 1)) + <= self.avg_req_wait * self.decay + ): + n = len(queue) + now = time.time() + w0 = now - queue[0].enqueue_time + latency_0 = w0 + optimizer.predict(n) + + # wait for additional requests to arrive + await asyncio.sleep(tick_interval) + + # dispatch the batch + inputs_info: list[Job] = [] + n_call_out = 0 + batch_size = 0 + for job in queue: + if batch_size + job.data.sample.batch_size <= max_batch_size: + n_call_out += 1 + + for _ in range(n_call_out): + job = queue.pop() + batch_size += job.data.sample.batch_size + new_wait = (now - job.enqueue_time) / self.n_kept_samples + if len(self.avg_wait_times) == self.n_kept_samples: + oldest_wait = self.avg_wait_times.popleft() + self.avg_req_wait = self.avg_req_wait - oldest_wait + new_wait + else: + # avg deliberately undercounts until we hit n_kept_sample for simplicity + self.avg_req_wait += new_wait + inputs_info.append(job) + + dispatch(inputs_info, batch_size) + + +class Dispatcher: """ A decorator that: * wrap batch function @@ -120,10 +331,14 @@ class CorkDispatcher: The wrapped function should be an async function. """ + background_tasks: set[asyncio.Task[None]] = set() + def __init__( self, max_latency_in_ms: int, max_batch_size: int, + optimizer: Optimizer, + strategy: BatchingStrategy, shared_sema: t.Optional[NonBlockSema] = None, fallback: t.Callable[[], t.Any] | type[t.Any] | None = None, ): @@ -138,7 +353,8 @@ def __init__( """ self.max_latency = max_latency_in_ms / 1000.0 self.fallback = fallback - self.optimizer = Optimizer(self.max_latency) + self.optimizer = optimizer + self.strategy = strategy self.max_batch_size = int(max_batch_size) self.tick_interval = 0.001 @@ -151,6 +367,8 @@ def __init__( def shutdown(self): if self._controller is not None: self._controller.cancel() + for task in self.background_tasks: + task.cancel() for job in self._queue: job.future.cancel() @@ -193,10 +411,11 @@ async def train_optimizer( if self.max_batch_size < batch_size: batch_size = self.max_batch_size + wait = 0 if batch_size > 1: wait = min( self.max_latency * 0.95, - (batch_size * 2 + 1) * (self.optimizer.o_a + self.optimizer.o_b), + self.optimizer.predict(batch_size * 2 + 1), ) req_count = 0 @@ -215,10 +434,10 @@ async def train_optimizer( self._queue.popleft().future.cancel() continue if batch_size > 1: # only wait if batch_size - a = self.optimizer.o_a - b = self.optimizer.o_b - - if n < batch_size and (batch_size * a + b) + w0 <= wait: + if ( + n < batch_size + and self.optimizer.predict(batch_size) + w0 <= wait + ): await asyncio.sleep(self.tick_interval) continue if self._sema.is_locked(): @@ -231,31 +450,18 @@ async def train_optimizer( else: n_call_out = 0 batch_size = 0 - try: - for input_info in self._queue: - if ( - batch_size + input_info.data.sample.batch_size - < self.max_batch_size - ): - n_call_out += 1 - batch_size += input_info.data.sample.batch_size - else: - break - except Exception as e: - n_call_out = min(n, self.max_batch_size) - logger.error( - "error in batch-size aware batching, falling back to regular batching method", - exc_info=e, - ) + for job in self._queue: + if ( + batch_size + job.data.sample.batch_size + <= self.max_batch_size + ): + n_call_out += 1 + batch_size += job.data.sample.batch_size req_count += 1 # call - self._sema.acquire() inputs_info = tuple(self._queue.pop() for _ in range(n_call_out)) - for info in inputs_info: - # fake wait as 0 for training requests - info.enqueue_time = now - self._loop.create_task(self.outbound_call(inputs_info)) + self._dispatch(inputs_info, batch_size) except Exception as e: # pylint: disable=broad-except logger.error(traceback.format_exc(), exc_info=e) @@ -267,7 +473,7 @@ async def controller(self): logger.debug("Starting dispatcher optimizer training...") # warm up the model await self.train_optimizer( - self.optimizer.N_SKIPPED_SAMPLE, self.optimizer.N_SKIPPED_SAMPLE + 6, 1 + self.optimizer.n_skipped_sample, self.optimizer.n_skipped_sample + 6, 1 ) logger.debug("Dispatcher finished warming up model.") @@ -284,7 +490,7 @@ async def controller(self): self.optimizer.trigger_refresh() logger.debug("Dispatcher finished optimizer training request 3.") - if self.optimizer.o_a + self.optimizer.o_b >= self.max_latency: + if self.optimizer.predict(1) >= self.max_latency: logger.warning( "BentoML has detected that a service has a max latency that is likely too low for serving. If many 503 errors are encountered, try raising the 'runner.max_latency' in your BentoML configuration YAML file." ) @@ -298,16 +504,11 @@ async def controller(self): await self._wake_event.wait_for(self._queue.__len__) n = len(self._queue) - dt = self.tick_interval - decay = 0.95 # the decay rate of wait time now = time.time() w0 = now - self._queue[0].enqueue_time - wn = now - self._queue[-1].enqueue_time - a = self.optimizer.o_a - b = self.optimizer.o_b # the estimated latency of the first request if we began processing now - latency_0 = w0 + a * n + b + latency_0 = w0 + self.optimizer.predict(n) if n > 1 and latency_0 >= self.max_latency: self._queue.popleft().future.cancel() @@ -318,48 +519,34 @@ async def controller(self): continue await asyncio.sleep(self.tick_interval) continue - if ( - n < self.max_batch_size - and n * (wn + dt + (a or 0)) <= self.optimizer.wait * decay - ): - await asyncio.sleep(self.tick_interval) - continue + # we are now free to dispatch whenever we like if self.max_batch_size == -1: # batching is disabled - n_call_out = 1 - batch_size = self._queue[0].data.sample.batch_size + self._queue[0].data.sample.batch_size else: - n_call_out = 0 - batch_size = 0 - try: - for input_info in self._queue: - if ( - batch_size + input_info.data.sample.batch_size - < self.max_batch_size - ): - n_call_out += 1 - batch_size += input_info.data.sample.batch_size - else: - break - except Exception as e: - n_call_out = min(n, self.max_batch_size) - logger.error( - "error in batch-size aware batching, falling back to regular batching method", - exc_info=e, - ) - # call - self._sema.acquire() - inputs_info = tuple(self._queue.pop() for _ in range(n_call_out)) - self._loop.create_task(self.outbound_call(inputs_info)) + await self.strategy.batch( + self.optimizer, + self._queue, + self.max_latency, + self.max_batch_size, + self.tick_interval, + self._dispatch, + ) + except Exception as e: # pylint: disable=broad-except logger.error(traceback.format_exc(), exc_info=e) + def _dispatch(self, inputs_info: t.Sequence[Job], batch_size: int): + self._sema.acquire() + task = self._loop.create_task(self.outbound_call(inputs_info, batch_size)) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + async def inbound_call(self, data: Params[Payload]): if self.max_batch_size > 0 and data.sample.batch_size > self.max_batch_size: raise RuntimeError( f"batch of size {data.sample.batch_size} exceeds configured max batch size of {self.max_batch_size}." ) - now = time.time() future = self._loop.create_future() input_info = Job(now, data, future) @@ -368,11 +555,14 @@ async def inbound_call(self, data: Params[Payload]): self._wake_event.notify_all() return await future - async def outbound_call(self, inputs_info: tuple[Job, ...]): + async def outbound_call(self, inputs_info: t.Sequence[Job], batch_size: int): _time_start = time.time() _done = False - batch_size = len(inputs_info) - logger.debug("Dynamic batching cork released, batch size: %d", batch_size) + logger.debug( + "Dynamic batching cork released, batch size: %d (%d requests)", + batch_size, + len(inputs_info), + ) try: outputs = await self.callback( tuple(t.cast(t.Any, input_info.data) for input_info in inputs_info) @@ -384,8 +574,7 @@ async def outbound_call(self, inputs_info: tuple[Job, ...]): fut.set_result(out) _done = True self.optimizer.log_outbound( - n=len(inputs_info), - wait=_time_start - inputs_info[-1].enqueue_time, + batch_size=len(inputs_info), duration=time.time() - _time_start, ) except Exception as e: # pylint: disable=broad-except diff --git a/src/bentoml/_internal/models/model.py b/src/bentoml/_internal/models/model.py index 98cd80f353f..443339cf25e 100644 --- a/src/bentoml/_internal/models/model.py +++ b/src/bentoml/_internal/models/model.py @@ -40,6 +40,8 @@ from ..utils import normalize_labels_value if t.TYPE_CHECKING: + from ..marshal.dispatcher import BatchingStrategy + from ..marshal.dispatcher import Optimizer from ..runner import Runnable from ..runner import Runner from ..runner.strategy import Strategy @@ -319,6 +321,8 @@ def to_runner( name: str = "", max_batch_size: int | None = None, max_latency_ms: int | None = None, + optimizer: Optimizer | None = None, + batching_strategy: BatchingStrategy | None = None, method_configs: dict[str, dict[str, int]] | None = None, embedded: bool = False, scheduling_strategy: type[Strategy] | None = None, @@ -355,6 +359,8 @@ def to_runner( models=[self], max_batch_size=max_batch_size, max_latency_ms=max_latency_ms, + optimizer=optimizer, + batching_strategy=batching_strategy, method_configs=method_configs, embedded=embedded, scheduling_strategy=scheduling_strategy, diff --git a/src/bentoml/_internal/runner/runner.py b/src/bentoml/_internal/runner/runner.py index 59a6cd993f6..1901b0bd1ee 100644 --- a/src/bentoml/_internal/runner/runner.py +++ b/src/bentoml/_internal/runner/runner.py @@ -4,13 +4,17 @@ import typing as t from abc import ABC from abc import abstractmethod +from pprint import pprint import attr from simple_di import Provide from simple_di import inject +from ...exceptions import BentoMLConfigException from ...exceptions import StateException from ..configuration.containers import BentoMLContainer +from ..marshal.dispatcher import BATCHING_STRATEGY_REGISTRY +from ..marshal.dispatcher import OPTIMIZER_REGISTRY from ..models.model import Model from ..tag import validate_tag_str from ..utils import first_not_none @@ -22,6 +26,8 @@ if t.TYPE_CHECKING: from ...triton import Runner as TritonRunner + from ..marshal.dispatcher import BatchingStrategy + from ..marshal.dispatcher import Optimizer from .runnable import RunnableMethodConfig # only use ParamSpec in type checking, as it's only in 3.10 @@ -47,6 +53,8 @@ class RunnerMethod(t.Generic[T, P, R]): config: RunnableMethodConfig max_batch_size: int max_latency_ms: int + optimizer: Optimizer + batching_strategy: BatchingStrategy def run(self, *args: P.args, **kwargs: P.kwargs) -> R: return self.runner._runner_handle.run_method(self, *args, **kwargs) @@ -135,7 +143,7 @@ def __getattr__(self, item: str) -> t.Any: runner_methods: list[RunnerMethod[t.Any, t.Any, t.Any]] scheduling_strategy: type[Strategy] - workers_per_resource: int | float = 1 + workers_per_resource: float = 1 runnable_init_params: dict[str, t.Any] = attr.field( default=None, converter=attr.converters.default_if_none(factory=dict) ) @@ -170,6 +178,8 @@ def __init__( models: list[Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, + optimizer: Optimizer | None = None, + batching_strategy: BatchingStrategy | None = None, method_configs: dict[str, dict[str, int]] | None = None, embedded: bool = False, ) -> None: @@ -189,8 +199,11 @@ def __init__( models: An optional list composed of ``bentoml.Model`` instances. max_batch_size: Max batch size config for dynamic batching. If not provided, use the default value from configuration. - max_latency_ms: Max latency config for dynamic batching. If not provided, use the default value from - configuration. + max_latency_ms: Max latency config. If not provided, uses the default value from configuration. + optimizer: Optimizer to use to predict runtime for runners. If not provided, uses the default value + from the configuration + batching_strategy: Batching strategy for dynamic batching. If not provided, uses the default value + from the configuration. method_configs: A dictionary per method config for this given Runner signatures. Returns: @@ -227,6 +240,8 @@ def __init__( method_max_batch_size = None method_max_latency_ms = None + method_optimizer = None + method_batching_strategy = None if method_name in method_configs: method_max_batch_size = method_configs[method_name].get( "max_batch_size" @@ -234,6 +249,56 @@ def __init__( method_max_latency_ms = method_configs[method_name].get( "max_latency_ms" ) + method_optimizer = method_configs[method_name].get("optimizer") + method_batching_strategy = method_configs[method_name].get( + "batching_strategy" + ) + + optimizer_conf = config["optimizer"] + if isinstance(optimizer_conf, str): + optimizer_name = optimizer_conf + optimizer_opts = {} + else: + optimizer_name = optimizer_conf["name"] + optimizer_opts = optimizer_conf["options"] + + if optimizer_name not in OPTIMIZER_REGISTRY: + raise BentoMLConfigException( + f"Unknown optimizer '{optimizer_name}'. Available optimizers are: {','.join(OPTIMIZER_REGISTRY.keys())}." + ) + + try: + default_optimizer = OPTIMIZER_REGISTRY[optimizer_name](optimizer_opts) + except Exception as e: + raise BentoMLConfigException( + f"Initializing strategy '{optimizer_name}' with configured options ({pprint(optimizer_opts)}) failed." + ) from e + + strategy_conf = config["batching"]["strategy"] + if isinstance(strategy_conf, str): + pass + else: + strategy_conf["name"] + strategy_conf["options"] + + if config["batching"]["strategy"] not in BATCHING_STRATEGY_REGISTRY: + raise BentoMLConfigException( + f"Unknown batching strategy '{config['batching']['strategy']}'. Available strategies are: {','.join(BATCHING_STRATEGY_REGISTRY.keys())}.", + ) + + try: + if isinstance(strategy_conf, str): + default_batching_strategy = BATCHING_STRATEGY_REGISTRY[ + strategy_conf + ]({}) + else: + default_batching_strategy = BATCHING_STRATEGY_REGISTRY[ + strategy_conf["name"] + ](strategy_conf["options"]) + except Exception as e: + raise BentoMLConfigException( + f"Initializing strategy '{pprint(config['batching'])}' failed." + ) from e runner_method_map[method_name] = RunnerMethod( runner=self, @@ -249,6 +314,16 @@ def __init__( max_latency_ms, default=config["batching"]["max_latency_ms"], ), + optimizer=first_not_none( + method_optimizer, + optimizer, + default=default_optimizer, + ), + batching_strategy=first_not_none( + method_batching_strategy, + batching_strategy, + default=default_batching_strategy, + ), ) self.__attrs_init__( diff --git a/src/bentoml/_internal/server/runner_app.py b/src/bentoml/_internal/server/runner_app.py index bc337569ae1..3a86a3ad03a 100644 --- a/src/bentoml/_internal/server/runner_app.py +++ b/src/bentoml/_internal/server/runner_app.py @@ -17,7 +17,7 @@ from ..configuration.containers import BentoMLContainer from ..context import component_context from ..context import trace_context -from ..marshal.dispatcher import CorkDispatcher +from ..marshal.dispatcher import Dispatcher from ..runner.container import AutoContainer from ..runner.container import Payload from ..runner.utils import PAYLOAD_META_HEADER @@ -54,7 +54,7 @@ def __init__( self.worker_index = worker_index self.enable_metrics = enable_metrics - self.dispatchers: dict[str, CorkDispatcher] = {} + self.dispatchers: dict[str, Dispatcher] = {} runners_config = BentoMLContainer.runners_config.get() traffic = runners_config.get("traffic", {}).copy() @@ -69,9 +69,11 @@ def fallback(): for method in runner.runner_methods: max_batch_size = method.max_batch_size if method.config.batchable else -1 - self.dispatchers[method.name] = CorkDispatcher( + self.dispatchers[method.name] = Dispatcher( max_latency_in_ms=method.max_latency_ms, max_batch_size=max_batch_size, + optimizer=method.optimizer, + strategy=method.batching_strategy, fallback=fallback, )