diff --git a/crates/http-api-bindings/src/completion/mistral.rs b/crates/http-api-bindings/src/completion/mistral.rs index 9ac192f92810..4e076be8b5bc 100644 --- a/crates/http-api-bindings/src/completion/mistral.rs +++ b/crates/http-api-bindings/src/completion/mistral.rs @@ -5,6 +5,8 @@ use reqwest_eventsource::{Event, EventSource}; use serde::{Deserialize, Serialize}; use tabby_inference::{CompletionOptions, CompletionStream}; +use super::FIM_TOKEN; + pub struct MistralFIMEngine { client: reqwest::Client, api_endpoint: String, @@ -12,8 +14,14 @@ pub struct MistralFIMEngine { model_name: String, } +const DEFAULT_API_ENDPOINT: &str = "https://api.mistral.ai"; + impl MistralFIMEngine { - pub fn create(api_endpoint: &str, api_key: Option, model_name: Option) -> Self { + pub fn create( + api_endpoint: Option<&str>, + api_key: Option, + model_name: Option, + ) -> Self { let client = reqwest::Client::new(); let model_name = model_name.unwrap_or("codestral-latest".into()); let api_key = api_key.expect("API key is required for mistral/completion"); @@ -21,7 +29,10 @@ impl MistralFIMEngine { Self { client, model_name, - api_endpoint: format!("{}/v1/fim/completions", api_endpoint), + api_endpoint: format!( + "{}/v1/fim/completions", + api_endpoint.unwrap_or(DEFAULT_API_ENDPOINT) + ), api_key, } } @@ -30,7 +41,7 @@ impl MistralFIMEngine { #[derive(Serialize)] struct FIMRequest { prompt: String, - suffix: String, + suffix: Option, model: String, temperature: f32, max_tokens: i32, @@ -57,10 +68,13 @@ struct FIMResponseDelta { #[async_trait] impl CompletionStream for MistralFIMEngine { async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { - let parts: Vec<&str> = prompt.split("").collect(); + let parts = prompt.splitn(2, FIM_TOKEN).collect::>(); let request = FIMRequest { prompt: parts[0].to_owned(), - suffix: parts[1].to_owned(), + suffix: parts + .get(1) + .map(|x| x.to_string()) + .filter(|x| !x.is_empty()), model: self.model_name.clone(), max_tokens: options.max_decoding_tokens, temperature: options.sampling_temperature, diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index c49b1db75a3e..9f8a698e9484 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -23,13 +23,9 @@ pub async fn create(model: &HttpModelConfig) -> Arc { Arc::new(engine) } "ollama/completion" => ollama_api_bindings::create_completion(model).await, - "mistral/completion" => { let engine = MistralFIMEngine::create( - model - .api_endpoint - .as_deref() - .expect("api_endpoint is required"), + model.api_endpoint.as_deref(), model.api_key.clone(), model.model_name.clone(), ); @@ -46,7 +42,6 @@ pub async fn create(model: &HttpModelConfig) -> Arc { ); Arc::new(engine) } - unsupported_kind => panic!( "Unsupported model kind for http completion: {}", unsupported_kind @@ -54,9 +49,12 @@ pub async fn create(model: &HttpModelConfig) -> Arc { } } +const FIM_TOKEN: &str = "<|FIM|>"; +const FIM_TEMPLATE: &str = "{prefix}<|FIM|>{suffix}"; + pub fn build_completion_prompt(model: &HttpModelConfig) -> (Option, Option) { - if model.kind == "mistral/completion" { - (Some("{prefix}{suffix}".to_owned()), None) + if model.kind == "mistral/completion" || model.kind == "openai/completion" { + (Some(FIM_TEMPLATE.to_owned()), None) } else { (model.prompt_template.clone(), model.chat_template.clone()) } diff --git a/crates/http-api-bindings/src/completion/openai.rs b/crates/http-api-bindings/src/completion/openai.rs index 3606d88a80ab..1445237b418a 100644 --- a/crates/http-api-bindings/src/completion/openai.rs +++ b/crates/http-api-bindings/src/completion/openai.rs @@ -5,6 +5,8 @@ use reqwest_eventsource::{Event, EventSource}; use serde::{Deserialize, Serialize}; use tabby_inference::{CompletionOptions, CompletionStream}; +use super::FIM_TOKEN; + pub struct OpenAICompletionEngine { client: reqwest::Client, model_name: String, @@ -14,7 +16,7 @@ pub struct OpenAICompletionEngine { impl OpenAICompletionEngine { pub fn create(model_name: Option, api_endpoint: &str, api_key: Option) -> Self { - let model_name = model_name.unwrap(); + let model_name = model_name.expect("model_name is required for openai/completion"); let client = reqwest::Client::new(); Self { @@ -30,6 +32,7 @@ impl OpenAICompletionEngine { struct CompletionRequest { model: String, prompt: String, + suffix: Option, max_tokens: i32, temperature: f32, stream: bool, @@ -50,9 +53,14 @@ struct CompletionResponseChoice { #[async_trait] impl CompletionStream for OpenAICompletionEngine { async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { + let parts = prompt.splitn(2, FIM_TOKEN).collect::>(); let request = CompletionRequest { model: self.model_name.clone(), - prompt: prompt.to_owned(), + prompt: parts[0].to_owned(), + suffix: parts + .get(1) + .map(|x| x.to_string()) + .filter(|x| !x.is_empty()), max_tokens: options.max_decoding_tokens, temperature: options.sampling_temperature, stream: true,