Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] fixing bug in using AsyncOpenAI client (header setting, token counting, etc) #738

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 81 additions & 59 deletions benchmarks/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@

logging.basicConfig(level=logging.INFO)


async def send_request_streaming(client: openai.AsyncOpenAI,
model: str,
endpoint: str,
prompt: str,
output_file: str,
):
model: str,
endpoint: str,
prompt: str,
output_file: str,
request_id: int):
start_time = asyncio.get_event_loop().time()
first_response_time = None

try:
logging.info(f"Request {request_id}: Starting streaming request to {endpoint}")
stream = await client.chat.completions.create(
model=model,
messages=prompt,
Expand All @@ -30,74 +33,89 @@ async def send_request_streaming(client: openai.AsyncOpenAI,
stream=True,
stream_options={"include_usage": True},
)

text_chunks = []
prompt_tokens = 0
output_tokens = 0
total_tokens = 0

async for chunk in stream:
if chunk.choices:
if chunk.choices[0].delta.content is not None:
if not first_response_time:
first_response_time = asyncio.get_event_loop().time()
output_text = chunk.choices[0].delta.content
text_chunks.append(output_text)
prompt_tokens = chunk.usage.prompt_tokens
output_tokens = chunk.usage.completion_tokens
total_tokens = chunk.usage.total_tokens
response = "".join(text_chunks)
logging.info(result)
if response.status_code == 200:
response_time = asyncio.get_event_loop().time()
latency = response_time - start_time
throughput = output_tokens / latency
ttft = first_response_time - start_time
tpot = (response_time - first_response_time) / output_tokens
result = {
"status_code": response.status_code,
"input": prompt,
"output": response,
"prompt_tokens": prompt_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"latency": latency,
"throughput": throughput,
"start_time": start_time,
"current_time": asyncio.get_event_loop().time(),
"ttft": ttft,
"tpot": tpot,
}
else:
logging.error(f"Request failed status-code: {response.status_code}, raw response: {response.text}")
result = {
"status_code": response.status_code,
"input": prompt,
"output": response,
"prompt_tokens": prompt_tokens,
"output_tokens": None,
"total_tokens": total_tokens,
"latency": latency,
"throughput": None,
"start_time": start_time,
"current_time": asyncio.get_event_loop().time(),
"ttft": None,
"tpot": None,
}
try:
async for chunk in stream:
if chunk.choices:
if chunk.choices[0].delta.content is not None:
if not first_response_time:
first_response_time = asyncio.get_event_loop().time()
output_text = chunk.choices[0].delta.content
text_chunks.append(output_text)
if hasattr(chunk, 'usage') and chunk.usage is not None:
# For OpenAI, we expect to get complete usage stats, not partial ones to accumulate
# So we can safely overwrite previous values if they exist
if chunk.usage.prompt_tokens is not None:
prompt_tokens = chunk.usage.prompt_tokens
if chunk.usage.completion_tokens is not None:
output_tokens = chunk.usage.completion_tokens
if chunk.usage.total_tokens is not None:
total_tokens = chunk.usage.total_tokens
except Exception as stream_error:
# Handle errors during streaming
logging.error(f"Request {request_id}: Stream interrupted: {type(stream_error).__name__}: {str(stream_error)}")

response_text = "".join(text_chunks)
response_time = asyncio.get_event_loop().time()
latency = response_time - start_time
throughput = output_tokens / latency if output_tokens > 0 else 0
ttft = first_response_time - start_time if first_response_time else None
tpot = (response_time - first_response_time) / output_tokens if first_response_time and output_tokens > 0 else None

result = {
"request_id": request_id,
"status": "success",
"input": prompt,
"output": response_text,
"prompt_tokens": prompt_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"latency": latency,
"throughput": throughput,
"start_time": start_time,
"end_time": response_time,
"ttft": ttft,
"tpot": tpot,
}

# Write result to JSONL file
logging.info(f"Request {request_id}: Completed successfully. Tokens: {total_tokens}, Latency: {latency:.2f}s")
output_file.write(json.dumps(result) + "\n")
output_file.flush() # Ensure data is written immediately to the file
return result

except Exception as e:
logging.error(f"Error sending request to at {endpoint}: {str(e)}")
traceback.print_exc()
return None
error_time = asyncio.get_event_loop().time()
# Determine error type based on exception class
error_type = type(e).__name__

error_result = {
"request_id": request_id,
"status": "error",
"error_type": error_type,
"error_message": str(e),
"error_traceback": traceback.format_exc(),
"input": prompt,
"latency": error_time - start_time,
"start_time": start_time,
"end_time": error_time
}
logging.error(f"Request {request_id}: Error ({error_type}): {str(e)}")
output_file.write(json.dumps(error_result) + "\n")
output_file.flush()
return error_result

async def benchmark_streaming(client: openai.AsyncOpenAI,
endpoint: str,
model: str,
load_struct: List,
output_file: io.TextIOWrapper):

request_id = 0
batch_tasks = []
base_time = time.time()
num_requests = 0
Expand All @@ -116,8 +134,10 @@ async def benchmark_streaming(client: openai.AsyncOpenAI,
model = model,
endpoint = endpoint,
prompt = formatted_prompt,
output_file = output_file)
output_file = output_file,
request_id = request_id)
)
request_id += 1
batch_tasks.append(task)
num_requests += len(requests)
await asyncio.gather(*batch_tasks)
Expand Down Expand Up @@ -199,7 +219,9 @@ def main(args):
base_url=args.endpoint + "/v1",
)
if args.routing_strategy is not None:
client.default_headers["routing-strategy"] = args.routing_strategy
client = client.with_options(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

em. that's something we didn't review clearly. good fix

default_headers={"routing-strategy": args.routing_strategy}
)
if not args.streaming:
logging.info("Using batch client")
start_time = time.time()
Expand Down