Skip to content

Commit

Permalink
vertexai: add retry tracing with metadata in ChatAnthropicVertex (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgesuellip authored Feb 12, 2025
1 parent ef5932d commit b341070
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 7 deletions.
68 changes: 62 additions & 6 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
convert_to_anthropic_tool,
)
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon
from langchain_google_vertexai.utils import create_base_retry_decorator


class CacheUsageMetadata(UsageMetadata):
Expand All @@ -69,6 +70,31 @@ class CacheUsageMetadata(UsageMetadata):
"""The number of input tokens read from the cache."""


def _create_retry_decorator(
llm: ChatAnthropicVertex,
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Anthropic Vertex LLMs with proper tracing."""
from anthropic import ( # type: ignore[unused-ignore, import-not-found]
APIError,
APITimeoutError,
RateLimitError,
)

errors = [
APIError,
APITimeoutError,
RateLimitError,
]

return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""

Expand Down Expand Up @@ -143,6 +169,9 @@ class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
access_token: Optional[str] = None
stream_usage: bool = True # Whether to include usage metadata in streaming output
credentials: Optional[Credentials] = None
max_retries: int = Field(
default=3, description="Number of retries for error handling."
)

model_config = ConfigDict(
populate_by_name=True,
Expand All @@ -164,17 +193,18 @@ def validate_environment(self) -> Self:

project_id: str = self.project

# Always disable Anthropic's retries, we handle it using the retry decorator
self.client = AnthropicVertex(
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
max_retries=0,
access_token=self.access_token,
credentials=self.credentials,
)
self.async_client = AsyncAnthropicVertex(
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
max_retries=0,
access_token=self.access_token,
credentials=self.credentials,
)
Expand Down Expand Up @@ -250,13 +280,20 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Run the LLM on the given prompt and input."""
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self.client.messages.create(**params)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

@retry_decorator
def _completion_with_retry_inner(**params: Any) -> Any:
return self.client.messages.create(**params)

data = _completion_with_retry_inner(**params)
return self._format_output(data, **kwargs)

async def _agenerate(
Expand All @@ -266,13 +303,20 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Run the LLM on the given prompt and input."""
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self.async_client.messages.create(**params)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

@retry_decorator
async def _acompletion_with_retry_inner(**params: Any) -> Any:
return await self.async_client.messages.create(**params)

data = await _acompletion_with_retry_inner(**params)
return self._format_output(data, **kwargs)

@property
Expand All @@ -292,7 +336,13 @@ def _stream(
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
stream = self.client.messages.create(**params, stream=True)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

@retry_decorator
def _stream_with_retry(**params: Any) -> Any:
return self.client.messages.create(**params, stream=True)

stream = _stream_with_retry(**params)
coerce_content_to_string = not _tools_in_params(params)
for event in stream:
msg = _make_message_chunk_from_anthropic_event(
Expand All @@ -318,7 +368,13 @@ async def _astream(
if stream_usage is None:
stream_usage = self.stream_usage
params = self._format_params(messages=messages, stop=stop, **kwargs)
stream = await self.async_client.messages.create(**params, stream=True)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

@retry_decorator
async def _astream_with_retry(**params: Any) -> Any:
return await self.async_client.messages.create(**params, stream=True)

stream = await _astream_with_retry(**params)
coerce_content_to_string = not _tools_in_params(params)
async for event in stream:
msg = _make_message_chunk_from_anthropic_event(
Expand Down
84 changes: 83 additions & 1 deletion libs/vertexai/langchain_google_vertexai/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
import asyncio
import logging
from datetime import datetime, timedelta
from typing import List, Optional
from typing import Any, Callable, List, Optional, Union

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.messages import BaseMessage
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from vertexai.preview import caching # type: ignore

from langchain_google_vertexai._image_utils import ImageBytesLoader
Expand Down Expand Up @@ -76,3 +91,70 @@ def create_context_cache(
)

return cached_content.name


def create_base_retry_decorator(
error_types: list[type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided a list of error types.
Args:
error_types: List of error types to retry on.
max_retries: Number of retries. Default is 1.
run_manager: Callback manager for the run. Default is None.
Returns:
A retry decorator.
"""
logger = logging.getLogger(__name__)
_logging = before_sleep_log(logger, logging.WARNING)

def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
retry_d: dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
if retry_state.outcome is None:
retry_d["outcome"] = "N/A"
elif retry_state.outcome.failed:
retry_d["outcome"] = "failed"
exception = retry_state.outcome.exception()
retry_d["exception"] = str(exception)
retry_d["exception_type"] = exception.__class__.__name__
else:
retry_d["outcome"] = "success"
retry_d["result"] = str(retry_state.outcome.result())
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
logger.error(f"Error in on_retry: {e}")
else:
run_manager.metadata.update({"retry_state": retry_d})
run_manager.on_retry(retry_state)

min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: retry_base = retry_if_exception_type(error_types[0])
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(error)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance,
before_sleep=_before_sleep,
)
100 changes: 100 additions & 0 deletions libs/vertexai/tests/unit_tests/test_model_garden_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from unittest.mock import MagicMock

import pytest
from anthropic import APIError

from langchain_google_vertexai.model_garden import (
_create_retry_decorator,
)


def create_api_error():
"""Helper function to create an APIError with required arguments."""
mock_request = MagicMock()
mock_request.method = "POST"
mock_request.url = "test-url"
mock_request.headers = {}
mock_request.body = None
return APIError(
message="Test error",
request=mock_request,
body={"error": {"message": "Test error"}},
)


def test_retry_on_errors():
"""Test that the retry decorator works with sync functions."""
mock_llm = MagicMock()
mock_llm.max_retries = 2
mock_function = MagicMock(side_effect=[create_api_error(), "success"])

decorator = _create_retry_decorator(mock_llm)
wrapped_func = decorator(mock_function)

result = wrapped_func()
assert result == "success"
assert mock_function.call_count == 2


def test_max_retries_exceeded():
"""Test that the retry decorator fails after max retries."""
mock_llm = MagicMock()
mock_llm.max_retries = 2
mock_function = MagicMock(side_effect=[create_api_error(), create_api_error()])

decorator = _create_retry_decorator(mock_llm)
wrapped_func = decorator(mock_function)

with pytest.raises(APIError):
wrapped_func()
assert mock_function.call_count == 2


@pytest.mark.asyncio
async def test_async_retry_on_errors():
"""Test that the retry decorator works with async functions."""
mock_llm = MagicMock()
mock_llm.max_retries = 2

class AsyncMock:
def __init__(self):
self.call_count = 0

async def __call__(self):
self.call_count += 1
if self.call_count == 1:
raise create_api_error()
return "success"

mock_async = AsyncMock()

decorator = _create_retry_decorator(mock_llm)
wrapped_func = decorator(mock_async)

result = await wrapped_func()
assert result == "success"
assert mock_async.call_count == 2


@pytest.mark.asyncio
async def test_async_max_retries_exceeded():
"""Test that the async retry decorator fails after max retries."""
mock_llm = MagicMock()
mock_llm.max_retries = 2

class AsyncMock:
def __init__(self):
self.call_count = 0

async def __call__(self):
self.call_count += 1
raise create_api_error()

mock_async = AsyncMock()

decorator = _create_retry_decorator(mock_llm)
wrapped_func = decorator(mock_async)

with pytest.raises(APIError):
await wrapped_func()
assert mock_async.call_count == 2

0 comments on commit b341070

Please sign in to comment.