Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add openai completion API to SQL and daemon #341

Merged
merged 15 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions lantern_cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,6 @@ lantern-cli create-embeddings --model 'clip/ViT-B-32-textual' --uri 'postgresq

> The output database, table and column names can be specified via `--out-table`, `--out-uri`, `--out-column` arguments. Check `help` for more info.

or you can export to csv file

```bash
lantern-cli create-embeddings --model 'clip/ViT-B-32-textual' --uri 'postgresql://postgres:postgres@localhost:5432/test' --table "articles" --column "description" --out-column embedding --out-csv "embeddings.csv" --schema "public"
```

### Image Embedding Example

1. Create table with image uris data
Expand Down
1 change: 0 additions & 1 deletion lantern_cli/src/daemon/embedding_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ async fn embedding_worker(
visual: false,
stream: true,
create_column: false,
out_csv: None,
filter: job_clone.filter.clone(),
limit: None,
},
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/src/daemon/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ pub struct JobInsertNotification {
pub generate_missing: bool,
pub row_id: Option<String>,
pub filter: Option<String>,
#[allow(dead_code)]
pub limit: Option<u32>,
}

Expand Down
4 changes: 0 additions & 4 deletions lantern_cli/src/embeddings/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ pub struct EmbeddingArgs {
#[arg(long, default_value_t = false)]
pub visual: bool,

/// Output csv path. If specified result will be written in csv instead of database
#[arg(short, long)]
pub out_csv: Option<String>,

/// Filter which will be used when getting data from source table
#[arg(short, long)]
pub filter: Option<String>,
Expand Down
2 changes: 1 addition & 1 deletion lantern_cli/src/embeddings/core/http_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ macro_rules! HTTPRuntime {
for request_body in self.chunk_inputs(model_name, inputs)? {
let client = client.clone();
let url = url.clone();
let embedding_response =
let embedding_response: super::runtime::EmbeddingResult =
post_with_retries(client, url, request_body, Box::new($a::get_response), 5)
.await?;
processed_tokens_clone
Expand Down
30 changes: 28 additions & 2 deletions lantern_cli/src/embeddings/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ pub mod ort_runtime;
pub mod runtime;
pub mod utils;

use std::str::FromStr;
use std::{str::FromStr, sync::Arc};
use strum::{EnumIter, IntoEnumIterator};

use cohere_runtime::CohereRuntime;
use openai_runtime::OpenAiRuntime;
use ort_runtime::OrtRuntime;
use runtime::EmbeddingRuntimeT;

use self::runtime::EmbeddingResult;
use self::runtime::{BatchCompletionResult, CompletionResult, EmbeddingResult};

fn default_logger(text: &str) {
println!("{}", text);
Expand Down Expand Up @@ -96,6 +96,32 @@ impl<'a> EmbeddingRuntime<'a> {
}
}

pub async fn completion(
&self,
model_name: &str,
query: &str,
) -> Result<CompletionResult, anyhow::Error> {
match self {
EmbeddingRuntime::OpenAi(runtime) => {
runtime.completion(model_name, query, Some(1)).await
}
_ => anyhow::bail!("completion is not available for this runtime"),
}
}

pub async fn batch_completion(
&self,
model_name: &str,
queries: &Vec<&str>,
) -> Result<BatchCompletionResult, anyhow::Error> {
match self {
EmbeddingRuntime::OpenAi(runtime) => {
OpenAiRuntime::batch_completion(Arc::new(runtime), model_name, queries).await
}
_ => anyhow::bail!("completion is not available for this runtime"),
}
}

pub async fn get_available_models(&self) -> (String, Vec<(String, bool)>) {
match self {
EmbeddingRuntime::Cohere(runtime) => runtime.get_available_models().await,
Expand Down
133 changes: 130 additions & 3 deletions lantern_cli/src/embeddings/core/openai_runtime.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use itertools::Itertools;
use regex::Regex;
use serde_json::json;
use std::{collections::HashMap, sync::RwLock};

use super::{
runtime::{EmbeddingResult, EmbeddingRuntimeT},
runtime::{BatchCompletionResult, CompletionResult, EmbeddingResult, EmbeddingRuntimeT},
LoggerFn,
};
use crate::HTTPRuntime;
Expand Down Expand Up @@ -34,6 +35,33 @@ struct OpenAiResponse {
usage: OpenAiUsage,
}

#[derive(Deserialize, Debug)]
struct OpenAiChatMessage {
#[allow(dead_code)]
role: String,
content: String,
}

impl OpenAiChatMessage {
fn new() -> OpenAiChatMessage {
OpenAiChatMessage {
role: "system".to_owned(),
content: "".to_owned(),
}
}
}

#[derive(Deserialize, Debug)]
struct OpenAiChatChoice {
message: OpenAiChatMessage,
}

#[derive(Deserialize, Debug)]
struct OpenAiChatResponse {
choices: Vec<OpenAiChatChoice>,
usage: OpenAiUsage,
}

enum OpenAiDeployment {
Azure,
OpenAi,
Expand Down Expand Up @@ -93,6 +121,7 @@ pub struct OpenAiRuntime<'a> {
request_timeout: u64,
base_url: String,
headers: Vec<(String, String)>,
context: serde_json::Value,
dimensions: Option<usize>,
#[allow(dead_code)]
logger: &'a LoggerFn,
Expand All @@ -104,6 +133,7 @@ pub struct OpenAiRuntimeParams {
pub api_token: Option<String>,
pub azure_api_token: Option<String>,
pub azure_entra_token: Option<String>,
pub context: Option<String>,
pub dimensions: Option<usize>,
}

Expand Down Expand Up @@ -144,6 +174,11 @@ impl<'a> OpenAiRuntime<'a> {
}
};

let context = match &runtime_params.context {
Some(system_prompt) => json!({ "role": "system", "content": system_prompt.clone()}),
None => json!({ "role": "system", "content": "" }),
};

Ok(Self {
base_url,
logger,
Expand All @@ -153,6 +188,7 @@ impl<'a> OpenAiRuntime<'a> {
auth_header,
],
dimensions: runtime_params.dimensions,
context,
})
}

Expand All @@ -162,7 +198,7 @@ impl<'a> OpenAiRuntime<'a> {
if base_url.is_none() {
return Ok((
OpenAiDeployment::OpenAi,
"https://api.openai.com/v1/embeddings".to_owned(),
"https://api.openai.com".to_owned(),
));
}

Expand Down Expand Up @@ -263,6 +299,70 @@ impl<'a> OpenAiRuntime<'a> {
Ok(batch_tokens)
}

pub async fn completion(
&self,
model_name: &str,
query: &str,
retries: Option<usize>,
) -> Result<CompletionResult, anyhow::Error> {
let client = Arc::new(self.get_client()?);
let url = Url::parse(&self.base_url)?
.join("/v1/chat/completions")?
.to_string();
let completion_response: CompletionResult = post_with_retries(
client,
url,
serde_json::to_string(&json!({
"model": model_name,
"messages": [
self.context,
{ "role": "user", "content": query }
]
}))?,
Box::new(Self::get_completion_response),
retries.unwrap_or(5),
)
.await?;

Ok(completion_response)
}

pub async fn batch_completion(
self: Arc<&Self>,
model_name: &str,
queries: &Vec<&str>,
) -> Result<BatchCompletionResult, anyhow::Error> {
let mut processed_tokens = 0;

let completion_futures = queries.into_iter().map(|query| {
let self_clone = Arc::clone(&self);
let model_name_clone = model_name.to_owned();
async move {
self_clone
.completion(&model_name_clone, &query, Some(5))
.await
}
});

let results = futures::future::join_all(completion_futures).await;

let mut responses = Vec::with_capacity(results.len());
for result in results {
match result {
Ok(msg) => {
processed_tokens += msg.processed_tokens;
responses.push(msg.message);
}
Err(e) => responses.push(format!("Error: {e}")),
}
}

Ok(BatchCompletionResult {
messages: responses,
processed_tokens,
})
}

// Static functions
pub fn get_response(body: Vec<u8>) -> Result<EmbeddingResult, anyhow::Error> {
let result: Result<OpenAiResponse, serde_json::Error> = serde_json::from_slice(&body);
Expand All @@ -284,6 +384,31 @@ impl<'a> OpenAiRuntime<'a> {
.collect(),
})
}

pub fn get_completion_response(body: Vec<u8>) -> Result<CompletionResult, anyhow::Error> {
let result: Result<OpenAiChatResponse, serde_json::Error> = serde_json::from_slice(&body);
if let Err(e) = result {
anyhow::bail!(
"Error: {e}. OpenAI response: {:?}",
serde_json::from_slice::<serde_json::Value>(&body)?
);
}

let result = result.unwrap();

Ok(CompletionResult {
processed_tokens: result.usage.total_tokens,
message: result
.choices
.first()
.unwrap_or(&OpenAiChatChoice {
message: OpenAiChatMessage::new(),
})
.message
.content
.clone(),
})
}
}

impl<'a> EmbeddingRuntimeT for OpenAiRuntime<'a> {
Expand All @@ -292,7 +417,8 @@ impl<'a> EmbeddingRuntimeT for OpenAiRuntime<'a> {
model_name: &str,
inputs: &Vec<&str>,
) -> Result<EmbeddingResult, anyhow::Error> {
self.post_request("", model_name, inputs).await
self.post_request("/v1/embeddings", model_name, inputs)
.await
}

async fn get_available_models(&self) -> (String, Vec<(String, bool)>) {
Expand All @@ -310,4 +436,5 @@ impl<'a> EmbeddingRuntimeT for OpenAiRuntime<'a> {
return (res, models);
}
}

HTTPRuntime!(OpenAiRuntime);
10 changes: 10 additions & 0 deletions lantern_cli/src/embeddings/core/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ pub struct EmbeddingResult {
pub processed_tokens: usize,
}

pub struct CompletionResult {
pub message: String,
pub processed_tokens: usize,
}

pub struct BatchCompletionResult {
pub messages: Vec<String>,
pub processed_tokens: usize,
}

pub trait EmbeddingRuntimeT {
fn process(
&self,
Expand Down
10 changes: 4 additions & 6 deletions lantern_cli/src/embeddings/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use std::sync::Arc;
use std::{fs::create_dir_all, io::Cursor, time::Duration};
use sysinfo::{System, SystemExt};

use super::runtime::EmbeddingResult;

type GetResponseFn = Box<dyn Fn(Vec<u8>) -> Result<EmbeddingResult, anyhow::Error> + Send + Sync>;
type GetResponseFn<T> = Box<dyn Fn(Vec<u8>) -> Result<T, anyhow::Error> + Send + Sync>;

pub async fn download_file(url: &str, path: &PathBuf) -> Result<(), anyhow::Error> {
let client = reqwest::Client::builder()
Expand Down Expand Up @@ -78,13 +76,13 @@ pub fn percent_gpu_memory_used() -> Result<f64, anyhow::Error> {
Ok((mem_info.used as f64 / mem_info.total as f64) * 100.0)
}

pub async fn post_with_retries(
pub async fn post_with_retries<T>(
client: Arc<reqwest::Client>,
url: String,
body: String,
get_response_fn: GetResponseFn,
get_response_fn: GetResponseFn<T>,
max_retries: usize,
) -> Result<EmbeddingResult, anyhow::Error> {
) -> Result<T, anyhow::Error> {
let starting_interval = 4000; // ms
let mut last_error = "".to_string();

Expand Down
1 change: 0 additions & 1 deletion lantern_cli/src/embeddings/measure_speed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ async fn measure_model_speed(
schema: SCHEMA_NAME.to_owned(),
table: table_name.to_owned(),
out_uri: None,
out_csv: None,
out_table: None,
runtime: runtime.clone(),
runtime_params: runtime_params.to_owned(),
Expand Down
Loading
Loading