Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal #66

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4312cb4
Add mean and percentile info as computed_field properties such that t…
anmarques Sep 4, 2024
46e1076
quality fixes
anmarques Sep 5, 2024
65fafde
quality fix
anmarques Sep 5, 2024
cc8d2c6
Quality fixes
anmarques Sep 6, 2024
bb9bc0c
Add class to describe image samples and loading logic for images from…
anmarques Nov 4, 2024
59002b5
Add class to describe image samples and loading logic for images from…
anmarques Nov 4, 2024
cb1f244
Add url used to download images from for emulated requests
anmarques Nov 4, 2024
24e6527
Add support to images in requests
anmarques Nov 4, 2024
3946709
quality fixes
anmarques Nov 4, 2024
7d93b02
Quality fixes
anmarques Nov 4, 2024
a441dad
Quality fixes
anmarques Nov 4, 2024
570670b
Quality fixes
anmarques Nov 5, 2024
984da28
Add new dependencies
anmarques Nov 5, 2024
355f368
Allow images to be resized to specific resolution
anmarques Dec 3, 2024
43f14d4
Ignore EOS
anmarques Dec 4, 2024
d9819e9
Ignore EOS
anmarques Dec 4, 2024
89e8c6b
Merge branch 'output_summary' into multimodal
anmarques Dec 5, 2024
503a56c
Add image processing dependencies
anmarques Dec 5, 2024
ffcb28d
Fix support to images
anmarques Dec 5, 2024
6106a71
Fix serialization
anmarques Dec 5, 2024
8171820
Fix image registration
anmarques Dec 5, 2024
e845510
Fix pydantic format
anmarques Dec 5, 2024
d1ad0f8
Use resized image
anmarques Dec 5, 2024
40e8e92
Update pyproject.toml
anmarques Dec 6, 2024
0d8eb2f
Update pyproject.toml
anmarques Dec 6, 2024
511d3cb
Update .pre-commit-config.yaml
anmarques Dec 6, 2024
bca2614
Adds aiohttp backend
anmarques Dec 6, 2024
1f7a638
Merge branch 'output_summary' into http_backend
anmarques Dec 6, 2024
b9751d0
Merge branch 'http_backend' into multimodal
anmarques Dec 11, 2024
72db6a4
Add support for aiohttp backend
anmarques Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/guidellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# flake8: noqa

import os

import transformers # type: ignore

os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers
Expand Down
2 changes: 2 additions & 0 deletions src/guidellm/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .base import Backend, BackendEngine, BackendEnginePublic, GenerativeResponse
from .openai import OpenAIBackend
from .aiohttp import AiohttpBackend

__all__ = [
"Backend",
"BackendEngine",
"BackendEnginePublic",
"GenerativeResponse",
"OpenAIBackend",
"AiohttpBackend"
]
180 changes: 180 additions & 0 deletions src/guidellm/backend/aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import base64
import io
from typing import AsyncGenerator, Dict, List, Optional
from loguru import logger

import aiohttp
import json

from guidellm.backend.base import Backend, GenerativeResponse
from guidellm.config import settings
from guidellm.core import TextGenerationRequest

__all__ = ["AiohttpBackend"]

@Backend.register("aiohttp_server")
class AiohttpBackend(Backend):
"""
An aiohttp-based backend implementation for LLM requests.

This class provides an interface to communicate with a server hosting
an LLM API using aiohttp for asynchronous requests.
"""

def __init__(
self,
openai_api_key: Optional[str] = None,
target: Optional[str] = None,
model: Optional[str] = None,
timeout: Optional[float] = None,
**request_args,
):
self._request_args: Dict = request_args
self._api_key: str = openai_api_key or settings.aiohttp.api_key

if not self._api_key:
err = ValueError(
"`GUIDELLM__AIOHTTP__API_KEY` environment variable or "
"--openai-api-key CLI parameter must be specified for the "
"aiohttp backend."
)
logger.error("{}", err)
raise err

base_url = target or settings.aiohttp.base_url
self._api_url = f"{base_url}/chat/completions"

if not base_url:
err = ValueError(
"`GUIDELLM__AIOHTTP__BASE_URL` environment variable or "
"target parameter must be specified for the OpenAI backend."
)
logger.error("{}", err)
raise err

self._timeout = aiohttp.ClientTimeout(total=timeout or settings.request_timeout)
self._model = model

super().__init__(type_="aiohttp_backend", target=base_url, model=self._model)
logger.info("aiohttp {} Backend listening on {}", self._model, base_url)

