Skip to content

Commit

Permalink
Create the checkpoints table in a migration.
Browse files Browse the repository at this point in the history
  • Loading branch information
bakar-io committed Mar 20, 2024
1 parent 6972b36 commit b17c425
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 30 deletions.
32 changes: 2 additions & 30 deletions backend/app/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pickle
from typing import Optional

import asyncpg
from langchain.pydantic_v1 import Field
from langchain.schema.runnable import RunnableConfig
from langchain.schema.runnable.utils import ConfigurableFieldSpec
from langgraph.checkpoint import BaseCheckpointSaver
Expand All @@ -12,9 +10,6 @@


class PostgresCheckpoint(BaseCheckpointSaver):
pg_pool: Optional[asyncpg.Pool] = None
is_setup: bool = Field(False, init=False, repr=False)

class Config:
arbitrary_types_allowed = True

Expand All @@ -31,46 +26,23 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
),
]

async def setup(self) -> None:
if self.is_setup:
return

if self.pg_pool is None:
self.pg_pool = get_pg_pool()

try:
async with self.pg_pool.acquire() as conn:
await conn.execute(
"""
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT PRIMARY KEY,
checkpoint BYTEA
);
"""
)
self.is_setup = True
except BaseException as e:
raise e

def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
raise NotImplementedError

def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
raise NotImplementedError

async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
await self.setup()
thread_id = config["configurable"]["thread_id"]
async with self.pg_pool.acquire() as conn:
async with get_pg_pool().acquire() as conn:
if value := await conn.fetchrow(
"SELECT checkpoint FROM checkpoints WHERE thread_id = $1", thread_id
):
return pickle.loads(value[0])

async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
await self.setup()
thread_id = config["configurable"]["thread_id"]
async with self.pg_pool.acquire() as conn:
async with get_pg_pool().acquire() as conn:
await conn.execute(
(
"INSERT INTO checkpoints (thread_id, checkpoint) "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@ CREATE TABLE IF NOT EXISTS thread (
user_id VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT (CURRENT_TIMESTAMP AT TIME ZONE 'UTC')
);

CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT PRIMARY KEY,
checkpoint BYTEA
);

0 comments on commit b17c425

Please sign in to comment.