Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node caching #2917

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,40 @@ def put(
)
return next_config

def get_writes(
self,
task_id: str,
) -> Sequence[tuple[str, Any]]:
"""
Retrieve cached writes for a specific task ID.

Args:
task_id (str): The identifier for the task.

Returns:
Sequence[tuple[str, Any]]: A list of tuples representing cached writes,
where each tuple contains a channel name and its associated value.
"""
with self._cursor() as cur:
cur.execute(self.GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL, (task_id,))
rows = cur.fetchall()

serialized_writes = [
(
row["task_id"].encode(),
row["channel"].encode(),
row["type"].encode(),
row["blob"],
)
for row in rows
]

# Deserialize writes using `_load_writes` and return as (channel, value) pairs
return [
(channel, value)
for _, channel, value in self._load_writes(serialized_writes)
]

def put_writes(
self,
config: RunnableConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@
blob = EXCLUDED.blob;
"""

GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL = """
SELECT task_id, channel, type, blob
FROM checkpoint_writes
WHERE task_id = %s
"""

INSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
Expand All @@ -143,6 +149,7 @@ class BasePostgresSaver(BaseCheckpointSaver[str]):
UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL
UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL
GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL = GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL

jsonplus_serde = JsonPlusSerializer()
supports_pipeline: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
blob = EXCLUDED.blob;
"""

GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL = """
SELECT task_id, channel, type, blob
FROM checkpoint_writes
WHERE task_id = %s
"""

INSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
Expand Down Expand Up @@ -175,6 +181,7 @@ class ShallowPostgresSaver(BasePostgresSaver):
UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL
UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL
GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL = GET_CHECKPOINT_WRITES_BY_TASK_ID_SQL

lock: threading.Lock

Expand Down
16 changes: 16 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ def put_writes(
"""
raise NotImplementedError

def get_writes(
self,
task_id: str,
) -> Sequence[tuple[str, Any]]:
"""
Retrieve cached writes for a specific task ID.

Args:
task_id (str): The identifier for the task.

Returns:
Sequence[tuple[str, Any]]: A list of tuples representing cached writes,
where each tuple contains a channel name and its associated value.
"""
raise NotImplementedError

async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
"""Asynchronously fetch a checkpoint using the given configuration.

Expand Down
Empty file.
40 changes: 40 additions & 0 deletions libs/langgraph/langgraph/caching/cache_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from datetime import datetime, timedelta, timezone
from hashlib import sha1
from typing import Optional

class CacheControl:
def __init__(self, cache_key: str, ttl: Optional[int] = None):
"""
Args:
cache_key (str): The unique identifier for caching.
ttl (Optional[int]): Time-to-live in seconds, after which the cache is invalid.
"""
self.cache_key = cache_key
self.ttl = ttl

def compute_time_bucket(self) -> str:
"""Calculate a precise time bucket string for the given TTL in seconds."""
if not self.ttl:
return 'no-ttl'
now = datetime.now(timezone.utc)
bucket_start = now - timedelta(seconds=now.timestamp() % self.ttl)
return bucket_start.isoformat()

def generate_task_id(self, input_data: str) -> str:
"""
Generate a cache-enabled task ID based on cache key, stringified input data, and TTL bucket.

Args:
input_data (str): The stringified input data to be hashed for uniqueness.

Returns:
str: A unique task ID that includes the cache key, input hash, and TTL bucket.
"""
# Hash the input data to create a unique identifier
input_hash = sha1(input_data.encode()).hexdigest()

# Compute the TTL bucket
ttl_bucket = self.compute_time_bucket()

# Combine the cache key, input hash, and TTL bucket to form the task ID
return f"{self.cache_key}:{input_hash}:{ttl_bucket}"
8 changes: 8 additions & 0 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing_extensions import Self

from langgraph._api.deprecation import LangGraphDeprecationWarning
from langgraph.caching.cache_control import CacheControl
from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.dynamic_barrier_value import DynamicBarrierValue, WaitForNames
Expand Down Expand Up @@ -92,6 +93,7 @@ class StateNodeSpec(NamedTuple):
input: Type[Any]
retry_policy: Optional[RetryPolicy]
ends: Optional[tuple[str, ...]] = EMPTY_SEQ
cache: Optional[CacheControl] = None


class StateGraph(Graph):
Expand Down Expand Up @@ -232,6 +234,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CacheControl] = None,
) -> Self:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand All @@ -256,6 +259,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CacheControl] = None,
) -> Self:
"""Adds a new node to the state graph.

Expand All @@ -279,6 +283,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CacheControl] = None,
) -> Self:
"""Adds a new node to the state graph.

