-
Notifications
You must be signed in to change notification settings - Fork 283
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: groq integration * refactor: add example * refactor: message type not OpenAI chat completion compatible * chore: satisfy ci
- Loading branch information
1 parent
530a327
commit 1b71738
Showing
3 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
use std::env; | ||
|
||
use rig::{ | ||
completion::Prompt, | ||
providers::{self, groq::DEEPSEEK_R1_DISTILL_LLAMA_70B}, | ||
}; | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Create OpenAI client | ||
let client = | ||
providers::groq::Client::new(&env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set")); | ||
|
||
// Create agent with a single context prompt | ||
let comedian_agent = client | ||
.agent(DEEPSEEK_R1_DISTILL_LLAMA_70B) | ||
.preamble("You are a comedian here to entertain the user using humour and jokes.") | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
let response = comedian_agent.prompt("Entertain me!").await?; | ||
println!("{}", response); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
//! Groq API client and Rig integration | ||
//! | ||
//! # Example | ||
//! ``` | ||
//! use rig::providers::groq; | ||
//! | ||
//! let client = groq::Client::new("YOUR_API_KEY"); | ||
//! | ||
//! let gpt4o = client.completion_model(groq::GPT_4O); | ||
//! ``` | ||
use crate::{ | ||
agent::AgentBuilder, | ||
completion::{self, CompletionError, CompletionRequest}, | ||
extractor::ExtractorBuilder, | ||
json_utils, | ||
message::{self, MessageError}, | ||
providers::openai::ToolDefinition, | ||
OneOrMany, | ||
}; | ||
use schemars::JsonSchema; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::json; | ||
|
||
use super::openai::CompletionResponse; | ||
|
||
// ================================================================ | ||
// Main Groq Client | ||
// ================================================================ | ||
const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1"; | ||
|
||
#[derive(Clone)] | ||
pub struct Client { | ||
base_url: String, | ||
http_client: reqwest::Client, | ||
} | ||
|
||
impl Client { | ||
/// Create a new Groq client with the given API key. | ||
pub fn new(api_key: &str) -> Self { | ||
Self::from_url(api_key, GROQ_API_BASE_URL) | ||
} | ||
|
||
/// Create a new Groq client with the given API key and base API URL. | ||
pub fn from_url(api_key: &str, base_url: &str) -> Self { | ||
Self { | ||
base_url: base_url.to_string(), | ||
http_client: reqwest::Client::builder() | ||
.default_headers({ | ||
let mut headers = reqwest::header::HeaderMap::new(); | ||
headers.insert( | ||
"Authorization", | ||
format!("Bearer {}", api_key) | ||
.parse() | ||
.expect("Bearer token should parse"), | ||
); | ||
headers | ||
}) | ||
.build() | ||
.expect("Groq reqwest client should build"), | ||
} | ||
} | ||
|
||
/// Create a new Groq client from the `GROQ_API_KEY` environment variable. | ||
/// Panics if the environment variable is not set. | ||
pub fn from_env() -> Self { | ||
let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set"); | ||
Self::new(&api_key) | ||
} | ||
|
||
fn post(&self, path: &str) -> reqwest::RequestBuilder { | ||
let url = format!("{}/{}", self.base_url, path).replace("//", "/"); | ||
self.http_client.post(url) | ||
} | ||
|
||
/// Create a completion model with the given name. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::groq::{Client, self}; | ||
/// | ||
/// // Initialize the Groq client | ||
/// let groq = Client::new("your-groq-api-key"); | ||
/// | ||
/// let gpt4 = groq.completion_model(groq::GPT_4); | ||
/// ``` | ||
pub fn completion_model(&self, model: &str) -> CompletionModel { | ||
CompletionModel::new(self.clone(), model) | ||
} | ||
|
||
/// Create an agent builder with the given completion model. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::groq::{Client, self}; | ||
/// | ||
/// // Initialize the Groq client | ||
/// let groq = Client::new("your-groq-api-key"); | ||
/// | ||
/// let agent = groq.agent(groq::GPT_4) | ||
/// .preamble("You are comedian AI with a mission to make people laugh.") | ||
/// .temperature(0.0) | ||
/// .build(); | ||
/// ``` | ||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> { | ||
AgentBuilder::new(self.completion_model(model)) | ||
} | ||
|
||
/// Create an extractor builder with the given completion model. | ||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>( | ||
&self, | ||
model: &str, | ||
) -> ExtractorBuilder<T, CompletionModel> { | ||
ExtractorBuilder::new(self.completion_model(model)) | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct ApiErrorResponse { | ||
message: String, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
#[serde(untagged)] | ||
enum ApiResponse<T> { | ||
Ok(T), | ||
Err(ApiErrorResponse), | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct Message { | ||
pub role: String, | ||
pub content: Option<String>, | ||
} | ||
|
||
impl TryFrom<Message> for message::Message { | ||
type Error = message::MessageError; | ||
|
||
fn try_from(message: Message) -> Result<Self, Self::Error> { | ||
match message.role.as_str() { | ||
"user" => Ok(Self::User { | ||
content: OneOrMany::one( | ||
message | ||
.content | ||
.map(|content| message::UserContent::text(&content)) | ||
.ok_or_else(|| { | ||
message::MessageError::ConversionError("Empty user message".to_string()) | ||
})?, | ||
), | ||
}), | ||
"assistant" => Ok(Self::Assistant { | ||
content: OneOrMany::one( | ||
message | ||
.content | ||
.map(|content| message::AssistantContent::text(&content)) | ||
.ok_or_else(|| { | ||
message::MessageError::ConversionError( | ||
"Empty assistant message".to_string(), | ||
) | ||
})?, | ||
), | ||
}), | ||
_ => Err(message::MessageError::ConversionError(format!( | ||
"Unknown role: {}", | ||
message.role | ||
))), | ||
} | ||
} | ||
} | ||
|
||
impl TryFrom<message::Message> for Message { | ||
type Error = message::MessageError; | ||
|
||
fn try_from(message: message::Message) -> Result<Self, Self::Error> { | ||
match message { | ||
message::Message::User { content } => Ok(Self { | ||
role: "user".to_string(), | ||
content: content.iter().find_map(|c| match c { | ||
message::UserContent::Text(text) => Some(text.text.clone()), | ||
_ => None, | ||
}), | ||
}), | ||
message::Message::Assistant { content } => { | ||
let mut text_content: Option<String> = None; | ||
|
||
for c in content.iter() { | ||
match c { | ||
message::AssistantContent::Text(text) => { | ||
text_content = Some( | ||
text_content | ||
.map(|mut existing| { | ||
existing.push('\n'); | ||
existing.push_str(&text.text); | ||
existing | ||
}) | ||
.unwrap_or_else(|| text.text.clone()), | ||
); | ||
} | ||
message::AssistantContent::ToolCall(_tool_call) => { | ||
return Err(MessageError::ConversionError( | ||
"Tool calls do not exist on this message".into(), | ||
)) | ||
} | ||
} | ||
} | ||
|
||
Ok(Self { | ||
role: "assistant".to_string(), | ||
content: text_content, | ||
}) | ||
} | ||
} | ||
} | ||
} | ||
|
||
// ================================================================ | ||
// Groq Completion API | ||
// ================================================================ | ||
/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion. | ||
pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b"; | ||
/// The `gemma2-9b-it` model. Used for chat completion. | ||
pub const GEMMA2_9B_IT: &str = "gemma2-9b-it"; | ||
/// The `llama-3.1-8b-instant` model. Used for chat completion. | ||
pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant"; | ||
/// The `llama-3.2-11b-vision-preview` model. Used for chat completion. | ||
pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview"; | ||
/// The `llama-3.2-1b-preview` model. Used for chat completion. | ||
pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview"; | ||
/// The `llama-3.2-3b-preview` model. Used for chat completion. | ||
pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview"; | ||
/// The `llama-3.2-90b-vision-preview` model. Used for chat completion. | ||
pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview"; | ||
/// The `llama-3.2-70b-specdec` model. Used for chat completion. | ||
pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec"; | ||
/// The `llama-3.2-70b-versatile` model. Used for chat completion. | ||
pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile"; | ||
/// The `llama-guard-3-8b` model. Used for chat completion. | ||
pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b"; | ||
/// The `llama3-70b-8192` model. Used for chat completion. | ||
pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192"; | ||
/// The `llama3-8b-8192` model. Used for chat completion. | ||
pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192"; | ||
/// The `mixtral-8x7b-32768` model. Used for chat completion. | ||
pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768"; | ||
|
||
#[derive(Clone)] | ||
pub struct CompletionModel { | ||
client: Client, | ||
/// Name of the model (e.g.: deepseek-r1-distill-llama-70b) | ||
pub model: String, | ||
} | ||
|
||
impl CompletionModel { | ||
pub fn new(client: Client, model: &str) -> Self { | ||
Self { | ||
client, | ||
model: model.to_string(), | ||
} | ||
} | ||
} | ||
|
||
impl completion::CompletionModel for CompletionModel { | ||
type Response = CompletionResponse; | ||
|
||
#[cfg_attr(feature = "worker", worker::send)] | ||
async fn completion( | ||
&self, | ||
completion_request: CompletionRequest, | ||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> { | ||
// Add preamble to chat history (if available) | ||
let mut full_history: Vec<Message> = match &completion_request.preamble { | ||
Some(preamble) => vec![Message { | ||
role: "system".to_string(), | ||
content: Some(preamble.to_string()), | ||
}], | ||
None => vec![], | ||
}; | ||
|
||
// Convert prompt to user message | ||
let prompt: Message = completion_request.prompt_with_context().try_into()?; | ||
|
||
// Convert existing chat history | ||
let chat_history: Vec<Message> = completion_request | ||
.chat_history | ||
.into_iter() | ||
.map(|message| message.try_into()) | ||
.collect::<Result<Vec<Message>, _>>()?; | ||
|
||
// Combine all messages into a single history | ||
full_history.extend(chat_history); | ||
full_history.push(prompt); | ||
|
||
let request = if completion_request.tools.is_empty() { | ||
json!({ | ||
"model": self.model, | ||
"messages": full_history, | ||
"temperature": completion_request.temperature, | ||
}) | ||
} else { | ||
json!({ | ||
"model": self.model, | ||
"messages": full_history, | ||
"temperature": completion_request.temperature, | ||
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(), | ||
"tool_choice": "auto", | ||
}) | ||
}; | ||
|
||
let response = self | ||
.client | ||
.post("/chat/completions") | ||
.json( | ||
&if let Some(params) = completion_request.additional_params { | ||
json_utils::merge(request, params) | ||
} else { | ||
request | ||
}, | ||
) | ||
.send() | ||
.await?; | ||
|
||
if response.status().is_success() { | ||
match response.json::<ApiResponse<CompletionResponse>>().await? { | ||
ApiResponse::Ok(response) => { | ||
tracing::info!(target: "rig", | ||
"groq completion token usage: {:?}", | ||
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) | ||
); | ||
response.try_into() | ||
} | ||
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), | ||
} | ||
} else { | ||
Err(CompletionError::ProviderError(response.text().await?)) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters