Skip to content

Commit

Permalink
move sql-only functions to plpgsql instead of using SPI
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Nov 20, 2024
1 parent 1257c10 commit e8a679c
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 123 deletions.
2 changes: 1 addition & 1 deletion lantern_extras/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_extras"
version = "0.5.0"
version = "0.6.0"
edition = "2021"

[lib]
Expand Down
4 changes: 2 additions & 2 deletions lantern_extras/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ To add a new embedding job, use the `add_embedding_job` function:

```sql
SELECT add_embedding_job(
table => 'articles', -- Name of the table
table_name => 'articles', -- Name of the table
src_column => 'content', -- Source column for embeddings
dst_column => 'content_embedding', -- Destination column for embeddings (will be created automatically)
model => 'text-embedding-3-small', -- Model for runtime to use (default: 'text-embedding-3-small')
Expand Down Expand Up @@ -224,7 +224,7 @@ To add a new completion job, use the `add_completion_job` function:

```sql
SELECT add_completion_job(
table => 'articles', -- Name of the table
table_name => 'articles', -- Name of the table
src_column => 'content', -- Source column for embeddings
dst_column => 'content_summary', -- Destination column for llm response (will be created automatically)
system_prompt => 'Provide short summary for the given text', -- System prompt for LLM (default: '')
Expand Down
296 changes: 176 additions & 120 deletions lantern_extras/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,126 +226,182 @@ fn add_completion_job<'a>(
Ok(id.unwrap())
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_embedding_job_status<'a>(
job_id: i32,
) -> Result<
TableIterator<
'static,
(
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
let tuple = Spi::get_three_with_args(
r#"
SELECT
CASE
WHEN init_failed_at IS NOT NULL THEN 'failed'
WHEN canceled_at IS NOT NULL THEN 'canceled'
WHEN init_finished_at IS NOT NULL THEN 'enabled'
WHEN init_started_at IS NOT NULL THEN 'in_progress'
ELSE 'queued'
END AS status,
init_progress as progress,
init_failure_reason as error
FROM _lantern_extras_internal.embedding_generation_jobs
WHERE id=$1;
"#,
vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())],
);

if tuple.is_err() {
return Ok(TableIterator::once((None, None, None)));
}

Ok(TableIterator::once(tuple.unwrap()))
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_completion_job_failures<'a>(
job_id: i32,
) -> Result<
TableIterator<'static, (name!(row_id, Option<i32>), name!(value, Option<String>))>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT row_id, value FROM _lantern_extras_internal.embedding_failure_info WHERE job_id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]))?
.map(|row| Ok((row["row_id"].value()?, row["value"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_embedding_jobs<'a>() -> Result<
TableIterator<
'static,
(
name!(id, Option<i32>),
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'embedding_generation'", None, None)?
.map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_completion_jobs<'a>() -> Result<
TableIterator<
'static,
(
name!(id, Option<i32>),
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'completion'", None, None)?
.map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn cancel_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult {
Spi::run_with_args(
r#"
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=$1;
"#,
Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]),
)?;

Ok(())
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn resume_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult {
Spi::run_with_args(
r#"
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=$1;
"#,
Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]),
)?;

Ok(())
}
extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_embedding_job_status(job_id INT)
RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT
CASE
WHEN init_failed_at IS NOT NULL THEN 'failed'
WHEN canceled_at IS NOT NULL THEN 'canceled'
WHEN init_finished_at IS NOT NULL THEN 'enabled'
WHEN init_started_at IS NOT NULL THEN 'in_progress'
ELSE 'queued'
END AS status,
init_progress as progress,
init_failure_reason as error
FROM _lantern_extras_internal.embedding_generation_jobs
WHERE id=job_id;
END
$$;
"#,
name = "get_embedding_job_status"
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_job_status(job_id INT)
RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT * FROM get_embedding_job_status(job_id);
END
$$;
"#,
name = "get_completion_job_status",
requires = ["get_embedding_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_job_failures(job_id INT)
RETURNS TABLE (row_id INT, value TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT info.row_id, info.value
FROM _lantern_extras_internal.embedding_failure_info info
WHERE info.job_id=get_completion_job_failures.job_id;
END
$$;
"#,
name = "get_completion_job_failures",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_embedding_jobs()
RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT jobs.id, (get_embedding_job_status(jobs.id)).*
FROM _lantern_extras_internal.embedding_generation_jobs jobs
WHERE jobs.job_type = 'embedding_generation';
END
$$;
"#,
name = "get_embedding_jobs",
requires = ["get_embedding_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_jobs()
RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT jobs.id, (get_completion_job_status(jobs.id)).*
FROM _lantern_extras_internal.embedding_generation_jobs jobs
WHERE jobs.job_type = 'completion';
END
$$;
"#,
name = "get_completion_jobs",
requires = ["get_completion_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION cancel_embedding_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=job_id;
END
$$;
"#,
name = "cancel_embedding_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION cancel_completion_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=job_id;
END
$$;
"#,
name = "cancel_completion_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION resume_embedding_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=job_id;
END
$$;
"#,
name = "resume_embedding_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION resume_completion_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=job_id;
END
$$;
"#,
name = "resume_completion_job",
);

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
Expand Down

0 comments on commit e8a679c

Please sign in to comment.