-
Notifications
You must be signed in to change notification settings - Fork 562
feat: implement Tool for Agent #704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| use anyhow::Result; | ||
| use rig::prelude::*; | ||
| use rig::{ | ||
| completion::{Prompt, ToolDefinition}, | ||
| providers, | ||
| tool::Tool, | ||
| }; | ||
| use serde::{Deserialize, Serialize}; | ||
| use serde_json::json; | ||
|
|
||
| #[derive(Deserialize)] | ||
| struct OperationArgs { | ||
| x: i32, | ||
| y: i32, | ||
| } | ||
|
|
||
| #[derive(Debug, thiserror::Error)] | ||
| #[error("Math error")] | ||
| struct MathError; | ||
|
|
||
| #[derive(Deserialize, Serialize)] | ||
| struct Adder; | ||
| impl Tool for Adder { | ||
| const NAME: &'static str = "add"; | ||
| type Error = MathError; | ||
| type Args = OperationArgs; | ||
| type Output = i32; | ||
|
|
||
| async fn definition(&self, _prompt: String) -> ToolDefinition { | ||
| ToolDefinition { | ||
| name: "add".to_string(), | ||
| description: "Add x and y together".to_string(), | ||
| parameters: json!({ | ||
| "type": "object", | ||
| "properties": { | ||
| "x": { | ||
| "type": "number", | ||
| "description": "The first number to add" | ||
| }, | ||
| "y": { | ||
| "type": "number", | ||
| "description": "The second number to add" | ||
| } | ||
| }, | ||
| "required": ["x", "y"], | ||
| }), | ||
| } | ||
| } | ||
|
|
||
| async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { | ||
| println!("[tool-call] Adding {} and {}", args.x, args.y); | ||
| let result = args.x + args.y; | ||
| Ok(result) | ||
| } | ||
| } | ||
|
|
||
| #[derive(Deserialize, Serialize)] | ||
| struct Subtract; | ||
|
|
||
| impl Tool for Subtract { | ||
| const NAME: &'static str = "subtract"; | ||
| type Error = MathError; | ||
| type Args = OperationArgs; | ||
| type Output = i32; | ||
|
|
||
| async fn definition(&self, _prompt: String) -> ToolDefinition { | ||
| serde_json::from_value(json!({ | ||
| "name": "subtract", | ||
| "description": "Subtract y from x (i.e.: x - y)", | ||
| "parameters": { | ||
| "type": "object", | ||
| "properties": { | ||
| "x": { | ||
| "type": "number", | ||
| "description": "The number to subtract from" | ||
| }, | ||
| "y": { | ||
| "type": "number", | ||
| "description": "The number to subtract" | ||
| } | ||
| }, | ||
| "required": ["x", "y"], | ||
| }, | ||
| })) | ||
| .expect("Tool Definition") | ||
| } | ||
|
|
||
| async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { | ||
| println!("[tool-call] Subtracting {} from {}", args.y, args.x); | ||
| let result = args.x - args.y; | ||
| Ok(result) | ||
| } | ||
| } | ||
|
|
||
| #[tokio::main] | ||
| async fn main() -> Result<(), anyhow::Error> { | ||
| tracing_subscriber::fmt() | ||
| .with_max_level(tracing::Level::DEBUG) | ||
| .with_target(false) | ||
| .init(); | ||
|
|
||
| // Create OpenAI client | ||
| let openai_client = providers::openai::Client::from_env(); | ||
|
|
||
| // Create agent with a single context prompt and two tools | ||
| let calculator_agent = openai_client | ||
| .agent(providers::openai::GPT_4O) | ||
| .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") | ||
| .max_tokens(1024) | ||
| .tool(Adder) | ||
| .tool(Subtract) | ||
| .build(); | ||
|
|
||
| // Create agent which has the calculator_agent as a tool | ||
| let agent_using_agent = openai_client | ||
| .agent(providers::openai::GPT_4O) | ||
| .preamble("You are a helpful assistant that can solve problems. Use the tool provided to answer the user's question.") | ||
| .max_tokens(1024) | ||
| .tool(calculator_agent) | ||
| .build(); | ||
|
|
||
| // Prompt the agent and print the response | ||
| println!("Calculate 2 - 5"); | ||
|
|
||
| println!( | ||
| "OpenAI Agent-Using Agent: {}", | ||
| agent_using_agent.prompt("Calculate 2 - 5").await? | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| use crate::{ | ||
| agent::Agent, | ||
| completion::{CompletionModel, Prompt, PromptError, ToolDefinition}, | ||
| tool::Tool, | ||
| }; | ||
| use schemars::{JsonSchema, schema_for}; | ||
| use serde::{Deserialize, Serialize}; | ||
|
|
||
| #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] | ||
| pub struct AgentToolArgs { | ||
| /// The prompt for the agent to call. | ||
| prompt: String, | ||
| } | ||
|
|
||
| impl<M: CompletionModel> Tool for Agent<M> { | ||
| const NAME: &'static str = "agent_tool"; | ||
|
|
||
| type Error = PromptError; | ||
| type Args = AgentToolArgs; | ||
| type Output = String; | ||
|
|
||
| async fn definition(&self, _prompt: String) -> ToolDefinition { | ||
| ToolDefinition { | ||
| name: <Self as Tool>::name(self), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is distinguish between the |
||
| description: format!( | ||
| "A tool that allows the agent to call another agent by prompting it. The preamble | ||
| of that agent follows: | ||
| --- | ||
| {}", | ||
| self.preamble.clone() | ||
| ), | ||
| parameters: serde_json::to_value(schema_for!(AgentToolArgs)) | ||
| .expect("converting JSON schema to JSON value should never fail"), | ||
| } | ||
| } | ||
|
|
||
| async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { | ||
| self.prompt(args.prompt).await | ||
| } | ||
|
|
||
| fn name(&self) -> String { | ||
| Self::NAME.to_string() | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.