diff --git a/rig-core/examples/agent_with_groq.rs b/rig-core/examples/agent_with_groq.rs new file mode 100644 index 00000000..7d55d833 --- /dev/null +++ b/rig-core/examples/agent_with_groq.rs @@ -0,0 +1,25 @@ +use std::env; + +use rig::{ + completion::Prompt, + providers::{self, groq::DEEPSEEK_R1_DISTILL_LLAMA_70B}, +}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create OpenAI client + let client = + providers::groq::Client::new(&env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set")); + + // Create agent with a single context prompt + let comedian_agent = client + .agent(DEEPSEEK_R1_DISTILL_LLAMA_70B) + .preamble("You are a comedian here to entertain the user using humour and jokes.") + .build(); + + // Prompt the agent and print the response + let response = comedian_agent.prompt("Entertain me!").await?; + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs new file mode 100644 index 00000000..e1b5574b --- /dev/null +++ b/rig-core/src/providers/groq.rs @@ -0,0 +1,336 @@ +//! Groq API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::groq; +//! +//! let client = groq::Client::new("YOUR_API_KEY"); +//! +//! let gpt4o = client.completion_model(groq::GPT_4O); +//! ``` +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError, CompletionRequest}, + extractor::ExtractorBuilder, + json_utils, + message::{self, MessageError}, + providers::openai::ToolDefinition, + OneOrMany, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use super::openai::CompletionResponse; + +// ================================================================ +// Main Groq Client +// ================================================================ +const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new Groq client with the given API key. + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, GROQ_API_BASE_URL) + } + + /// Create a new Groq client with the given API key and base API URL. + pub fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("Groq reqwest client should build"), + } + } + + /// Create a new Groq client from the `GROQ_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set"); + Self::new(&api_key) + } + + fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + /// Create a completion model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::groq::{Client, self}; + /// + /// // Initialize the Groq client + /// let groq = Client::new("your-groq-api-key"); + /// + /// let gpt4 = groq.completion_model(groq::GPT_4); + /// ``` + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::groq::{Client, self}; + /// + /// // Initialize the Groq client + /// let groq = Client::new("your-groq-api-key"); + /// + /// let agent = groq.agent(groq::GPT_4) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option, +} + +impl TryFrom for message::Message { + type Error = message::MessageError; + + fn try_from(message: Message) -> Result { + match message.role.as_str() { + "user" => Ok(Self::User { + content: OneOrMany::one( + message + .content + .map(|content| message::UserContent::text(&content)) + .ok_or_else(|| { + message::MessageError::ConversionError("Empty user message".to_string()) + })?, + ), + }), + "assistant" => Ok(Self::Assistant { + content: OneOrMany::one( + message + .content + .map(|content| message::AssistantContent::text(&content)) + .ok_or_else(|| { + message::MessageError::ConversionError( + "Empty assistant message".to_string(), + ) + })?, + ), + }), + _ => Err(message::MessageError::ConversionError(format!( + "Unknown role: {}", + message.role + ))), + } + } +} + +impl TryFrom for Message { + type Error = message::MessageError; + + fn try_from(message: message::Message) -> Result { + match message { + message::Message::User { content } => Ok(Self { + role: "user".to_string(), + content: content.iter().find_map(|c| match c { + message::UserContent::Text(text) => Some(text.text.clone()), + _ => None, + }), + }), + message::Message::Assistant { content } => { + let mut text_content: Option = None; + + for c in content.iter() { + match c { + message::AssistantContent::Text(text) => { + text_content = Some( + text_content + .map(|mut existing| { + existing.push('\n'); + existing.push_str(&text.text); + existing + }) + .unwrap_or_else(|| text.text.clone()), + ); + } + message::AssistantContent::ToolCall(_tool_call) => { + return Err(MessageError::ConversionError( + "Tool calls do not exist on this message".into(), + )) + } + } + } + + Ok(Self { + role: "assistant".to_string(), + content: text_content, + }) + } + } + } +} + +// ================================================================ +// Groq Completion API +// ================================================================ +/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion. +pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b"; +/// The `gemma2-9b-it` model. Used for chat completion. +pub const GEMMA2_9B_IT: &str = "gemma2-9b-it"; +/// The `llama-3.1-8b-instant` model. Used for chat completion. +pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant"; +/// The `llama-3.2-11b-vision-preview` model. Used for chat completion. +pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview"; +/// The `llama-3.2-1b-preview` model. Used for chat completion. +pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview"; +/// The `llama-3.2-3b-preview` model. Used for chat completion. +pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview"; +/// The `llama-3.2-90b-vision-preview` model. Used for chat completion. +pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview"; +/// The `llama-3.2-70b-specdec` model. Used for chat completion. +pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec"; +/// The `llama-3.2-70b-versatile` model. Used for chat completion. +pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile"; +/// The `llama-guard-3-8b` model. Used for chat completion. +pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b"; +/// The `llama3-70b-8192` model. Used for chat completion. +pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192"; +/// The `llama3-8b-8192` model. Used for chat completion. +pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192"; +/// The `mixtral-8x7b-32768` model. Used for chat completion. +pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768"; + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: deepseek-r1-distill-llama-70b) + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + // Add preamble to chat history (if available) + let mut full_history: Vec = match &completion_request.preamble { + Some(preamble) => vec![Message { + role: "system".to_string(), + content: Some(preamble.to_string()), + }], + None => vec![], + }; + + // Convert prompt to user message + let prompt: Message = completion_request.prompt_with_context().try_into()?; + + // Convert existing chat history + let chat_history: Vec = completion_request + .chat_history + .into_iter() + .map(|message| message.try_into()) + .collect::, _>>()?; + + // Combine all messages into a single history + full_history.extend(chat_history); + full_history.push(prompt); + + let request = if completion_request.tools.is_empty() { + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + }) + } else { + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), + "tool_choice": "auto", + }) + }; + + let response = self + .client + .post("/chat/completions") + .json( + &if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }, + ) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "groq completion token usage: {:?}", + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) + ); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 138f0a59..2308cdc5 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -50,6 +50,7 @@ pub mod cohere; pub mod deepseek; pub mod galadriel; pub mod gemini; +pub mod groq; pub mod hyperbolic; pub mod moonshot; pub mod openai;