From 6926c4bcc050467c11235b744c4b37273373cd06 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Fri, 24 Jan 2025 16:42:54 -0500 Subject: [PATCH] langgraph: add names for tasks (#3202) --- libs/langgraph/langgraph/func/__init__.py | 9 ++++++- libs/langgraph/tests/test_pregel.py | 30 ++++++++++++++++++++-- libs/langgraph/tests/test_pregel_async.py | 31 +++++++++++++++++++++-- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index d0f429776..9daec2333 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -37,7 +37,7 @@ @overload def task( - *, retry: Optional[RetryPolicy] = None + *, name: Optional[str] = None, retry: Optional[RetryPolicy] = None ) -> Callable[[Callable[P, T]], Callable[P, SyncAsyncFuture[T]]]: ... @@ -50,6 +50,7 @@ def task( def task( __func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None, *, + name: Optional[str] = None, retry: Optional[RetryPolicy] = None, ) -> Union[ Callable[[Callable[P, T]], Callable[P, SyncAsyncFuture[T]]], @@ -120,6 +121,12 @@ def decorator( ) -> Union[ Callable[P, concurrent.futures.Future[T]], Callable[P, asyncio.Future[T]] ]: + if name is not None: + if hasattr(func, "__func__"): + func.__func__.__name__ = name + else: + func.__name__ = name + call_func = functools.partial(call, func, retry=retry) object.__setattr__(call_func, "_is_pregel_task", True) return functools.update_wrapper(call_func, func) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index e7d2a8164..bd58f0b55 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5570,10 +5570,10 @@ def graph(state: dict) -> dict: @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) -def test_multiple_interrupts_imperative( +def test_multiple_interrupts_functional( request: pytest.FixtureRequest, checkpointer_name: str, snapshot: SnapshotAssertion ): - """Test multiple interrupts with an imperative API.""" + """Test multiple interrupts with functional API.""" checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") counter = 0 @@ -6265,3 +6265,29 @@ async def workflow(inputs: dict, *, previous: Any): # Test with another thread assert await workflow.ainvoke({}, {"configurable": {"thread_id": "2"}}) == "!" assert previous_ is None + + +def test_named_tasks_functional() -> None: + + class Foo: + def foo(self, state: dict) -> dict: + return "foo" + + f = Foo() + foo = task(f.foo, name="custom_foo") + + @task(name="custom_bar") + def bar(state: dict) -> dict: + return "bar" + + @entrypoint() + def workflow(inputs: dict) -> dict: + fut_foo = foo(inputs) + fut_bar = bar(fut_foo.result()) + return fut_bar.result() + + assert list(workflow.stream({}, stream_mode="updates")) == [ + {"custom_foo": "foo"}, + {"custom_bar": "bar"}, + {"workflow": "bar"}, + ] diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 35d6143dd..0462454f2 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -6812,8 +6812,8 @@ async def graph(state: dict) -> dict: @NEEDS_CONTEXTVARS @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) -async def test_multiple_interrupts_imperative(checkpointer_name: str) -> None: - """Test multiple interrupts with an imperative API.""" +async def test_multiple_interrupts_functional(checkpointer_name: str) -> None: + """Test multiple interrupts with functional API.""" from langgraph.func import entrypoint, task counter = 0 @@ -7325,3 +7325,30 @@ async def foo(inputs, previous=None) -> Any: assert list(await foo.ainvoke({"a": "1"}, config)) == ["a", "b"] assert previous_return_values == [None] + + +@NEEDS_CONTEXTVARS +async def test_named_tasks_functional() -> None: + + class Foo: + async def foo(self, state: dict) -> dict: + return "foo" + + f = Foo() + foo = task(f.foo, name="custom_foo") + + @task(name="custom_bar") + async def bar(state: dict) -> dict: + return "bar" + + @entrypoint() + async def workflow(inputs: dict) -> dict: + foo_result = await foo(inputs) + bar_result = await bar(foo_result) + return bar_result + + assert [c async for c in workflow.astream({}, stream_mode="updates")] == [ + {"custom_foo": "foo"}, + {"custom_bar": "bar"}, + {"workflow": "bar"}, + ]