diff --git a/rig-core/examples/agent_with_deepseek.rs b/rig-core/examples/agent_with_deepseek.rs index 0a609c7b..94fe7a20 100644 --- a/rig-core/examples/agent_with_deepseek.rs +++ b/rig-core/examples/agent_with_deepseek.rs @@ -8,6 +8,11 @@ use serde_json::json; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_target(false) + .init(); + let client = providers::deepseek::Client::from_env(); let agent = client .agent("deepseek-chat") diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index e3e8d64b..f32a61a0 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -18,7 +18,7 @@ use crate::{ use reqwest::Client as HttpClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{json, Value}; use super::openai::AssistantContent; @@ -30,18 +30,13 @@ const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com"; #[derive(Clone)] pub struct Client { pub base_url: String, - pub api_key: String, http_client: HttpClient, } impl Client { // Create a new DeepSeek client from an API key. pub fn new(api_key: &str) -> Self { - Self { - base_url: DEEPSEEK_API_BASE_URL.to_string(), - api_key: api_key.to_string(), - http_client: HttpClient::new(), - } + Self::from_url(api_key, DEEPSEEK_API_BASE_URL) } // If you prefer the environment variable approach: @@ -55,8 +50,19 @@ impl Client { // Possibly configure a custom HTTP client here if needed. Self { base_url: base_url.to_string(), - api_key: api_key.to_string(), - http_client: HttpClient::new(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("OpenAI reqwest client should build"), } } @@ -259,15 +265,20 @@ impl CompletionModel for DeepSeekCompletionModel { .await?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "DeepSeek completion error: {}", t); + let t: Value = response.json().await?; + tracing::debug!( + target: "rig", + "DeepSeek completion success: {}", + serde_json::to_string_pretty(&t).unwrap()); - match serde_json::from_str::>(&t)? { + match serde_json::from_value::>(t)? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + let t = response.text().await?; + tracing::debug!(target: "rig", "DeepSeek completion error: {}", t); + Err(CompletionError::ProviderError(t)) } } }