From e8a679cccee0aa8389bf6364d3492ed45e01894f Mon Sep 17 00:00:00 2001 From: Varik Matevosyan Date: Wed, 20 Nov 2024 19:14:53 +0400 Subject: [PATCH] move sql-only functions to plpgsql instead of using SPI --- lantern_extras/Cargo.toml | 2 +- lantern_extras/README.md | 4 +- lantern_extras/src/daemon.rs | 296 +++++++++++++++++++++-------------- 3 files changed, 179 insertions(+), 123 deletions(-) diff --git a/lantern_extras/Cargo.toml b/lantern_extras/Cargo.toml index d7810235..142a6933 100644 --- a/lantern_extras/Cargo.toml +++ b/lantern_extras/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lantern_extras" -version = "0.5.0" +version = "0.6.0" edition = "2021" [lib] diff --git a/lantern_extras/README.md b/lantern_extras/README.md index 8e2dda5e..9d07dd52 100644 --- a/lantern_extras/README.md +++ b/lantern_extras/README.md @@ -178,7 +178,7 @@ To add a new embedding job, use the `add_embedding_job` function: ```sql SELECT add_embedding_job( - table => 'articles', -- Name of the table + table_name => 'articles', -- Name of the table src_column => 'content', -- Source column for embeddings dst_column => 'content_embedding', -- Destination column for embeddings (will be created automatically) model => 'text-embedding-3-small', -- Model for runtime to use (default: 'text-embedding-3-small') @@ -224,7 +224,7 @@ To add a new completion job, use the `add_completion_job` function: ```sql SELECT add_completion_job( - table => 'articles', -- Name of the table + table_name => 'articles', -- Name of the table src_column => 'content', -- Source column for embeddings dst_column => 'content_summary', -- Destination column for llm response (will be created automatically) system_prompt => 'Provide short summary for the given text', -- System prompt for LLM (default: '') diff --git a/lantern_extras/src/daemon.rs b/lantern_extras/src/daemon.rs index 6a9522dc..1b6c8928 100644 --- a/lantern_extras/src/daemon.rs +++ b/lantern_extras/src/daemon.rs @@ -226,126 +226,182 @@ fn add_completion_job<'a>( Ok(id.unwrap()) } -#[pg_extern(immutable, parallel_safe, security_definer)] -fn get_embedding_job_status<'a>( - job_id: i32, -) -> Result< - TableIterator< - 'static, - ( - name!(status, Option), - name!(progress, Option), - name!(error, Option), - ), - >, - anyhow::Error, -> { - let tuple = Spi::get_three_with_args( - r#" - SELECT - CASE - WHEN init_failed_at IS NOT NULL THEN 'failed' - WHEN canceled_at IS NOT NULL THEN 'canceled' - WHEN init_finished_at IS NOT NULL THEN 'enabled' - WHEN init_started_at IS NOT NULL THEN 'in_progress' - ELSE 'queued' - END AS status, - init_progress as progress, - init_failure_reason as error - FROM _lantern_extras_internal.embedding_generation_jobs - WHERE id=$1; - "#, - vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())], - ); - - if tuple.is_err() { - return Ok(TableIterator::once((None, None, None))); - } - - Ok(TableIterator::once(tuple.unwrap())) -} - -#[pg_extern(immutable, parallel_safe, security_definer)] -fn get_completion_job_failures<'a>( - job_id: i32, -) -> Result< - TableIterator<'static, (name!(row_id, Option), name!(value, Option))>, - anyhow::Error, -> { - Spi::connect(|client| { - client.select("SELECT row_id, value FROM _lantern_extras_internal.embedding_failure_info WHERE job_id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]))? - .map(|row| Ok((row["row_id"].value()?, row["value"].value()?))) - .collect::, _>>() - }).map(TableIterator::new) -} - -#[pg_extern(immutable, parallel_safe, security_definer)] -fn get_embedding_jobs<'a>() -> Result< - TableIterator< - 'static, - ( - name!(id, Option), - name!(status, Option), - name!(progress, Option), - name!(error, Option), - ), - >, - anyhow::Error, -> { - Spi::connect(|client| { - client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'embedding_generation'", None, None)? - .map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?))) - .collect::, _>>() - }).map(TableIterator::new) -} - -#[pg_extern(immutable, parallel_safe, security_definer)] -fn get_completion_jobs<'a>() -> Result< - TableIterator< - 'static, - ( - name!(id, Option), - name!(status, Option), - name!(progress, Option), - name!(error, Option), - ), - >, - anyhow::Error, -> { - Spi::connect(|client| { - client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'completion'", None, None)? - .map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?))) - .collect::, _>>() - }).map(TableIterator::new) -} - -#[pg_extern(immutable, parallel_safe, security_definer)] -fn cancel_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult { - Spi::run_with_args( - r#" - UPDATE _lantern_extras_internal.embedding_generation_jobs - SET canceled_at=NOW() - WHERE id=$1; - "#, - Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]), - )?; - - Ok(()) -} - -#[pg_extern(immutable, parallel_safe, security_definer)] -fn resume_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult { - Spi::run_with_args( - r#" - UPDATE _lantern_extras_internal.embedding_generation_jobs - SET canceled_at=NULL - WHERE id=$1; - "#, - Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]), - )?; - - Ok(()) -} +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION get_embedding_job_status(job_id INT) +RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT) +STRICT IMMUTABLE PARALLEL SAFE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY + SELECT + CASE + WHEN init_failed_at IS NOT NULL THEN 'failed' + WHEN canceled_at IS NOT NULL THEN 'canceled' + WHEN init_finished_at IS NOT NULL THEN 'enabled' + WHEN init_started_at IS NOT NULL THEN 'in_progress' + ELSE 'queued' + END AS status, + init_progress as progress, + init_failure_reason as error + FROM _lantern_extras_internal.embedding_generation_jobs + WHERE id=job_id; +END +$$; +"#, + name = "get_embedding_job_status" +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION get_completion_job_status(job_id INT) +RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT) +STRICT IMMUTABLE PARALLEL SAFE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY + SELECT * FROM get_embedding_job_status(job_id); +END +$$; +"#, + name = "get_completion_job_status", + requires = ["get_embedding_job_status"] +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION get_completion_job_failures(job_id INT) +RETURNS TABLE (row_id INT, value TEXT) +STRICT IMMUTABLE PARALLEL SAFE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY + SELECT info.row_id, info.value + FROM _lantern_extras_internal.embedding_failure_info info + WHERE info.job_id=get_completion_job_failures.job_id; +END +$$; +"#, + name = "get_completion_job_failures", +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION get_embedding_jobs() +RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT) +STRICT IMMUTABLE PARALLEL SAFE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY + SELECT jobs.id, (get_embedding_job_status(jobs.id)).* + FROM _lantern_extras_internal.embedding_generation_jobs jobs + WHERE jobs.job_type = 'embedding_generation'; +END +$$; +"#, + name = "get_embedding_jobs", + requires = ["get_embedding_job_status"] +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION get_completion_jobs() +RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT) +STRICT IMMUTABLE PARALLEL SAFE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN QUERY + SELECT jobs.id, (get_completion_job_status(jobs.id)).* + FROM _lantern_extras_internal.embedding_generation_jobs jobs + WHERE jobs.job_type = 'completion'; +END +$$; +"#, + name = "get_completion_jobs", + requires = ["get_completion_job_status"] +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION cancel_embedding_job(job_id INT) +RETURNS VOID +STRICT VOLATILE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + UPDATE _lantern_extras_internal.embedding_generation_jobs + SET canceled_at=NOW() + WHERE id=job_id; +END +$$; +"#, + name = "cancel_embedding_job", +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION cancel_completion_job(job_id INT) +RETURNS VOID +STRICT VOLATILE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + UPDATE _lantern_extras_internal.embedding_generation_jobs + SET canceled_at=NOW() + WHERE id=job_id; +END +$$; +"#, + name = "cancel_completion_job", +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION resume_embedding_job(job_id INT) +RETURNS VOID +STRICT VOLATILE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + UPDATE _lantern_extras_internal.embedding_generation_jobs + SET canceled_at=NULL + WHERE id=job_id; +END +$$; +"#, + name = "resume_embedding_job", +); + +extension_sql!( + r#" +CREATE OR REPLACE FUNCTION resume_completion_job(job_id INT) +RETURNS VOID +STRICT VOLATILE +SECURITY DEFINER +LANGUAGE plpgsql +AS $$ +BEGIN + UPDATE _lantern_extras_internal.embedding_generation_jobs + SET canceled_at=NULL + WHERE id=job_id; +END +$$; +"#, + name = "resume_completion_job", +); #[cfg(any(test, feature = "pg_test"))] #[pg_schema]