Skip to content

Commit

Permalink
langgraph: bring back tool content stringify (#1114)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Jul 24, 2024
1 parent 82cbe25 commit ca6aef4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
25 changes: 23 additions & 2 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from copy import copy
from typing import (
Any,
Expand Down Expand Up @@ -28,6 +29,16 @@
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."


def str_output(output: Any) -> str:
if isinstance(output, str):
return output
else:
try:
return json.dumps(output)
except Exception:
return str(output)


class ToolNode(RunnableCallable):
"""A node that runs the tools called in the last AIMessage.
Expand Down Expand Up @@ -94,7 +105,12 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:

try:
input = {**call, **{"type": "tool_call"}}
return self.tools_by_name[call["name"]].invoke(input, config)
tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke(
input, config
)
# TODO: handle this properly in core
tool_message.content = str_output(tool_message.content)
return tool_message
except Exception as e:
if not self.handle_tool_errors:
raise e
Expand All @@ -106,7 +122,12 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage
return invalid_tool_message
try:
input = {**call, **{"type": "tool_call"}}
return await self.tools_by_name[call["name"]].ainvoke(input, config)
tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke(
input, config
)
# TODO: handle this properly in core
tool_message.content = str_output(tool_message.content)
return tool_message
except Exception as e:
if not self.handle_tool_errors:
raise e
Expand Down
32 changes: 32 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ async def tool2(some_val: int, some_other_val: str) -> str:
raise ValueError("Test error")
return f"tool2: {some_val} - {some_other_val}"

async def tool3(some_val: int, some_other_val: str) -> str:
"""Tool 3 docstring."""
return [
{"key_1": some_val, "key_2": "foo"},
{"key_1": some_other_val, "key_2": "baz"},
]

result = ToolNode([tool1]).invoke(
{
"messages": [
Expand Down Expand Up @@ -377,6 +384,31 @@ async def tool2(some_val: int, some_other_val: str) -> str:
)
assert tool_message.tool_call_id == "some 0"

# list of dicts tool content
result3 = await ToolNode([tool3]).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool3",
"args": {"some_val": 2, "some_other_val": "bar"},
"id": "some 0",
}
],
)
]
}
)
tool_message: ToolMessage = result3["messages"][-1]
assert tool_message.type == "tool"
assert (
tool_message.content
== '[{"key_1": 2, "key_2": "foo"}, {"key_1": "bar", "key_2": "baz"}]'
)
assert tool_message.tool_call_id == "some 0"


def my_function(some_val: int, some_other_val: str) -> str:
return f"{some_val} - {some_other_val}"
Expand Down

0 comments on commit ca6aef4

Please sign in to comment.