Skip to content

Commit

Permalink
feat: groq integration (#263)
Browse files Browse the repository at this point in the history
* feat: groq integration

* refactor: add example

* refactor: message type not OpenAI chat completion compatible

* chore: satisfy ci
  • Loading branch information
joshua-mo-143 authored Feb 12, 2025
1 parent 530a327 commit 1b71738
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 0 deletions.
25 changes: 25 additions & 0 deletions rig-core/examples/agent_with_groq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::env;

use rig::{
completion::Prompt,
providers::{self, groq::DEEPSEEK_R1_DISTILL_LLAMA_70B},
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let client =
providers::groq::Client::new(&env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set"));

// Create agent with a single context prompt
let comedian_agent = client
.agent(DEEPSEEK_R1_DISTILL_LLAMA_70B)
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

// Prompt the agent and print the response
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);

Ok(())
}
336 changes: 336 additions & 0 deletions rig-core/src/providers/groq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
//! Groq API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::groq;
//!
//! let client = groq::Client::new("YOUR_API_KEY");
//!
//! let gpt4o = client.completion_model(groq::GPT_4O);
//! ```
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
message::{self, MessageError},
providers::openai::ToolDefinition,
OneOrMany,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;

use super::openai::CompletionResponse;

// ================================================================
// Main Groq Client
// ================================================================
const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
/// Create a new Groq client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, GROQ_API_BASE_URL)
}

/// Create a new Groq client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Groq reqwest client should build"),
}
}

/// Create a new Groq client from the `GROQ_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
Self::new(&api_key)
}

fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::groq::{Client, self};
///
/// // Initialize the Groq client
/// let groq = Client::new("your-groq-api-key");
///
/// let gpt4 = groq.completion_model(groq::GPT_4);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::groq::{Client, self};
///
/// // Initialize the Groq client
/// let groq = Client::new("your-groq-api-key");
///
/// let agent = groq.agent(groq::GPT_4)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
}

impl TryFrom<Message> for message::Message {
type Error = message::MessageError;

fn try_from(message: Message) -> Result<Self, Self::Error> {
match message.role.as_str() {
"user" => Ok(Self::User {
content: OneOrMany::one(
message
.content
.map(|content| message::UserContent::text(&content))
.ok_or_else(|| {
message::MessageError::ConversionError("Empty user message".to_string())
})?,
),
}),
"assistant" => Ok(Self::Assistant {
content: OneOrMany::one(
message
.content
.map(|content| message::AssistantContent::text(&content))
.ok_or_else(|| {
message::MessageError::ConversionError(
"Empty assistant message".to_string(),
)
})?,
),
}),
_ => Err(message::MessageError::ConversionError(format!(
"Unknown role: {}",
message.role
))),
}
}
}

impl TryFrom<message::Message> for Message {
type Error = message::MessageError;

fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => Ok(Self {
role: "user".to_string(),
content: content.iter().find_map(|c| match c {
message::UserContent::Text(text) => Some(text.text.clone()),
_ => None,
}),
}),
message::Message::Assistant { content } => {
let mut text_content: Option<String> = None;

for c in content.iter() {
match c {
message::AssistantContent::Text(text) => {
text_content = Some(
text_content
.map(|mut existing| {
existing.push('\n');
existing.push_str(&text.text);
existing
})
.unwrap_or_else(|| text.text.clone()),
);
}
message::AssistantContent::ToolCall(_tool_call) => {
return Err(MessageError::ConversionError(
"Tool calls do not exist on this message".into(),
))
}
}
}

Ok(Self {
role: "assistant".to_string(),
content: text_content,
})
}
}
}
}

// ================================================================
// Groq Completion API
// ================================================================
/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
/// The `gemma2-9b-it` model. Used for chat completion.
pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
/// The `llama-3.1-8b-instant` model. Used for chat completion.
pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
/// The `llama-3.2-1b-preview` model. Used for chat completion.
pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
/// The `llama-3.2-3b-preview` model. Used for chat completion.
pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
/// The `llama-3.2-70b-specdec` model. Used for chat completion.
pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
/// The `llama-3.2-70b-versatile` model. Used for chat completion.
pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
/// The `llama-guard-3-8b` model. Used for chat completion.
pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
/// The `llama3-70b-8192` model. Used for chat completion.
pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
/// The `llama3-8b-8192` model. Used for chat completion.
pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
/// The `mixtral-8x7b-32768` model. Used for chat completion.
pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";

#[derive(Clone)]
pub struct CompletionModel {
client: Client,
/// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
pub model: String,
}

impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}

impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;

#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message {
role: "system".to_string(),
content: Some(preamble.to_string()),
}],
None => vec![],
};

// Convert prompt to user message
let prompt: Message = completion_request.prompt_with_context().try_into()?;

// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Message>, _>>()?;

// Combine all messages into a single history
full_history.extend(chat_history);
full_history.push(prompt);

let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send()
.await?;

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"groq completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub mod cohere;
pub mod deepseek;
pub mod galadriel;
pub mod gemini;
pub mod groq;
pub mod hyperbolic;
pub mod moonshot;
pub mod openai;
Expand Down

0 comments on commit 1b71738

Please sign in to comment.