Skip to content

Commit

Permalink
Add support to images in requests
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed Nov 4, 2024
1 parent cb1f244 commit 24e6527
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
25 changes: 22 additions & 3 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncGenerator, Dict, List, Optional
import io, base64

from loguru import logger
from openai import AsyncOpenAI, OpenAI
Expand Down Expand Up @@ -103,11 +104,11 @@ async def make_request(

request_args.update(self._request_args)

messages = self._build_messages(request)

stream = await self._async_client.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": request.prompt},
],
messages=messages,
stream=True,
**request_args,
)
Expand Down Expand Up @@ -167,3 +168,21 @@ def validate_connection(self):
except Exception as error:
logger.error("Failed to validate OpenAI connection: {}", error)
raise error

def _build_messages(self, request: TextGenerationRequest) -> Dict:
if request.number_images == 0:
messages = [{"role": "user", "content": request.prompt}]
else:
content = []
for image in request.images:
stream = io.BytesIO()
im_format = image.image.format or "PNG"
image.image.save(stream, format=im_format)
im_b64 = base64.b64encode(stream.getvalue()).decode("ascii")
image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"}
content.append({"type": "image_url", "image_url": image_url})

content.append({"type": "text", "text": request.prompt})
messages = [{"role": "user", "content": content}]

return messages
15 changes: 14 additions & 1 deletion src/guidellm/core/request.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import uuid
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List

from pydantic import Field

from guidellm.core.serializable import Serializable
from guidellm.utils import ImageDescriptor


class TextGenerationRequest(Serializable):
Expand All @@ -16,6 +17,10 @@ class TextGenerationRequest(Serializable):
description="The unique identifier for the request.",
)
prompt: str = Field(description="The input prompt for the text generation.")
images: Optional[List[ImageDescriptor]] = Field(
default=None,
description="Input images.",
)
prompt_token_count: Optional[int] = Field(
default=None,
description="The number of tokens in the input prompt.",
Expand All @@ -29,6 +34,13 @@ class TextGenerationRequest(Serializable):
description="The parameters for the text generation request.",
)

@property
def number_images(self) -> int:
if self.images is None:
return 0
else:
return len(self.images)

def __str__(self) -> str:
prompt_short = (
self.prompt[:32] + "..."
Expand All @@ -41,4 +53,5 @@ def __str__(self) -> str:
f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, "
f"output_token_count={self.output_token_count}, "
f"params={self.params})"
f"images={self.number_images}"
)
17 changes: 15 additions & 2 deletions src/guidellm/request/emulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from guidellm.config import settings
from guidellm.core.request import TextGenerationRequest
from guidellm.request.base import GenerationMode, RequestGenerator
from guidellm.utils import clean_text, filter_text, load_text, split_text
from guidellm.utils import clean_text, filter_text, load_text, split_text, load_images

__all__ = ["EmulatedConfig", "EmulatedRequestGenerator", "EndlessTokens"]

Expand All @@ -30,6 +30,7 @@ class EmulatedConfig:
generated_tokens_variance (Optional[int]): Variance for generated tokens.
generated_tokens_min (Optional[int]): Minimum number of generated tokens.
generated_tokens_max (Optional[int]): Maximum number of generated tokens.
images (Optional[int]): Number of input images.
"""

@staticmethod
Expand All @@ -47,7 +48,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
"""
if not config:
logger.debug("Creating default configuration")
return EmulatedConfig(prompt_tokens=1024, generated_tokens=256)
return EmulatedConfig(prompt_tokens=1024, generated_tokens=256, images=0)

if isinstance(config, dict):
logger.debug("Loading configuration from dict: {}", config)
Expand Down Expand Up @@ -105,6 +106,8 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
generated_tokens_min: Optional[int] = None
generated_tokens_max: Optional[int] = None

images: int = 0

@property
def prompt_tokens_range(self) -> Tuple[int, int]:
"""
Expand Down Expand Up @@ -327,6 +330,8 @@ def __init__(
settings.emulated_data.filter_start,
settings.emulated_data.filter_end,
)
if self._config.images > 0:
self._images = load_images(settings.emulated_data.image_source)
self._rng = np.random.default_rng(random_seed)

# NOTE: Must be after all the parameters since the queue population
Expand Down Expand Up @@ -355,6 +360,7 @@ def create_item(self) -> TextGenerationRequest:
logger.debug("Creating new text generation request")
target_prompt_token_count = self._config.sample_prompt_tokens(self._rng)
prompt = self.sample_prompt(target_prompt_token_count)
images = self.sample_images()
prompt_token_count = len(self.tokenizer.tokenize(prompt))
output_token_count = self._config.sample_output_tokens(self._rng)
logger.debug("Generated prompt: {}", prompt)
Expand All @@ -363,6 +369,7 @@ def create_item(self) -> TextGenerationRequest:
prompt=prompt,
prompt_token_count=prompt_token_count,
output_token_count=output_token_count,
images=images,
)

def sample_prompt(self, tokens: int) -> str:
Expand Down Expand Up @@ -395,3 +402,9 @@ def sample_prompt(self, tokens: int) -> str:
right = mid

return self._tokens.create_text(start_line_index, left)


def sample_images(self):
image_indices = self._rng.choice(len(self._images), size=self._config.images, replace=False)

return [self._images[i] for i in image_indices]

0 comments on commit 24e6527

Please sign in to comment.