Skip to content

Commit

Permalink
fix: deepseek client auth (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlos-verdes authored Feb 7, 2025
1 parent 0ef7bfd commit df3231b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
5 changes: 5 additions & 0 deletions rig-core/examples/agent_with_deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ use serde_json::json;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();

let client = providers::deepseek::Client::from_env();
let agent = client
.agent("deepseek-chat")
Expand Down
37 changes: 24 additions & 13 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
use reqwest::Client as HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};

use super::openai::AssistantContent;

Expand All @@ -30,18 +30,13 @@ const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
#[derive(Clone)]
pub struct Client {
pub base_url: String,
pub api_key: String,
http_client: HttpClient,
}

impl Client {
// Create a new DeepSeek client from an API key.
pub fn new(api_key: &str) -> Self {
Self {
base_url: DEEPSEEK_API_BASE_URL.to_string(),
api_key: api_key.to_string(),
http_client: HttpClient::new(),
}
Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
}

// If you prefer the environment variable approach:
Expand All @@ -55,8 +50,19 @@ impl Client {
// Possibly configure a custom HTTP client here if needed.
Self {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
http_client: HttpClient::new(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenAI reqwest client should build"),
}
}

Expand Down Expand Up @@ -259,15 +265,20 @@ impl CompletionModel for DeepSeekCompletionModel {
.await?;

if response.status().is_success() {
let t = response.text().await?;
tracing::debug!(target: "rig", "DeepSeek completion error: {}", t);
let t: Value = response.json().await?;
tracing::debug!(
target: "rig",
"DeepSeek completion success: {}",
serde_json::to_string_pretty(&t).unwrap());

match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
match serde_json::from_value::<ApiResponse<CompletionResponse>>(t)? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
let t = response.text().await?;
tracing::debug!(target: "rig", "DeepSeek completion error: {}", t);
Err(CompletionError::ProviderError(t))
}
}
}
Expand Down

0 comments on commit df3231b

Please sign in to comment.