Expand All @@ -290,6 +295,7 @@ def add_node(
metadata (Optional[dict[str, Any]]): The metadata associated with the node. (default: None)
input (Optional[Type[Any]]): The input schema for the node. (default: the graph's input schema)
retry (Optional[RetryPolicy]): The policy for retrying the node. (default: None)
cache (Optional[CacheControl]): The caching configuration for the node. (default: None)
Raises:
ValueError: If the key is already being used as a state key.

Expand Down Expand Up @@ -394,6 +400,7 @@ def add_node(
input=input or self.schema,
retry_policy=retry,
ends=ends,
cache=cache,
)
return self

Expand Down Expand Up @@ -719,6 +726,7 @@ def _get_updates(
],
metadata=node.metadata,
retry_policy=node.retry_policy,
cache=node.cache,
bound=node.runnable,
)
else:
Expand Down
95 changes: 58 additions & 37 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def prepare_single_task(
configurable = config.get(CONF, {})
parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "")

node_cache = None
if task_path[0] == PUSH and isinstance(task_path[-1], Call):
# (PUSH, parent task path, idx of PUSH write, id of parent task, Call)
task_path_t = cast(tuple[str, tuple, int, str, Call], task_path)
Expand All @@ -494,15 +495,21 @@ def prepare_single_task(
# create task id
triggers = [PUSH]
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)

pregel_node = processes.get(name)
if pregel_node and pregel_node.cache:
node_cache = pregel_node.cache
task_id = node_cache.generate_task_id(str(call.input))
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
metadata = {
"langgraph_step": step,
Expand Down Expand Up @@ -564,7 +571,7 @@ def prepare_single_task(
),
triggers,
call.retry,
None,
node_cache,
task_id,
task_path[:3],
)
Expand Down Expand Up @@ -593,14 +600,19 @@ def prepare_single_task(
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
str(idx),
)
proc = processes.get(packet.node)
if proc and proc.cache:
node_cache = proc.cache
task_id = node_cache.generate_task_id(str(packet.arg))
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
str(idx),
)
elif len(task_path) >= 4:
# new PUSH tasks, executed in superstep n
# (PUSH, parent task path, idx of PUSH write, id of parent task)
Expand Down Expand Up @@ -629,15 +641,20 @@ def prepare_single_task(
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
proc = processes.get(packet.node)
if proc and proc.cache:
node_cache = proc.cache
task_id = node_cache.generate_task_id(str(packet.arg))
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
else:
logger.warning(f"Ignoring invalid PUSH task path {task_path}")
return
Expand Down Expand Up @@ -713,7 +730,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
node_cache,
task_id,
task_path[:3],
writers=proc.flat_writers,
Expand Down Expand Up @@ -756,14 +773,18 @@ def prepare_single_task(

# create task id
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PULL,
*triggers,
)
if proc and proc.cache:
node_cache = proc.cache
task_id = node_cache.generate_task_id(str(val))
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PULL,
*triggers,
)
task_checkpoint_ns = f"{checkpoint_ns}{NS_END}{task_id}"
metadata = {
"langgraph_step": step,
Expand Down Expand Up @@ -837,7 +858,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
node_cache,
task_id,
task_path[:3],
writers=proc.flat_writers,
Expand Down
6 changes: 6 additions & 0 deletions libs/langgraph/langgraph/pregel/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain_core.runnables.base import Input, Other, coerce_to_runnable
from langchain_core.runnables.utils import ConfigurableFieldSpec

from langgraph.caching.cache_control import CacheControl
from langgraph.constants import CONF, CONFIG_KEY_READ
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.write import ChannelWrite
Expand Down Expand Up @@ -145,6 +146,9 @@ class PregelNode(Runnable):
metadata: Optional[Mapping[str, Any]]
"""Metadata to attach to the node for tracing."""

cache: Optional[CacheControl]
"""The caching configuration for the node."""

def __init__(
self,
*,
Expand All @@ -156,6 +160,7 @@ def __init__(
metadata: Optional[Mapping[str, Any]] = None,
bound: Optional[Runnable[Any, Any]] = None,
retry_policy: Optional[RetryPolicy] = None,
cache: Optional[CacheControl] = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
Expand All @@ -165,6 +170,7 @@ def __init__(
self.retry_policy = retry_policy
self.tags = tags
self.metadata = metadata
self.cache = cache

def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
Expand Down
Loading