Skip to content

Commit

Permalink
run thread for event loop
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Mar 20, 2023
1 parent 9cbb69f commit 9f2c608
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
37 changes: 24 additions & 13 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.method_table import GRPC_METHOD_TABLE
from mii.event_loop import get_event_loop


def _get_deployment_info(deployment_name):
Expand Down Expand Up @@ -56,7 +57,7 @@ class MIIClient():
Client to send queries to a single endpoint.
"""
def __init__(self, task_name, host, port):
self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()
channel = create_channel(host, port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.task = get_task(task_name)
Expand All @@ -73,17 +74,22 @@ async def _request_async_response(self, request_dict, **query_kwargs):
proto_response
) if "unpack_response_from_proto" in conversions else proto_response

def query(self, request_dict, **query_kwargs):
return self.asyncio_loop.run_until_complete(
def query_async(self, request_dict, **query_kwargs):
return asyncio.run_coroutine_threadsafe(
self._request_async_response(request_dict,
**query_kwargs))
**query_kwargs),
get_event_loop())

def query(self, request_dict, **query_kwargs):
return self.query_async(request_dict, **query_kwargs).result()

async def terminate_async(self):
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())

def terminate(self):
self.asyncio_loop.run_until_complete(self.terminate_async())
asyncio.run_coroutine_threadsafe(self.terminate_async(),
get_event_loop()).result()


class MIITensorParallelClient():
Expand All @@ -94,7 +100,7 @@ class MIITensorParallelClient():
def __init__(self, task_name, host, ports):
self.task = get_task(task_name)
self.clients = [MIIClient(task_name, host, port) for port in ports]
self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()

# runs task in parallel and return the result from the first task
async def _query_in_tensor_parallel(self, request_string, query_kwargs):
Expand All @@ -106,7 +112,16 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs):
**query_kwargs)))

await responses[0]
return responses[0]
return responses[0].result()

def query_async(self, request_dict, **query_kwargs):
"""Asynchronously auery a local deployment.
See `query` for the arguments and the return value.
"""
return asyncio.run_coroutine_threadsafe(
self._query_in_tensor_parallel(request_dict,
query_kwargs),
self.asyncio_loop)

def query(self, request_dict, **query_kwargs):
"""Query a local deployment:
Expand All @@ -121,11 +136,7 @@ def query(self, request_dict, **query_kwargs):
Returns:
response: Response of the model
"""
response = self.asyncio_loop.run_until_complete(
self._query_in_tensor_parallel(request_dict,
query_kwargs))
ret = response.result()
return ret
return self.query_async(request_dict, **query_kwargs).result()

def terminate(self):
"""Terminates the deployment"""
Expand All @@ -135,5 +146,5 @@ def terminate(self):

def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.restful_api_port > 0:
if mii_configs.enable_restful_api:
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")
10 changes: 10 additions & 0 deletions mii/event_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import asyncio
import threading

global event_loop
event_loop = asyncio.get_event_loop()
threading.Thread(target=event_loop.run_forever, daemon=True).start()


def get_event_loop():
return event_loop
16 changes: 6 additions & 10 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mii.method_table import GRPC_METHOD_TABLE
from mii.client import create_channel
from mii.utils import get_task
from mii.event_loop import get_event_loop


class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer):
Expand All @@ -41,6 +42,7 @@ def __init__(self, inference_pipeline):
super().__init__()
self.inference_pipeline = inference_pipeline
self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()}
self.lock = threading.Lock()

def _get_model_time(self, model, sum_times=False):
model_times = []
Expand Down Expand Up @@ -71,7 +73,8 @@ def _run_inference(self, method_name, request_proto):
args, kwargs = conversions["unpack_request_from_proto"](request_proto)

start = time.time()
response = self.inference_pipeline(*args, **kwargs)
with self.lock:
response = self.inference_pipeline(*args, **kwargs)
end = time.time()

model_time = self._get_model_time(self.inference_pipeline.model,
Expand Down Expand Up @@ -133,7 +136,7 @@ def __init__(self, host, ports):
stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.stubs.append(stub)

self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()

async def _invoke_async(self, method_name, proto_request):
responses = []
Expand All @@ -153,7 +156,7 @@ def invoke(self, method_name, proto_request):
class LoadBalancingInterceptor(grpc.ServerInterceptor):
def __init__(self, task_name, replica_configs):
super().__init__()
self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()

self.stubs = [
ParallelStubInvoker(replica.hostname,
Expand All @@ -163,13 +166,6 @@ def __init__(self, task_name, replica_configs):
self.counter = AtomicCounter()
self.task = get_task(task_name)

# Start the asyncio loop in a separate thread
def run_asyncio_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()

threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start()

def choose_stub(self, call_count):
return self.stubs[call_count % len(self.stubs)]

Expand Down

0 comments on commit 9f2c608

Please sign in to comment.