Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions rig-core/examples/agent_with_agent_tool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use anyhow::Result;
use rig::prelude::*;
use rig::{
completion::{Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;

#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"],
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}

#[derive(Deserialize, Serialize)]
struct Subtract;

impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"],
},
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();

// Create OpenAI client
let openai_client = providers::openai::Client::from_env();

// Create agent with a single context prompt and two tools
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();

// Create agent which has the calculator_agent as a tool
let agent_using_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are a helpful assistant that can solve problems. Use the tool provided to answer the user's question.")
.max_tokens(1024)
.tool(calculator_agent)
.build();

// Prompt the agent and print the response
println!("Calculate 2 - 5");

println!(
"OpenAI Agent-Using Agent: {}",
agent_using_agent.prompt("Calculate 2 - 5").await?
);

Ok(())
}
1 change: 1 addition & 0 deletions rig-core/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
mod builder;
mod completion;
mod prompt_request;
mod tool;

pub use builder::AgentBuilder;
pub use completion::Agent;
Expand Down
44 changes: 44 additions & 0 deletions rig-core/src/agent/tool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate::{
agent::Agent,
completion::{CompletionModel, Prompt, PromptError, ToolDefinition},
tool::Tool,
};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentToolArgs {
/// The prompt for the agent to call.
prompt: String,
}

impl<M: CompletionModel> Tool for Agent<M> {
const NAME: &'static str = "agent_tool";

type Error = PromptError;
type Args = AgentToolArgs;
type Output = String;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: <Self as Tool>::name(self),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is distinguish between the Agent::name method and the Tool::name trait method. The Tool::name trait method implementation here returns a valid Tool name.
This is important because the Agent::name method could return an invalid Tool name - either an invalid one set by the user, or, if not set, the default name "Unnamed Agent" (which is invalid due to the whitespace).

description: format!(
"A tool that allows the agent to call another agent by prompting it. The preamble
of that agent follows:
---
{}",
self.preamble.clone()
),
parameters: serde_json::to_value(schema_for!(AgentToolArgs))
.expect("converting JSON schema to JSON value should never fail"),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
self.prompt(args.prompt).await
}

fn name(&self) -> String {
Self::NAME.to_string()
}
}