diff --git a/Cargo.lock b/Cargo.lock index 36b2c5d965bc..e7e59b6486ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3141,6 +3141,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "unbinder", "unicode-normalization", "url", "urlencoding", @@ -8485,6 +8486,18 @@ dependencies = [ "zip 2.5.0", ] +[[package]] +name = "unbinder" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911624cb458604f41ac58db43ab8c2ccc9ad530791c1cd5d0ff8d824a38eeaa7" +dependencies = [ + "cfg-if", + "rustc-hash 2.1.1", + "serde", + "serde_json", +] + [[package]] name = "unicase" version = "2.8.1" diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index 3e258af0c794..3d125f3f2a03 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -44,7 +44,7 @@ pub enum SaveAsFormat { pub struct WebScrapeParams { /// The URL to fetch content from pub url: String, - /// How to interpret and save the content + /// Format of the response. #[serde(default)] pub save_as: SaveAsFormat, } @@ -479,8 +479,7 @@ impl ComputerControllerServer { - text (for HTML pages) - json (for API responses) - binary (for images and other files) - The content is cached locally and can be accessed later using the cache_path - returned in the response. + Returns 'Content saved to: '. Use cache to read the content. " )] pub async fn web_scrape( diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index fd359e61418f..eb7e2b4c4d16 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -115,6 +115,7 @@ ignore = { workspace = true } which = { workspace = true} boa_engine = "0.21.0" boa_gc = "0.21" +unbinder = "0.1.7" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose/src/agents/code_execution_extension.rs b/crates/goose/src/agents/code_execution_extension.rs index f6d351646346..4eaf966aeb14 100644 --- a/crates/goose/src/agents/code_execution_extension.rs +++ b/crates/goose/src/agents/code_execution_extension.rs @@ -69,12 +69,86 @@ struct InputSchema { required: Vec, } +fn quote_join(vals: &[&str]) -> String { + format!("\"{}\"", vals.join("\" | \"")) +} + +fn infer_type(schema: &Value) -> Option { + if schema.get("properties").is_some() { + Some("object".to_string()) + } else if schema.get("items").is_some() { + Some("array".to_string()) + } else { + None + } +} + +fn extract_type_from_schema(schema: &Value) -> Option { + // enum array (github-mcp style) + if let Some(arr) = schema.get("enum").and_then(|e| e.as_array()) { + let vals: Vec<_> = arr.iter().filter_map(|v| v.as_str()).collect(); + if !vals.is_empty() { + return Some(quote_join(&vals)); + } + } + + // oneOf with const (schemars enums) + if let Some(arr) = schema.get("oneOf").and_then(|o| o.as_array()) { + let vals: Vec<_> = arr + .iter() + .filter_map(|v| v.get("const")?.as_str()) + .collect(); + if !vals.is_empty() { + return Some(quote_join(&vals)); + } + } + + // anyOf (Option or unions) + if let Some(arr) = schema.get("anyOf").and_then(|o| o.as_array()) { + let non_null: Vec<_> = arr + .iter() + .filter(|v| v.get("type").and_then(|t| t.as_str()) != Some("null")) + .collect(); + if non_null.len() == 1 { + return extract_type_from_schema(non_null[0]).or_else(|| infer_type(non_null[0])); + } + if non_null.len() > 1 { + let types: Vec<_> = non_null + .iter() + .filter_map(|v| extract_type_from_schema(v).or_else(|| infer_type(v))) + .collect(); + if !types.is_empty() { + return Some(types.join(" | ")); + } + } + } + + // type field (string or array) + match schema.get("type") { + Some(Value::String(s)) => Some(s.clone()), + Some(Value::Array(arr)) => { + let non_null: Vec<_> = arr + .iter() + .filter_map(|v| v.as_str()) + .filter(|s| *s != "null") + .collect(); + match non_null.len() { + 0 => None, + 1 => Some(non_null[0].to_string()), + _ => Some(non_null.join(" | ")), + } + } + _ => None, + } +} + struct ToolInfo { server_name: String, tool_name: String, full_name: String, description: String, params: Vec<(String, String, bool)>, + return_type: String, } impl ToolInfo { @@ -82,9 +156,9 @@ impl ToolInfo { let (server_name, tool_name) = tool.name.as_ref().split_once("__")?; let param_names = get_parameter_names(tool); - let schema: InputSchema = - serde_json::from_value(Value::Object(tool.input_schema.as_ref().clone())) - .unwrap_or_default(); + let mut schema_value = Value::Object(tool.input_schema.as_ref().clone()); + let _ = unbinder::dereference_schema(&mut schema_value, unbinder::Options::default()); + let schema: InputSchema = serde_json::from_value(schema_value).unwrap_or_default(); let params = param_names .iter() @@ -92,14 +166,24 @@ impl ToolInfo { let ty = schema .properties .get(name) - .and_then(|p| p.get("type")) - .and_then(|t| t.as_str()) - .unwrap_or("any"); + .and_then(extract_type_from_schema) + .unwrap_or_else(|| "any".to_string()); let required = schema.required.contains(name); - (name.clone(), ty.to_string(), required) + (name.clone(), ty, required) }) .collect(); + let return_type = tool + .output_schema + .as_ref() + .and_then(|schema| { + let mut schema_value = Value::Object(schema.as_ref().clone()); + let _ = + unbinder::dereference_schema(&mut schema_value, unbinder::Options::default()); + extract_type_from_schema(&schema_value) + }) + .unwrap_or_else(|| "string".to_string()); + Some(Self { server_name: server_name.to_string(), tool_name: tool_name.to_string(), @@ -110,6 +194,7 @@ impl ToolInfo { .map(|d| d.as_ref().to_string()) .unwrap_or_default(), params, + return_type, }) } @@ -121,7 +206,10 @@ impl ToolInfo { .collect::>() .join(", "); let desc = self.description.lines().next().unwrap_or(""); - format!("{}({{ {params} }}): string - {desc}", self.tool_name) + format!( + "{}({{ {params} }}): {} - {desc}", + self.tool_name, self.return_type + ) } } @@ -306,7 +394,7 @@ impl CodeExecutionClient { - RIGHT: One execute_code call with a script that calls all needed tools Workflow: - 1. Use read_module("server") to discover tools and signatures + 1. Use the read_module tool to discover tools and signatures 2. Write ONE script that imports and calls ALL tools needed for the task 3. Chain results: use output from one tool as input to the next "#}.to_string()), @@ -498,6 +586,7 @@ impl CodeExecutionClient { if !matching_tools.is_empty() { output.push_str("## Matching Tools\n"); + output.push_str("Use the read_module tool for full signature and import syntax\n\n"); for tool in &matching_tools { output.push_str(&format!( "- {}/{}: {}\n", @@ -611,7 +700,7 @@ impl McpClientTrait for CodeExecutionClient { - Last expression is the result - No comments in code - BEFORE CALLING: Use read_module("server") to check required parameters. + BEFORE CALLING: Use the read_module tool to check required parameters. "#} .to_string(), schema::(), @@ -656,9 +745,9 @@ impl McpClientTrait for CodeExecutionClient { Search for tools by name or description across all available modules. USAGE: - - Single term: search_modules({ terms: "file" }) - - Multiple terms: search_modules({ terms: ["git", "shell"] }) - - Regex patterns: search_modules({ terms: "sh.*", regex: true }) + - Single term: search_modules with terms="file" + - Multiple terms: search_modules with terms=["git", "shell"] + - Regex patterns: search_modules with terms="sh.*", regex=true Returns matching servers and tools with descriptions. Use this when you don't know which module contains the tool you need. @@ -745,7 +834,7 @@ impl McpClientTrait for CodeExecutionClient { Modules: {} - Use read_module("name") to see tool signatures before calling unfamiliar tools. + Use the read_module tool to see signatures before calling unfamiliar tools. "#}, server_list.join(", ") )) @@ -755,6 +844,8 @@ impl McpClientTrait for CodeExecutionClient { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + use test_case::test_case; #[tokio::test] async fn test_execute_code_simple() { @@ -809,6 +900,7 @@ mod tests { full_name: "developer__shell".to_string(), description: "Execute shell commands".to_string(), params: vec![("command".to_string(), "string".to_string(), true)], + return_type: "string".to_string(), }, ToolInfo { server_name: "developer".to_string(), @@ -816,6 +908,7 @@ mod tests { full_name: "developer__text_editor".to_string(), description: "Edit text files".to_string(), params: vec![("path".to_string(), "string".to_string(), true)], + return_type: "string".to_string(), }, ToolInfo { server_name: "git".to_string(), @@ -823,6 +916,7 @@ mod tests { full_name: "git__commit".to_string(), description: "Commit changes to git".to_string(), params: vec![("message".to_string(), "string".to_string(), true)], + return_type: "string".to_string(), }, ]; @@ -883,6 +977,7 @@ mod tests { full_name: "developer__shell".to_string(), description: "Execute shell commands".to_string(), params: vec![], + return_type: "string".to_string(), }, ToolInfo { server_name: "developer".to_string(), @@ -890,6 +985,7 @@ mod tests { full_name: "developer__text_editor".to_string(), description: "Edit text files".to_string(), params: vec![], + return_type: "string".to_string(), }, ]; @@ -916,4 +1012,73 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().contains("Invalid regex")); } + + #[test_case( + "github__get_me", + serde_json::json!({"type": "object", "properties": {}}), + None, + "get_me({ }): string - Get details of the authenticated user"; + "no params, no output schema" + )] + #[test_case( + "filesystem__read_file", + serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}, "required": ["path"]}), + Some(serde_json::json!({"type": "object"})), + "read_file({ path: string }): object - Read the complete contents of a file"; + "string param, object output" + )] + #[test_case( + "memory__create_entities", + serde_json::json!({"type": "object", "properties": {"entities": {"type": "array"}}, "required": ["entities"]}), + Some(serde_json::json!({"type": "object"})), + "create_entities({ entities: array }): object - Create multiple new entities"; + "array param, object output" + )] + #[test_case( + "github__dismiss_notification", + serde_json::json!({"type": "object", "properties": { + "threadID": {"type": "string"}, + "state": {"type": "string", "enum": ["read", "done"]} + }, "required": ["threadID", "state"]}), + None, + "dismiss_notification({ state: \"read\" | \"done\", threadID: string }): string - Dismiss a notification"; + "enum param, no output schema" + )] + #[test_case( + "computercontroller__web_scrape", + serde_json::json!({"type": "object", "properties": { + "url": {"type": "string"}, + "save_as": {"oneOf": [{"const": "text"}, {"const": "json"}, {"const": "binary"}]} + }, "required": ["url"]}), + None, + "web_scrape({ save_as?: \"text\" | \"json\" | \"binary\", url: string }): string - Scrape content from URL"; + "oneOf const param (schemars), no output schema" + )] + fn test_mcp_tool_signature( + name: &str, + input: serde_json::Value, + output: Option, + expected: &str, + ) { + let input_schema: serde_json::Map = + serde_json::from_value(input).unwrap(); + let output_schema = output.map(|v| { + Arc::new( + serde_json::from_value::>(v).unwrap(), + ) + }); + let desc = expected.split(" - ").nth(1).unwrap_or("").to_string(); + let tool = McpTool { + name: name.to_string().into(), + title: None, + description: Some(desc.into()), + input_schema: Arc::new(input_schema), + output_schema, + annotations: None, + icons: None, + meta: None, + }; + let info = ToolInfo::from_mcp_tool(&tool).unwrap(); + assert_eq!(info.to_signature(), expected); + } } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index abee7d7a2443..9110fb167d18 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -174,11 +174,14 @@ fn require_str_parameter<'a>(v: &'a serde_json::Value, name: &str) -> Result<&'a } pub fn get_parameter_names(tool: &Tool) -> Vec { - tool.input_schema + let mut names: Vec = tool + .input_schema .get("properties") .and_then(|props| props.as_object()) .map(|props| props.keys().cloned().collect()) - .unwrap_or_default() + .unwrap_or_default(); + names.sort(); + names } impl Default for ExtensionManager {