Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add openai completion API to SQL and daemon #341

Merged
merged 15 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions lantern_cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,6 @@ lantern-cli create-embeddings --model 'clip/ViT-B-32-textual' --uri 'postgresq

> The output database, table and column names can be specified via `--out-table`, `--out-uri`, `--out-column` arguments. Check `help` for more info.

or you can export to csv file

```bash
lantern-cli create-embeddings --model 'clip/ViT-B-32-textual' --uri 'postgresql://postgres:postgres@localhost:5432/test' --table "articles" --column "description" --out-column embedding --out-csv "embeddings.csv" --schema "public"
```

### Image Embedding Example

1. Create table with image uris data
Expand All @@ -105,10 +99,10 @@ Lantern CLI also supports generating OpenAI and Cohere embeddings via API. For t

```bash
# OpenAI
lantern-cli create-embeddings --model 'openai/text-embedding-ada-002' --uri 'postgresql://postgres:postgres@localhost:5432/test' --table "images" --column "url" --out-column "embedding" --schema "public" --runtime openai --runtime-params '{ "api_token": "sk-xxx-xxxx" }'
lantern-cli create-embeddings --model 'text-embedding-ada-002' --uri 'postgresql://postgres:postgres@localhost:5432/test' --table "images" --column "url" --out-column "embedding" --schema "public" --runtime openai --runtime-params '{ "api_token": "sk-xxx-xxxx" }'

# Cohere
lantern-cli create-embeddings --model 'openai/text-embedding-ada-002' --uri 'postgresql://postgres:postgres@localhost:5432/test' --table "images" --column "url" --out-column "embedding" --schema "public" --runtime cohere --runtime-params '{ "api_token": "xxx-xxxx" }'
lantern-cli create-embeddings --model 'embed-english-v3.0' --uri 'postgresql://postgres:postgres@localhost:5432/test' --table "images" --column "url" --out-column "embedding" --schema "public" --runtime cohere --runtime-params '{ "api_token": "xxx-xxxx" }'
```

|> To get available runtimes use `bash lantern-cli show-runtimes`
Expand Down
2 changes: 2 additions & 0 deletions lantern_cli/src/daemon/autotune_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ pub async fn start(
None,
None,
None,
None,
None,
&notification_channel,
logger.clone(),
)
Expand Down
4 changes: 4 additions & 0 deletions lantern_cli/src/daemon/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ pub struct DaemonArgs {
/// Log level
#[arg(long, value_enum, default_value_t = LogLevel::Info)] // arg_enum here
pub log_level: LogLevel,

/// Is being run inside postgres
#[arg(long, default_value_t = false)]
pub inside_postgres: bool,
}
41 changes: 37 additions & 4 deletions lantern_cli/src/daemon/embedding_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use super::types::{
JobEventHandlersMap, JobInsertNotification, JobRunArgs, JobUpdateNotification,
};
use crate::daemon::helpers::anyhow_wrap_connection;
use crate::embeddings::cli::EmbeddingArgs;
use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType};
use crate::embeddings::get_default_batch_size;
use crate::logger::Logger;
use crate::utils::{get_common_embedding_ignore_filters, get_full_table_name, quote_ident};
use crate::{embeddings, types::*};
use std::collections::HashMap;
use std::ops::Deref;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -31,8 +32,11 @@ pub const JOB_TABLE_DEFINITION: &'static str = r#"
"table" text NOT NULL,
"pk" text NOT NULL DEFAULT 'id',
"label" text NULL,
"job_type" text DEFAULT 'embedding_generation',
"column_type" text DEFAULT 'REAL[]',
"runtime" text NOT NULL DEFAULT 'ort',
"runtime_params" jsonb,
"batch_size" int NULL,
"src_column" text NOT NULL,
"dst_column" text NOT NULL,
"embedding_model" text NOT NULL,
Expand All @@ -55,7 +59,16 @@ pub const USAGE_TABLE_DEFINITION: &'static str = r#"
"created_at" timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
"#;

pub const FAILURE_TABLE_DEFINITION: &'static str = r#"
"id" SERIAL PRIMARY KEY,
"job_id" INT NOT NULL,
"row_id" INT NOT NULL,
"value" TEXT,
"created_at" timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
"#;

const EMB_USAGE_TABLE_NAME: &'static str = "embedding_usage_info";
const EMB_FAILURE_TABLE_NAME: &'static str = "embedding_failure_info";
const EMB_LOCK_TABLE_NAME: &'static str = "_lantern_emb_job_locks";