async def make_request(
self,
request: TextGenerationRequest,
) -> AsyncGenerator[GenerativeResponse, None]:
"""
Make a request to the aiohttp backend.

Sends a prompt to the LLM server and streams the response tokens.

:param request: The text generation request to submit.
:type request: TextGenerationRequest
:yield: A stream of GenerativeResponse objects.
:rtype: AsyncGenerator[GenerativeResponse, None]
"""

async with aiohttp.ClientSession(timeout=self._timeout) as session:
logger.debug("Making request to aiohttp backend with prompt: {}", request.prompt)

request_args = {}
if request.output_token_count is not None:
request_args.update(
{
"max_completion_tokens": request.output_token_count,
"stop": None,
"ignore_eos": True,
}
)
elif settings.aiohttp.max_gen_tokens and settings.aiohttp.max_gen_tokens > 0:
request_args.update(
{
"max_tokens": settings.aiohttp.max_gen_tokens,
}
)

request_args.update(self._request_args)

messages = self._build_messages(request)

payload = {
"model": self._model,
"messages": messages,
"stream": True,
**request_args,
}

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
}

try:
async with session.post(url=self._api_url, json=payload, headers=headers) as response:
if response.status != 200:
error_message = await response.text()
logger.error("Request failed: {} - {}", response.status, error_message)
raise Exception(f"Failed to generate response: {error_message}")

token_count = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk == "[DONE]":
# Final response
yield GenerativeResponse(
type_="final",
prompt=request.prompt,
output_token_count=token_count,
prompt_token_count=request.prompt_token_count,
)
else:
# Intermediate token response
token_count += 1
data = json.loads(chunk)
delta = data["choices"][0]["delta"]
token = delta["content"]
yield GenerativeResponse(
type_="token_iter",
add_token=token,
prompt=request.prompt,
output_token_count=token_count,
prompt_token_count=request.prompt_token_count,
)
except Exception as e:
logger.error("Error while making request: {}", e)
raise

def available_models(self) -> List[str]:
"""
Retrieve a list of available models from the server.
"""
# This could include an API call to `self._api_url/models` if the server supports it.
logger.warning("Fetching available models is not implemented for aiohttp backend.")
return []

def validate_connection(self):
"""
Validate the connection to the backend server.
"""
logger.info("Connection validation is not explicitly implemented for aiohttp backend.")

def _build_messages(self, request: TextGenerationRequest) -> Dict:
if request.number_images == 0:
messages = [{"role": "user", "content": request.prompt}]
else:
content = []
for image in request.images:
stream = io.BytesIO()
im_format = image.image.format or "PNG"
image.image.save(stream, format=im_format)
im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8")
image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"}
content.append({"type": "image_url", "image_url": image_url})

content.append({"type": "text", "text": request.prompt})
messages = [{"role": "user", "content": content}]

return messages
2 changes: 1 addition & 1 deletion src/guidellm/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
__all__ = ["Backend", "BackendEngine", "BackendEnginePublic", "GenerativeResponse"]


BackendEnginePublic = Literal["openai_server"]
BackendEnginePublic = Literal["openai_server", "aiohttp_server"]
BackendEngine = Union[BackendEnginePublic, Literal["test"]]


Expand Down
29 changes: 26 additions & 3 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import io
from typing import AsyncGenerator, Dict, List, Optional

from loguru import logger
Expand Down Expand Up @@ -92,6 +94,9 @@ async def make_request(
{
"max_tokens": request.output_token_count,
"stop": None,
"extra_body": {
"ignore_eos": True,
}
}
)
elif settings.openai.max_gen_tokens and settings.openai.max_gen_tokens > 0:
Expand All @@ -103,11 +108,11 @@ async def make_request(

request_args.update(self._request_args)

messages = self._build_messages(request)

stream = await self._async_client.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": request.prompt},
],
messages=messages,
stream=True,
**request_args,
)
Expand Down Expand Up @@ -167,3 +172,21 @@ def validate_connection(self):
except Exception as error:
logger.error("Failed to validate OpenAI connection: {}", error)
raise error

def _build_messages(self, request: TextGenerationRequest) -> Dict:
if request.number_images == 0:
messages = [{"role": "user", "content": request.prompt}]
else:
content = []
for image in request.images:
stream = io.BytesIO()
im_format = image.image.format or "PNG"
image.image.save(stream, format=im_format)
im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8")
image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"}
content.append({"type": "image_url", "image_url": image_url})

content.append({"type": "text", "text": request.prompt})
messages = [{"role": "user", "content": content}]

return messages
5 changes: 5 additions & 0 deletions src/guidellm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class EmulatedDataSettings(BaseModel):
"force_new_line_punctuation": True,
}
)
image_source: List[str] = "https://www.gutenberg.org/cache/epub/1342/pg1342-images.html"


