Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle multiple tool calls #224

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 0 additions & 113 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ thiserror = "1.0.61"
rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true }
glob = "0.3.1"
lopdf = { version = "0.34.0", optional = true }
rayon = { version = "1.10.0", optional = true}
rayon = { version = "1.10.0", optional = true }
worker = { version = "0.5", optional = true }

[dev-dependencies]
anyhow = "1.0.75"
assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
tracing-subscriber = { version = "0.3.18" }
tokio-test = "0.4.4"

[features]
Expand Down Expand Up @@ -66,3 +63,6 @@ required-features = ["derive"]
[[example]]
name = "xai_embeddings"
required-features = ["derive"]
anyhow = "1.0.75"
assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
146 changes: 146 additions & 0 deletions rig-core/examples/local_agent_with_tools.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use anyhow::Result;
use rig::{
completion::{Chat, Message, Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;

#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}

#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to substract from"
},
"y": {
"type": "number",
"description": "The number to substract"
}
}
}
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create local client
let local = providers::openai::Client::from_url("", "http://192.168.0.10:11434/v1");

let span = info_span!("calculator_agent");

// Create agent with a single context prompt and two tools
let calculator_agent = local
.agent("c4ai-command-r7b-12-2024-abliterated")
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.tool(Adder)
.tool(Subtract)
.max_tokens(1024)
.build();

// Initialize chat history
let mut chat_history = Vec::new();
println!("Calculator Agent: Ready to help with calculations! (Type 'quit' to exit)");

loop {
print!("\nYou: ");
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;
let input = input.trim();

if input.to_lowercase() == "quit" {
break;
}

// Add user message to history
chat_history.push(Message {
role: "user".into(),
content: input.into(),
});

// Get response from agent
let response = calculator_agent
.chat(input, chat_history.clone())
.instrument(span.clone())
.await?;

// Add assistant's response to history
chat_history.push(Message {
role: "assistant".into(),
content: response.clone(),
});

println!("Calculator Agent: {}", response);
}

println!("\nGoodbye!");
Ok(())
}
13 changes: 13 additions & 0 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,19 @@ impl<M: CompletionModel> Chat for Agent<M> {
choice: ModelChoice::ToolCall(toolname, _, args),
..
} => Ok(self.tools.call(&toolname, args.to_string()).await?),
CompletionResponse {
choice: ModelChoice::MultipleToolCalls(tool_calls),
..
} => {
let mut results = Vec::new();
for tool_call in tool_calls {
if let ModelChoice::ToolCall(toolname, _, args) = tool_call {
let result = self.tools.call(&toolname, args.to_string()).await?;
results.push(result);
}
}
Ok(results.join("\n"))
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ pub enum ModelChoice {
/// Represents a completion response as a tool call of the form
/// `ToolCall(function_name, id, function_params)`.
ToolCall(String, String, serde_json::Value),
/// Represents a completion response with multiple tool calls
MultipleToolCalls(Vec<ModelChoice>),
}

/// Trait defining a completion model that can be used to generate completion responses.
Expand Down
Loading