async fn lock_row(
Expand Down Expand Up @@ -302,7 +315,9 @@ async fn stream_job(
let mut progress = 0;
let mut processed_rows = 0;

let batch_size = embeddings::get_default_batch_size(&job.model) as i32;
let batch_size =
job.batch_size
.unwrap_or(embeddings::get_default_batch_size(&job.model)) as i32;
loop {
// poll batch_size rows from portal and send it to embedding thread via channel
let rows = job_client
Expand Down Expand Up @@ -436,6 +451,16 @@ async fn embedding_worker(
logger.level.clone(),
);
let job_clone = job.clone();
let mut failed_rows_table = None;
let mut check_column_type = false;

match job_clone.job_type {
EmbeddingJobType::Completion => {
failed_rows_table = Some(EMB_FAILURE_TABLE_NAME.to_owned());
check_column_type = true;
},
_ => {}
};

let (tx, mut rx) = mpsc::channel(1);
embedding_processor_tx.send(
Expand All @@ -456,9 +481,15 @@ async fn embedding_worker(
visual: false,
stream: true,
create_column: false,
out_csv: None,
filter: job_clone.filter.clone(),
limit: None,
job_type: Some(job_clone.job_type.clone()),
column_type: Some(job_clone.column_type.clone()),
create_cast_fn: false,
check_column_type,
job_id: job_clone.id,
internal_schema: schema.deref().clone(),
failed_rows_table
},
tx,
task_logger
Expand Down Expand Up @@ -618,7 +649,7 @@ async fn job_insert_processor(
// batch jobs for the rows. This will optimize embedding generation as if there will be lots of
// inserts to the table between 10 seconds all that rows will be batched.
let full_table_name = Arc::new(get_full_table_name(&schema, &table));
let job_query_sql = Arc::new(format!("SELECT id, pk, label, src_column as \"column\", dst_column, \"table\", \"schema\", embedding_model as model, runtime, runtime_params::text, init_finished_at FROM {0}", &full_table_name));
let job_query_sql = Arc::new(format!("SELECT id, pk, label, src_column as \"column\", dst_column, \"table\", \"schema\", embedding_model as model, runtime, runtime_params::text, init_finished_at, job_type, column_type, batch_size FROM {0}", &full_table_name));

let db_uri_r1 = db_uri.clone();
let full_table_name_r1 = full_table_name.clone();
Expand Down Expand Up @@ -1049,6 +1080,8 @@ pub async fn start(
None,
Some(EMB_USAGE_TABLE_NAME),
Some(USAGE_TABLE_DEFINITION),
Some(EMB_FAILURE_TABLE_NAME),
Some(FAILURE_TABLE_DEFINITION),
None,
&notification_channel,
logger.clone(),
Expand Down
2 changes: 2 additions & 0 deletions lantern_cli/src/daemon/external_index_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ pub async fn start(
None,
None,
None,
None,
None,
&notification_channel,
logger.clone(),
)
Expand Down
18 changes: 18 additions & 0 deletions lantern_cli/src/daemon/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::types::{
EmbeddingJob, JobEvent, JobEventHandlersMap, JobInsertNotification, JobTaskEventTx,
JobUpdateNotification,
};
use crate::embeddings::get_try_cast_fn_sql;
use crate::logger::Logger;
use crate::types::{AnyhowVoidResult, JOB_CANCELLED_MESSAGE};
use crate::utils::{get_common_embedding_ignore_filters, get_full_table_name, quote_ident};
Expand Down Expand Up @@ -147,6 +148,8 @@ pub async fn startup_hook(
results_table_def: Option<&str>,
usage_table_name: Option<&str>,
usage_table_def: Option<&str>,
failure_table_name: Option<&str>,
failure_table_def: Option<&str>,
migration: Option<&str>,
channel: &str,
logger: Arc<Logger>,
Expand Down Expand Up @@ -271,6 +274,21 @@ pub async fn startup_hook(
.await?;
}

if failure_table_name.is_some() && failure_table_def.is_some() {
let failure_table_name = get_full_table_name(schema, failure_table_name.unwrap());
let failure_table_def = failure_table_def.unwrap();
transaction
.batch_execute(&format!(
"
-- this function is used for completion jobs
{ldb_try_cast_fn}
CREATE TABLE IF NOT EXISTS {failure_table_name} ({failure_table_def});
CREATE INDEX IF NOT EXISTS embedding_failures_job_id_row_id ON {failure_table_name}(job_id, row_id);",
ldb_try_cast_fn = get_try_cast_fn_sql(&schema)
))
.await?;
}

transaction.commit().await?;

Ok(())
Expand Down
12 changes: 9 additions & 3 deletions lantern_cli/src/daemon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,15 @@ async fn spawn_job(
let mut last_retry = Instant::now();

loop {
let cancel_token = parent_cancel_token.child_token();
// If run inside postgres, in case of error
// We will cancel the parent task and let bgworker
// handle the task restart
let cancel_token = if args.inside_postgres {
parent_cancel_token.clone()
} else {
parent_cancel_token.child_token()
};

let mut jobs = JOBS.write().await;
jobs.insert(target_db.name.clone(), cancel_token.clone());
drop(jobs);
Expand Down Expand Up @@ -415,7 +423,5 @@ pub async fn start(
.await?;
}

cancel_token.cancelled().await;

Ok(())
}
23 changes: 20 additions & 3 deletions lantern_cli/src/daemon/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::embeddings::cli::{EmbeddingArgs, Runtime};
use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType, Runtime};
use crate::embeddings::core::utils::get_clean_model_name;
use crate::external_index::cli::CreateIndexArgs;
use crate::index_autotune::cli::IndexAutotuneArgs;
use crate::logger::Logger;
Expand Down Expand Up @@ -64,6 +65,8 @@ pub struct EmbeddingJob {
pub pk: String,
pub filter: Option<String>,
pub label: Option<String>,
pub job_type: EmbeddingJobType,
pub column_type: String,
pub out_column: String,
pub model: String,
pub runtime_params: String,
Expand All @@ -82,6 +85,12 @@ impl EmbeddingJob {
.unwrap_or("{}".to_owned())
};

let batch_size = if let Some(batch_size) = row.get::<&str, Option<i32>>("batch_size") {
Some(batch_size as usize)
} else {
None
};

Ok(Self {
id: row.get::<&str, i32>("id"),
pk: row.get::<&str, String>("pk"),
Expand All @@ -91,13 +100,20 @@ impl EmbeddingJob {
table: row.get::<&str, String>("table"),
column: row.get::<&str, String>("column"),
out_column: row.get::<&str, String>("dst_column"),
model: row.get::<&str, String>("model"),
model: get_clean_model_name(row.get::<&str, &str>("model"), runtime),
runtime,
runtime_params,
filter: None,
row_ids: None,
is_init: true,
batch_size: None,
batch_size,
job_type: EmbeddingJobType::try_from(
row.get::<&str, Option<&str>>("job_type")
.unwrap_or("embedding"),
)?,
column_type: row
.get::<&str, Option<String>>("column_type")
.unwrap_or("REAL[]".to_owned()),
})
}

Expand Down Expand Up @@ -208,6 +224,7 @@ pub struct JobInsertNotification {
pub generate_missing: bool,
pub row_id: Option<String>,
pub filter: Option<String>,
#[allow(dead_code)]
pub limit: Option<u32>,
}

Expand Down
55 changes: 50 additions & 5 deletions lantern_cli/src/embeddings/cli.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
pub use super::core::Runtime;
use clap::Parser;
use clap::{Parser, ValueEnum};

#[derive(ValueEnum, Debug, Clone)]
pub enum EmbeddingJobType {
EmbeddingGeneration,
Completion,
}

impl TryFrom<&str> for EmbeddingJobType {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"embedding_generation" => Ok(EmbeddingJobType::EmbeddingGeneration),
"completion" => Ok(EmbeddingJobType::Completion),
_ => anyhow::bail!("Invalid job_type {value}"),
}
}
}

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -44,6 +61,14 @@ pub struct EmbeddingArgs {
#[arg(short, long)]
pub batch_size: Option<usize>,

/// Generate embeddings or get chat completion
#[arg(long)]
pub job_type: Option<EmbeddingJobType>,

/// Type of destination column
#[arg(long)]
pub column_type: Option<String>,

/// Runtime
#[arg(long, default_value_t = Runtime::Ort)]
pub runtime: Runtime,
Expand All @@ -56,10 +81,6 @@ pub struct EmbeddingArgs {
#[arg(long, default_value_t = false)]
pub visual: bool,

/// Output csv path. If specified result will be written in csv instead of database
#[arg(short, long)]
pub out_csv: Option<String>,

/// Filter which will be used when getting data from source table
#[arg(short, long)]
pub filter: Option<String>,
Expand All @@ -75,6 +96,26 @@ pub struct EmbeddingArgs {
/// Create destination column if not exists
#[arg(long, default_value_t = false)]
pub create_column: bool,

/// Create ldb_try_cast function for type checking
#[arg(long, default_value_t = false)]
pub create_cast_fn: bool,

/// Check column type for each row before inserting to the table
#[arg(long, default_value_t = false)]
pub check_column_type: bool,

/// Schema name where the failed rows table and ldb_try_cast function are located
#[arg(long, default_value = "public")]
pub internal_schema: String,

/// Table to insert the rows which were not casted successfully
#[arg(long)]
pub failed_rows_table: Option<String>,

/// Job ID is only needed when run from daemon
#[arg(long, default_value_t = 0)]
pub job_id: i32,
}

impl EmbeddingArgs {
Expand All @@ -100,6 +141,10 @@ pub struct ShowModelsArgs {
/// Runtime Params JSON string
#[arg(long, default_value = "{}")]
pub runtime_params: String,

/// Generate embeddings or get chat completion
#[arg(long)]
pub job_type: Option<EmbeddingJobType>,
}

#[derive(Parser, Debug)]
Expand Down
Loading
Loading