Skip to content

Commit

Permalink
Ensure a single instance of ImageBytesLoader throughout chat_models (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravP97 authored Feb 7, 2025
1 parent a0027c8 commit ef5932d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 11 deletions.
17 changes: 13 additions & 4 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations # noqa
import ast
from functools import cached_property
import json
import logging
from dataclasses import dataclass, field
Expand Down Expand Up @@ -198,7 +199,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:

def _parse_chat_history_gemini(
history: List[BaseMessage],
project: Optional[str] = None,
imageBytesLoader: ImageBytesLoader,
convert_system_message_to_human: Optional[bool] = False,
) -> tuple[Content | None, list[Content]]:
def _convert_to_prompt(part: Union[str, Dict]) -> Optional[Part]:
Expand All @@ -218,7 +219,7 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Optional[Part]:
return None
if part["type"] == "image_url":
path = part["image_url"]["url"]
return ImageBytesLoader(project=project).load_gapic_part(path)
return imageBytesLoader.load_gapic_part(path)

# Handle media type like LangChain.js
# https://github.com/langchain-ai/langchainjs/blob/e536593e2585f1dd7b0afc187de4d07cb40689ba/libs/langchain-google-common/src/utils/gemini.ts#L93-L106
Expand Down Expand Up @@ -1107,6 +1108,10 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]

@cached_property
def _image_bytes_loader_client(self):
return ImageBytesLoader(project=self.project)

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that the python package exists in environment."""
Expand Down Expand Up @@ -1324,7 +1329,9 @@ def _prepare_request_gemini(
logprobs: Optional[Union[int, bool]] = None,
**kwargs,
) -> GenerateContentRequest:
system_instruction, contents = _parse_chat_history_gemini(messages)
system_instruction, contents = _parse_chat_history_gemini(
messages, self._image_bytes_loader_client
)
formatted_tools = self._tools_gemini(tools=tools, functions=functions)
if tool_config:
tool_config = self._tool_config_gemini(tool_config=tool_config)
Expand Down Expand Up @@ -1445,7 +1452,9 @@ def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
if self._is_gemini_model:
# https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#counttokensrequest
_, contents = _parse_chat_history_gemini([HumanMessage(content=text)])
_, contents = _parse_chat_history_gemini(
[HumanMessage(content=text)], self._image_bytes_loader_client
)
response = self.prediction_client.count_tokens(
{
"endpoint": self.full_model_name,
Expand Down
5 changes: 4 additions & 1 deletion libs/vertexai/langchain_google_vertexai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.messages import BaseMessage
from vertexai.preview import caching # type: ignore

from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history_gemini,
Expand Down Expand Up @@ -54,7 +55,9 @@ def create_context_cache(
error_msg = f"Model {model.full_model_name} doesn't support context catching"
raise ValueError(error_msg)

system_instruction, contents = _parse_chat_history_gemini(messages, model.project)
system_instruction, contents = _parse_chat_history_gemini(
messages, ImageBytesLoader(project=model.project)
)

if tool_config:
tool_config = _format_tool_config(tool_config)
Expand Down
6 changes: 5 additions & 1 deletion libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
HarmCategory,
create_context_cache,
)
from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai.chat_models import _parse_chat_history_gemini
from tests.integration_tests.conftest import _DEFAULT_MODEL_NAME

Expand Down Expand Up @@ -341,7 +342,10 @@ def test_parse_history_gemini_multimodal_FC():
Part(text=instruction),
]
expected = [Content(role="user", parts=parts)]
_, response = _parse_chat_history_gemini(history=history)
imageBytesLoader = ImageBytesLoader()
_, response = _parse_chat_history_gemini(
history=history, imageBytesLoader=imageBytesLoader
)
assert expected == response


Expand Down
22 changes: 17 additions & 5 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
InputOutputTextPair,
)

from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
Expand Down Expand Up @@ -238,7 +239,10 @@ def test_parse_history_gemini() -> None:
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
system_instructions, history = _parse_chat_history_gemini(messages)
image_bytes_loader = ImageBytesLoader()
system_instructions, history = _parse_chat_history_gemini(
messages, image_bytes_loader
)
assert len(history) == 3
assert history[0].role == "user"
assert history[0].parts[0].text == text_question1
Expand All @@ -256,8 +260,9 @@ def test_parse_history_gemini_converted_message() -> None:
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
image_bytes_loader = ImageBytesLoader()
_, history = _parse_chat_history_gemini(
messages, convert_system_message_to_human=True
messages, image_bytes_loader, convert_system_message_to_human=True
)
assert len(history) == 3
assert history[0].role == "user"
Expand Down Expand Up @@ -323,7 +328,10 @@ def test_parse_history_gemini_function() -> None:
message6,
message7,
]
system_instructions, history = _parse_chat_history_gemini(messages)
image_bytes_loader = ImageBytesLoader()
system_instructions, history = _parse_chat_history_gemini(
messages, image_bytes_loader
)
assert len(history) == 6
assert system_instructions and system_instructions.parts[0].text == system_input
assert history[0].role == "user"
Expand Down Expand Up @@ -529,7 +537,10 @@ def test_parse_history_gemini_function() -> None:
def test_parse_history_gemini_multi(
source_history, expected_sm, expected_history
) -> None:
sm, result_history = _parse_chat_history_gemini(history=source_history)
image_bytes_loader = ImageBytesLoader()
sm, result_history = _parse_chat_history_gemini(
history=source_history, imageBytesLoader=image_bytes_loader
)

for result, expected in zip(result_history, expected_history):
assert result == expected
Expand Down Expand Up @@ -1000,7 +1011,8 @@ def test_multiple_fc() -> None:
content='{"condition": "rainy", "temp_c": 25.2}',
),
]
_, history = _parse_chat_history_gemini(raw_history)
image_bytes_loader = ImageBytesLoader()
_, history = _parse_chat_history_gemini(raw_history, image_bytes_loader)
expected = [
Content(
parts=[Part(text=prompt)],
Expand Down

0 comments on commit ef5932d

Please sign in to comment.