diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 9daec2333..8ac8f6c7b 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -123,8 +123,10 @@ def decorator( ]: if name is not None: if hasattr(func, "__func__"): + # handle class methods func.__func__.__name__ = name else: + # handle regular functions / partials / callable classes, etc. func.__name__ = name call_func = functools.partial(call, func, retry=retry) diff --git a/libs/langgraph/langgraph/pregel/call.py b/libs/langgraph/langgraph/pregel/call.py index 1ddad1965..61a451335 100644 --- a/libs/langgraph/langgraph/pregel/call.py +++ b/libs/langgraph/langgraph/pregel/call.py @@ -168,12 +168,21 @@ def get_runnable_for_task(func: Callable[..., Any]) -> RunnableSeq: if key in CACHE: return CACHE[key] else: + if hasattr(func, "__name__"): + name = func.__name__ + elif hasattr(func, "func"): + name = func.func.__name__ + elif hasattr(func, "__class__"): + name = func.__class__.__name__ + else: + name = str(func) + if is_async_callable(func): run = RunnableCallable( None, func, explode_args=True, - name=func.__name__, + name=name, trace=False, recurse=False, ) @@ -182,14 +191,14 @@ def get_runnable_for_task(func: Callable[..., Any]) -> RunnableSeq: func, functools.wraps(func)(functools.partial(run_in_executor, None, func)), explode_args=True, - name=func.__name__, + name=name, trace=False, recurse=False, ) seq = RunnableSeq( run, ChannelWrite([ChannelWriteEntry(RETURN)], tags=[TAG_HIDDEN]), - name=func.__name__, + name=name, trace_inputs=functools.partial( _explode_args_trace_inputs, inspect.signature(func) ), diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index bd58f0b55..53a67ddc6 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1,4 +1,5 @@ import enum +import functools import json import logging import operator @@ -6268,26 +6269,48 @@ async def workflow(inputs: dict, *, previous: Any): def test_named_tasks_functional() -> None: - class Foo: - def foo(self, state: dict) -> dict: - return "foo" + def foo(self, value: str) -> dict: + return value + "foo" f = Foo() + + # class method task foo = task(f.foo, name="custom_foo") + # regular function task @task(name="custom_bar") - def bar(state: dict) -> dict: - return "bar" + def bar(value: str) -> dict: + return value + "|bar" + + def baz(update: str, value: str) -> dict: + return value + f"|{update}" + + # partial function task (unnamed) + baz_task = task(functools.partial(baz, "baz")) + # partial function task (named_) + custom_baz_task = task(functools.partial(baz, "custom_baz"), name="custom_baz") + + class Qux: + def __call__(self, value: str) -> dict: + return value + "|qux" + + qux_task = task(Qux(), name="qux") @entrypoint() def workflow(inputs: dict) -> dict: fut_foo = foo(inputs) fut_bar = bar(fut_foo.result()) - return fut_bar.result() + fut_baz = baz_task(fut_bar.result()) + fut_custom_baz = custom_baz_task(fut_baz.result()) + fut_qux = qux_task(fut_custom_baz.result()) + return fut_qux.result() - assert list(workflow.stream({}, stream_mode="updates")) == [ + assert list(workflow.stream("", stream_mode="updates")) == [ {"custom_foo": "foo"}, - {"custom_bar": "bar"}, - {"workflow": "bar"}, + {"custom_bar": "foo|bar"}, + {"baz": "foo|bar|baz"}, + {"custom_baz": "foo|bar|baz|custom_baz"}, + {"qux": "foo|bar|baz|custom_baz|qux"}, + {"workflow": "foo|bar|baz|custom_baz|qux"}, ] diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 0462454f2..192d31c11 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1,4 +1,5 @@ import asyncio +import functools import logging import operator import random @@ -7329,26 +7330,48 @@ async def foo(inputs, previous=None) -> Any: @NEEDS_CONTEXTVARS async def test_named_tasks_functional() -> None: - class Foo: - async def foo(self, state: dict) -> dict: - return "foo" + async def foo(self, value: str) -> dict: + return value + "foo" f = Foo() + + # class method task foo = task(f.foo, name="custom_foo") + # regular function task @task(name="custom_bar") - async def bar(state: dict) -> dict: - return "bar" + async def bar(value: str) -> dict: + return value + "|bar" + + async def baz(update: str, value: str) -> dict: + return value + f"|{update}" + + # partial function task (unnamed) + baz_task = task(functools.partial(baz, "baz")) + # partial function task (named_) + custom_baz_task = task(functools.partial(baz, "custom_baz"), name="custom_baz") + + class Qux: + def __call__(self, value: str) -> dict: + return value + "|qux" + + qux_task = task(Qux(), name="qux") @entrypoint() async def workflow(inputs: dict) -> dict: foo_result = await foo(inputs) bar_result = await bar(foo_result) - return bar_result + baz_result = await baz_task(bar_result) + custom_baz_result = await custom_baz_task(baz_result) + qux_result = await qux_task(custom_baz_result) + return qux_result - assert [c async for c in workflow.astream({}, stream_mode="updates")] == [ + assert [c async for c in workflow.astream("", stream_mode="updates")] == [ {"custom_foo": "foo"}, - {"custom_bar": "bar"}, - {"workflow": "bar"}, + {"custom_bar": "foo|bar"}, + {"baz": "foo|bar|baz"}, + {"custom_baz": "foo|bar|baz|custom_baz"}, + {"qux": "foo|bar|baz|custom_baz|qux"}, + {"workflow": "foo|bar|baz|custom_baz|qux"}, ]