Skip to content

Commit

Permalink
langgraph: add names for tasks (#3202)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Jan 24, 2025
1 parent cf7f669 commit 6926c4b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
9 changes: 8 additions & 1 deletion libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

@overload
def task(
*, retry: Optional[RetryPolicy] = None
*, name: Optional[str] = None, retry: Optional[RetryPolicy] = None
) -> Callable[[Callable[P, T]], Callable[P, SyncAsyncFuture[T]]]: ...


Expand All @@ -50,6 +50,7 @@ def task(
def task(
__func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None,
*,
name: Optional[str] = None,
retry: Optional[RetryPolicy] = None,
) -> Union[
Callable[[Callable[P, T]], Callable[P, SyncAsyncFuture[T]]],
Expand Down Expand Up @@ -120,6 +121,12 @@ def decorator(
) -> Union[
Callable[P, concurrent.futures.Future[T]], Callable[P, asyncio.Future[T]]
]:
if name is not None:
if hasattr(func, "__func__"):
func.__func__.__name__ = name
else:
func.__name__ = name

call_func = functools.partial(call, func, retry=retry)
object.__setattr__(call_func, "_is_pregel_task", True)
return functools.update_wrapper(call_func, func)
Expand Down
30 changes: 28 additions & 2 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5570,10 +5570,10 @@ def graph(state: dict) -> dict:


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_interrupts_imperative(
def test_multiple_interrupts_functional(
request: pytest.FixtureRequest, checkpointer_name: str, snapshot: SnapshotAssertion
):
"""Test multiple interrupts with an imperative API."""
"""Test multiple interrupts with functional API."""
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

counter = 0
Expand Down Expand Up @@ -6265,3 +6265,29 @@ async def workflow(inputs: dict, *, previous: Any):
# Test with another thread
assert await workflow.ainvoke({}, {"configurable": {"thread_id": "2"}}) == "!"
assert previous_ is None


def test_named_tasks_functional() -> None:

class Foo:
def foo(self, state: dict) -> dict:
return "foo"

f = Foo()
foo = task(f.foo, name="custom_foo")

@task(name="custom_bar")
def bar(state: dict) -> dict:
return "bar"

@entrypoint()
def workflow(inputs: dict) -> dict:
fut_foo = foo(inputs)
fut_bar = bar(fut_foo.result())
return fut_bar.result()

assert list(workflow.stream({}, stream_mode="updates")) == [
{"custom_foo": "foo"},
{"custom_bar": "bar"},
{"workflow": "bar"},
]
31 changes: 29 additions & 2 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6812,8 +6812,8 @@ async def graph(state: dict) -> dict:

@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multiple_interrupts_imperative(checkpointer_name: str) -> None:
"""Test multiple interrupts with an imperative API."""
async def test_multiple_interrupts_functional(checkpointer_name: str) -> None:
"""Test multiple interrupts with functional API."""
from langgraph.func import entrypoint, task

counter = 0
Expand Down Expand Up @@ -7325,3 +7325,30 @@ async def foo(inputs, previous=None) -> Any:

assert list(await foo.ainvoke({"a": "1"}, config)) == ["a", "b"]
assert previous_return_values == [None]


@NEEDS_CONTEXTVARS
async def test_named_tasks_functional() -> None:

class Foo:
async def foo(self, state: dict) -> dict:
return "foo"

f = Foo()
foo = task(f.foo, name="custom_foo")

@task(name="custom_bar")
async def bar(state: dict) -> dict:
return "bar"

@entrypoint()
async def workflow(inputs: dict) -> dict:
foo_result = await foo(inputs)
bar_result = await bar(foo_result)
return bar_result

assert [c async for c in workflow.astream({}, stream_mode="updates")] == [
{"custom_foo": "foo"},
{"custom_bar": "bar"},
{"workflow": "bar"},
]

0 comments on commit 6926c4b

Please sign in to comment.