From 3abf67a7caef5c9ca4c82756d30ea3ed8fea9173 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CGangmuk?= Date: Mon, 24 Feb 2025 17:57:21 -0800 Subject: [PATCH] Fixing wrong token count issue in streaming client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: “Gangmuk --- benchmarks/client/client.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/benchmarks/client/client.py b/benchmarks/client/client.py index 028ef010..7eb3a3d7 100755 --- a/benchmarks/client/client.py +++ b/benchmarks/client/client.py @@ -47,15 +47,18 @@ async def send_request_streaming(client: openai.AsyncOpenAI, first_response_time = asyncio.get_event_loop().time() output_text = chunk.choices[0].delta.content text_chunks.append(output_text) + if hasattr(chunk, 'usage') and chunk.usage is not None: + # For OpenAI, we expect to get complete usage stats, not partial ones to accumulate + # So we can safely overwrite previous values if they exist + if chunk.usage.prompt_tokens is not None: prompt_tokens = chunk.usage.prompt_tokens + if chunk.usage.completion_tokens is not None: output_tokens = chunk.usage.completion_tokens + if chunk.usage.total_tokens is not None: total_tokens = chunk.usage.total_tokens except Exception as stream_error: # Handle errors during streaming logging.error(f"Request {request_id}: Stream interrupted: {type(stream_error).__name__}: {str(stream_error)}") - # Still try to use what we've received so far - if not text_chunks: - raise # Re-raise if we got nothing at all response_text = "".join(text_chunks) response_time = asyncio.get_event_loop().time() @@ -216,7 +219,9 @@ def main(args): base_url=args.endpoint + "/v1", ) if args.routing_strategy is not None: - client.default_headers["routing-strategy"] = args.routing_strategy + client = client.with_options( + default_headers={"routing-strategy": args.routing_strategy} + ) if not args.streaming: logging.info("Using batch client") start_time = time.time()