-
Notifications
You must be signed in to change notification settings - Fork 804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: implement batching strategies #3630
base: main
Are you sure you want to change the base?
Changes from 1 commit
7510f21
06ffa76
ca5edb7
c11fb81
451eff0
6af6df2
bc4753f
c4e2bec
3e69b4d
6652866
b153452
4211a6b
a8e17f4
a4c3eac
5e6c844
ce337a6
ca9bfcf
17e61d5
a4d4850
ce403c1
28d209a
56088fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
import functools | ||
import traceback | ||
import collections | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import numpy as np | ||
|
||
|
@@ -41,7 +43,7 @@ class Job: | |
|
||
class Optimizer: | ||
""" | ||
Analyse historical data to optimize CorkDispatcher. | ||
Analyze historical data to predict execution time using a simple linear regression on batch size. | ||
""" | ||
|
||
N_KEPT_SAMPLE = 50 # amount of outbound info kept for inferring params | ||
|
@@ -98,14 +100,97 @@ def trigger_refresh(self): | |
T_OUT = t.TypeVar("T_OUT") | ||
|
||
|
||
class CorkDispatcher: | ||
""" | ||
A decorator that: | ||
* wrap batch function | ||
* implement CORK algorithm to cork & release calling of wrapped function | ||
The wrapped function should be an async function. | ||
""" | ||
BATCHING_STRATEGY_REGISTRY = {} | ||
|
||
|
||
class BatchingStrategy(abc.ABC): | ||
strategy_id: str | ||
|
||
@abc.abstractmethod | ||
def controller(queue: t.Sequence[Job], predict_execution_time: t.Callable[t.Sequence[Job]], dispatch: t.Callable[]): | ||
pass | ||
|
||
def __init_subclass__(cls, strategy_id: str): | ||
BATCHING_STRATEGY_REGISTRY[strategy_id] = cls | ||
cls.strategy_id = strategy_id | ||
|
||
|
||
class TargetLatencyStrategy(strategy_id="target_latency"): | ||
latency: float = 1 | ||
|
||
def __init__(self, options: dict[t.Any, t.Any]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: typed dict for init. |
||
for key in options: | ||
if key == "latency": | ||
self.latency = options[key] / 1000.0 | ||
else: | ||
logger.warning("Strategy 'target_latency' ignoring unknown configuration key '{key}'.") | ||
|
||
async def wait(queue: t.Sequence[Job], optimizer: Optimizer, max_latency: float, max_batch_size: int, tick_interval: float): | ||
now = time.time() | ||
w0 = now - queue[0].enqueue_time | ||
latency_0 = w0 + optimizer.o_a * n + optimizer.o_b | ||
|
||
while latency_0 < self.latency: | ||
n = len(queue) | ||
now = time.time() | ||
w0 = now - queue[0].enqueue_time | ||
latency_0 = w0 + optimizer.o_a * n + optimizer.o_b | ||
|
||
await asyncio.sleep(tick_interval) | ||
|
||
|
||
class FixedWaitStrategy(strategy_id="fixed_wait"): | ||
wait: float = 1 | ||
|
||
def __init__(self, options: dict[t.Any, t.Any]): | ||
for key in options: | ||
if key == "wait": | ||
self.wait = options[key] / 1000.0 | ||
else: | ||
logger.warning("Strategy 'fixed_wait' ignoring unknown configuration key '{key}'") | ||
|
||
async def wait(queue: t.Sequence[Job], optimizer: Optimizer, max_latency: float, max_batch_size: int, tick_interval: float): | ||
now = time.time() | ||
w0 = now - queue[0].enqueue_time | ||
|
||
if w0 < self.wait: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add loop checking for |
||
await asyncio.sleep(self.wait - w0) | ||
|
||
|
||
class IntelligentWaitStrategy(strategy_id="intelligent_wait"): | ||
decay: float = 0.95 | ||
|
||
def __init__(self, options: dict[t.Any, t.Any]): | ||
for key in options: | ||
if key == "decay": | ||
self.decay = options[key] | ||
else: | ||
logger.warning("Strategy 'intelligent_wait' ignoring unknown configuration value") | ||
|
||
async def wait(queue: t.Sequence[Job], optimizer: Optimizer, max_latency: float, max_batch_size: int, tick_interval: float): | ||
n = len(queue) | ||
now = time.time() | ||
wn = now - queue[-1].enqueue_time | ||
latency_0 = w0 + optimizer.o_a * n + optimizer.o_b | ||
while ( | ||
# if we don't already have enough requests, | ||
n < max_batch_size | ||
# we are not about to cancel the first request, | ||
and latency_0 + dt <= self.max_latency * 0.95 | ||
# and waiting will cause average latency to decrese | ||
sauyon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and n * (wn + dt + optimizer.o_a) <= optimizer.wait * decay | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n: number of requests in queue ^ The above is a measure of how much latency will be added to every request if we wait for a new request and add that to the batch less than optimizer.wait: the average amount of time a request sits in queue |
||
): | ||
n = len(queue) | ||
now = time.time() | ||
w0 = now - queue[0].enqueue_time | ||
latency_0 = w0 + optimizer.o_a * n + optimizer.o_b | ||
|
||
# wait for additional requests to arrive | ||
await asyncio.sleep(tick_interval) | ||
|
||
|
||
|
||
class Dispatcher: | ||
def __init__( | ||
self, | ||
max_latency_in_ms: int, | ||
|
@@ -283,19 +368,11 @@ async def controller(self): | |
continue | ||
await asyncio.sleep(self.tick_interval) | ||
continue | ||
if ( | ||
n < self.max_batch_size | ||
and n * (wn + dt + (a or 0)) <= self.optimizer.wait * decay | ||
): | ||
n = len(self._queue) | ||
now = time.time() | ||
wn = now - self._queue[-1].enqueue_time | ||
latency_0 += dt | ||
|
||
# wait for additional requests to arrive | ||
await asyncio.sleep(self.tick_interval) | ||
continue | ||
|
||
# we are now free to dispatch whenever we like | ||
await self.strategy.wait(self._queue, optimizer, self.max_latency, self.max_batch_size, self.tick_interval) | ||
|
||
n = len(self._queue) | ||
n_call_out = min(self.max_batch_size, n) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this (and above) logic into strategy. |
||
# call | ||
self._sema.acquire() | ||
|
@@ -306,6 +383,7 @@ async def controller(self): | |
except Exception as e: # pylint: disable=broad-except | ||
logger.error(traceback.format_exc(), exc_info=e) | ||
|
||
|
||
async def inbound_call(self, data: t.Any): | ||
now = time.time() | ||
future = self._loop.create_future() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.