Skip to content

Commit

Permalink
optionally truncate llm_completion messages to not exceed context window
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Oct 19, 2024
1 parent 5ef59f1 commit 7b0568d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 16 deletions.
45 changes: 34 additions & 11 deletions lantern_cli/src/embeddings/core/openai_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,28 @@ impl ModelInfo {
"gpt-4" => Ok(Self {
name,
tokenizer: cl100k_base()?,
sequence_len: 128000,
sequence_len: 127000,
dimensions: 0,
var_dimension: false,
}),
"gpt-4o" => Ok(Self {
name,
tokenizer: cl100k_base()?,
sequence_len: 128000,
sequence_len: 127000,
dimensions: 0,
var_dimension: false,
}),
"gpt-4o-mini" => Ok(Self {
name,
tokenizer: cl100k_base()?,
sequence_len: 128000,
sequence_len: 127000,
dimensions: 0,
var_dimension: false,
}),
"gpt-4-turbo" => Ok(Self {
name,
tokenizer: cl100k_base()?,
sequence_len: 128000,
sequence_len: 127000,
dimensions: 0,
var_dimension: false,
}),
Expand Down Expand Up @@ -188,6 +188,7 @@ pub struct OpenAiRuntime<'a> {
headers: Vec<(String, String)>,
context: serde_json::Value,
dimensions: Option<usize>,
truncate: bool,
#[allow(dead_code)]
logger: &'a LoggerFn,
}
Expand All @@ -200,6 +201,7 @@ pub struct OpenAiRuntimeParams {
pub azure_entra_token: Option<String>,
pub context: Option<String>,
pub dimensions: Option<usize>,
pub truncate: Option<bool>,
}

impl<'a> OpenAiRuntime<'a> {
Expand Down Expand Up @@ -254,6 +256,7 @@ impl<'a> OpenAiRuntime<'a> {
],
dimensions: runtime_params.dimensions,
context,
truncate: runtime_params.truncate.unwrap_or(false),
})
}

Expand Down Expand Up @@ -364,20 +367,40 @@ impl<'a> OpenAiRuntime<'a> {
let model_map = COMPLETION_MODEL_INFO_MAP.read().await;
let model_info = check_and_get_model!(model_map, model_name);

let body = if self.truncate {
let mut tokens = model_info.tokenizer.encode_with_special_tokens(query);
let context_tokens = model_info
.tokenizer
.encode_with_special_tokens(&self.context["content"].to_string());
if (tokens.len() + context_tokens.len()) > model_info.sequence_len {
tokens.truncate(model_info.sequence_len - context_tokens.len());
}
serde_json::to_string(&json!({
"model": model_info.name,
"messages": [
self.context,
{ "role": "user", "content": model_info.tokenizer.decode(tokens)? }
]
}))?
} else {
serde_json::to_string(&json!({
"model": model_info.name,
"messages": [
self.context,
{ "role": "user", "content": query }
]
}))?
};

let client = Arc::new(self.get_client()?);
let url = Url::parse(&self.base_url)?
.join("/v1/chat/completions")?
.to_string();

let completion_response: CompletionResult = post_with_retries(
client,
url,
serde_json::to_string(&json!({
"model": model_info.name,
"messages": [
self.context,
{ "role": "user", "content": query }
]
}))?,
body,
Box::new(Self::get_completion_response),
retries.unwrap_or(5),
)
Expand Down
3 changes: 2 additions & 1 deletion lantern_extras/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ SELECT add_completion_job(
'column_type', -- Target column type to be used for destination (default: TEXT)
'model', -- LLM model to use (default: 'gpt-4o')
'batch_size', -- Batch size to use when sending batch requests (default: 2)
'truncate', -- Truncate input message to fit in maximum context window (default: false)
'runtime', -- Runtime environment (default: 'openai')
'runtime_params', -- Runtime parameters (default: '{}' inferred from GUC variables)
'pk', -- Primary key column (default: 'id')
Expand Down Expand Up @@ -259,5 +260,5 @@ This will return a table with the following columns:
***Calling LLM Completion API***
```sql
SET lantern_extras.llm_token='xxxx';
SELECT llm_completion(query, [model, context, base_url, runtime]);
SELECT llm_completion(query, [model, context, truncate, base_url, runtime]);
```
Binary file not shown.
5 changes: 3 additions & 2 deletions lantern_extras/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ fn add_embedding_job<'a>(
if params == "{}" {
match runtime {
"openai" => {
params = get_openai_runtime_params("", "", 1536)?;
params = get_openai_runtime_params("", "", 1536, true)?;
}
"cohere" => {
params = get_cohere_runtime_params("search_document")?;
Expand Down Expand Up @@ -177,6 +177,7 @@ fn add_completion_job<'a>(
column_type: default!(&'a str, "'TEXT'"),
embedding_model: default!(&'a str, "'gpt-4o'"),
batch_size: default!(i32, -1),
truncate: default!(bool, "false"),
runtime: default!(&'a str, "'openai'"),
runtime_params: default!(&'a str, "'{}'"),
pk: default!(&'a str, "'id'"),
Expand All @@ -186,7 +187,7 @@ fn add_completion_job<'a>(
if params == "{}" {
match runtime {
"openai" => {
params = get_openai_runtime_params("", context, 0)?;
params = get_openai_runtime_params("", context, 0, truncate)?;
}
_ => anyhow::bail!("Runtime {runtime} does not support completion jobs"),
}
Expand Down
7 changes: 5 additions & 2 deletions lantern_extras/src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub fn get_openai_runtime_params(
base_url: &str,
context: &str,
dimensions: i32,
truncate: bool,
) -> Result<String, anyhow::Error> {
if OPENAI_TOKEN.get().is_none()
&& LLM_TOKEN.get().is_none()
Expand Down Expand Up @@ -93,6 +94,7 @@ pub fn get_openai_runtime_params(
azure_api_token,
azure_entra_token,
context,
truncate: Some(truncate),
})?;

Ok(params)
Expand Down Expand Up @@ -142,7 +144,7 @@ fn openai_embedding<'a>(
base_url: default!(&'a str, "''"),
dimensions: default!(i32, 1536),
) -> Result<Vec<f32>, anyhow::Error> {
let runtime_params = get_openai_runtime_params(base_url, "", dimensions)?;
let runtime_params = get_openai_runtime_params(base_url, "", dimensions, true)?;
let runtime = EmbeddingRuntime::new(
&Runtime::OpenAi,
Some(&(notice_fn as LoggerFn)),
Expand Down Expand Up @@ -231,10 +233,11 @@ fn llm_completion<'a>(
text: &'a str,
model_name: default!(&'a str, "'gpt-4o'"),
context: default!(&'a str, "''"),
truncate: default!(bool, "true"),
base_url: default!(&'a str, "''"),
runtime: default!(&'a str, "'openai'"),
) -> Result<String, anyhow::Error> {
let runtime_params = get_openai_runtime_params(base_url, context, 0)?;
let runtime_params = get_openai_runtime_params(base_url, context, 0, truncate)?;

let runtime = Runtime::try_from(runtime)?;
let embedding_runtime =
Expand Down

0 comments on commit 7b0568d

Please sign in to comment.