class OpenAISettings(BaseModel):
Expand All @@ -108,6 +109,9 @@ class OpenAISettings(BaseModel):
max_gen_tokens: int = 4096


class AiohttpSettings(OpenAISettings):
pass

class ReportGenerationSettings(BaseModel):
"""
Report generation settings for the application
Expand Down Expand Up @@ -152,6 +156,7 @@ class Settings(BaseSettings):

# Request settings
openai: OpenAISettings = OpenAISettings()
aiohttp: AiohttpSettings = AiohttpSettings()

# Report settings
report_generation: ReportGenerationSettings = ReportGenerationSettings()
Expand Down
26 changes: 8 additions & 18 deletions src/guidellm/core/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,15 @@ def _create_benchmark_report_data_tokens_summary(
for benchmark in report.benchmarks_sorted:
table.add_row(
_benchmark_rate_id(benchmark),
f"{benchmark.prompt_token_distribution.mean:.2f}",
f"{benchmark.prompt_token:.2f}",
", ".join(
f"{percentile:.1f}"
for percentile in benchmark.prompt_token_distribution.percentiles(
[1, 5, 50, 95, 99]
)
for percentile in benchmark.prompt_token_percentiles
),
f"{benchmark.output_token_distribution.mean:.2f}",
f"{benchmark.output_token:.2f}",
", ".join(
f"{percentile:.1f}"
for percentile in benchmark.output_token_distribution.percentiles(
[1, 5, 50, 95, 99]
)
for percentile in benchmark.output_token_percentiles
),
)
logger.debug("Created data tokens summary table for the report.")
Expand All @@ -181,7 +177,7 @@ def _create_benchmark_report_dist_perf_summary(
"Benchmark",
"Request Latency [1%, 5%, 10%, 50%, 90%, 95%, 99%] (sec)",
"Time to First Token [1%, 5%, 10%, 50%, 90%, 95%, 99%] (ms)",
"Inter Token Latency [1%, 5%, 10%, 50%, 90% 95%, 99%] (ms)",
"Inter Token Latency [1%, 5%, 10%, 50%, 90%, 95%, 99%] (ms)",
title="[magenta]Performance Stats by Benchmark[/magenta]",
title_style="bold",
title_justify="left",
Expand All @@ -193,21 +189,15 @@ def _create_benchmark_report_dist_perf_summary(
_benchmark_rate_id(benchmark),
", ".join(
f"{percentile:.2f}"
for percentile in benchmark.request_latency_distribution.percentiles(
[1, 5, 10, 50, 90, 95, 99]
)
for percentile in benchmark.request_latency_percentiles
),
", ".join(
f"{percentile * 1000:.1f}"
for percentile in benchmark.ttft_distribution.percentiles(
[1, 5, 10, 50, 90, 95, 99]
)
for percentile in benchmark.time_to_first_token_percentiles
),
", ".join(
f"{percentile * 1000:.1f}"
for percentile in benchmark.itl_distribution.percentiles(
[1, 5, 10, 50, 90, 95, 99]
)
for percentile in benchmark.inter_token_latency_percentiles
),
)
logger.debug("Created distribution performance summary table for the report.")
Expand Down
23 changes: 22 additions & 1 deletion src/guidellm/core/request.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import uuid
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple

from pydantic import Field

from guidellm.core.serializable import Serializable
from guidellm.utils import ImageDescriptor


class TextGenerationRequest(Serializable):
Expand All @@ -16,6 +17,10 @@ class TextGenerationRequest(Serializable):
description="The unique identifier for the request.",
)
prompt: str = Field(description="The input prompt for the text generation.")
images: Optional[List[ImageDescriptor]] = Field(
default=None,
description="Input images.",
)
prompt_token_count: Optional[int] = Field(
default=None,
description="The number of tokens in the input prompt.",
Expand All @@ -29,6 +34,21 @@ class TextGenerationRequest(Serializable):
description="The parameters for the text generation request.",
)

@property
def number_images(self) -> int:
if self.images is None:
return 0
else:
return len(self.images)

@property
def image_resolution(self) -> List[Tuple[int, int]]:
if self.images is None:
return None
else:
return [im.size for im in self.images]


def __str__(self) -> str:
prompt_short = (
self.prompt[:32] + "..."
Expand All @@ -41,4 +61,5 @@ def __str__(self) -> str:
f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, "
f"output_token_count={self.output_token_count}, "
f"params={self.params})"
f"image_resolution={self.image_resolution}"
)
Loading
Loading