Skip to content

Commit

Permalink
feat(tabby): support voyage embedding api (#2355)
Browse files Browse the repository at this point in the history
* feat(tabby): support voyage embedding api

Signed-off-by: TennyZhuang <[email protected]>

* [autofix.ci] apply automated fixes

* exit if necessary args are missing

Signed-off-by: TennyZhuang <[email protected]>

---------

Signed-off-by: TennyZhuang <[email protected]>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
TennyZhuang and autofix-ci[bot] authored Jun 6, 2024
1 parent 2d8d746 commit beb7d90
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 10 deletions.
18 changes: 16 additions & 2 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod llama;
mod openai;
mod voyage;

use core::panic;
use std::sync::Arc;
Expand All @@ -8,7 +9,7 @@ use llama::LlamaCppEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::Embedding;

use self::openai::OpenAIEmbeddingEngine;
use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};

pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
match config.kind.as_str() {
Expand All @@ -25,7 +26,20 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
Arc::new(engine)
}
"ollama/embedding" => ollama_api_bindings::create_embedding(config).await,

"voyage/embedding" => {
let engine = VoyageEmbeddingEngine::create(
&config.api_endpoint,
config
.model_name
.as_deref()
.expect("model_name must be set for voyage/embedding"),
config
.api_key
.clone()
.expect("api_key must be set for voyage/embedding"),
);
Arc::new(engine)
}
unsupported_kind => panic!(
"Unsupported kind for http embedding model: {}",
unsupported_kind
Expand Down
98 changes: 98 additions & 0 deletions crates/http-api-bindings/src/embedding/voyage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use anyhow::Context;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;

const DEFAULT_VOYAGE_API_ENDPOINT: &str = "https://api.voyageai.com";

pub struct VoyageEmbeddingEngine {
client: Client,
api_endpoint: String,
api_key: String,
model_name: String,
}

impl VoyageEmbeddingEngine {
pub fn create(api_endpoint: &str, model_name: &str, api_key: String) -> Self {
let endpoint = if api_endpoint.is_empty() {
DEFAULT_VOYAGE_API_ENDPOINT
} else {
api_endpoint
};
let client = Client::new();
Self {
client,
api_endpoint: format!("{}/v1/embeddings", endpoint),
api_key,
model_name: model_name.to_owned(),
}
}
}

#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
}

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}

#[async_trait]
impl Embedding for VoyageEmbeddingEngine {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
let request = EmbeddingRequest {
input: vec![prompt.to_owned()],
model: self.model_name.clone(),
};

let request_builder = self
.client
.post(&self.api_endpoint)
.json(&request)
.header("content-type", "application/json")
.bearer_auth(&self.api_key);

let response = request_builder.send().await?;
if response.status().is_server_error() {
let error = response.text().await?;
return Err(anyhow::anyhow!("Error from server: {}", error));
}

let response_body = response
.json::<EmbeddingResponse>()
.await
.context("Failed to parse response body")?;

response_body
.data
.into_iter()
.next()
.map(|data| data.embedding)
.ok_or_else(|| anyhow::anyhow!("No embedding data found"))
}
}

#[cfg(test)]
mod tests {
use super::*;

/// Make sure you have set the VOYAGE_API_KEY environment variable before running the test
#[tokio::test]
#[ignore]
async fn test_voyage_embedding() {
let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set");
let engine =
VoyageEmbeddingEngine::create(DEFAULT_VOYAGE_API_ENDPOINT, "voyage-code-2", api_key);
let embedding = engine.embed("Hello, world!").await.unwrap();
assert_eq!(embedding.len(), 1536);
}
}
16 changes: 8 additions & 8 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
terminal::{HeaderFormat, InfoMessage},
};

#[derive(Serialize, Deserialize, Default, Clone)]
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub struct Config {
#[serde(default)]
pub repositories: Vec<RepositoryConfig>,
Expand Down Expand Up @@ -124,7 +124,7 @@ fn sanitize_name(s: &str) -> String {
sanitized.into_iter().collect()
}

#[derive(Serialize, Deserialize, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ServerConfig {
/// The timeout in seconds for the /v1/completion api.
pub completion_timeout: u64,
Expand All @@ -138,21 +138,21 @@ impl Default for ServerConfig {
}
}

#[derive(Serialize, Deserialize, Default, Clone)]
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub struct ModelConfigGroup {
pub completion: Option<ModelConfig>,
pub chat: Option<ModelConfig>,
pub embedding: Option<ModelConfig>,
}

#[derive(Serialize, Deserialize, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ModelConfig {
Http(HttpModelConfig),
Local(LocalModelConfig),
}

#[derive(Serialize, Deserialize, Builder, Clone)]
#[derive(Serialize, Deserialize, Builder, Debug, Clone)]
pub struct HttpModelConfig {
/// The kind of model, we have three group of models:
/// 1. Completion API [CompletionStream](tabby_inference::CompletionStream)
Expand Down Expand Up @@ -181,7 +181,7 @@ pub struct HttpModelConfig {
pub chat_template: Option<String>,
}

#[derive(Serialize, Deserialize, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LocalModelConfig {
pub model_id: String,

Expand All @@ -200,12 +200,12 @@ fn default_num_gpu_layers() -> u16 {
9999
}

#[derive(Serialize, Deserialize, Default, Clone)]
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub struct ExperimentalConfig {
pub doc: Option<DocIndexConfig>,
}

#[derive(Serialize, Deserialize, Default, Clone)]
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub struct DocIndexConfig {
pub start_urls: Vec<String>,
}
Expand Down

0 comments on commit beb7d90

Please sign in to comment.