diff --git a/crates/goose/src/agents/code_execution_extension.rs b/crates/goose/src/agents/code_execution_extension.rs index 4eaf966aeb14..5bbf6b547310 100644 --- a/crates/goose/src/agents/code_execution_extension.rs +++ b/crates/goose/src/agents/code_execution_extension.rs @@ -125,6 +125,41 @@ fn extract_type_from_schema(schema: &Value) -> Option { // type field (string or array) match schema.get("type") { + Some(Value::String(s)) if s == "array" => { + let item_type = schema + .get("items") + .and_then(extract_type_from_schema) + .unwrap_or_else(|| "any".to_string()); + Some(if item_type == "any" { + "array".into() + } else { + format!("{item_type}[]") + }) + } + Some(Value::String(s)) if s == "object" => { + let Some(props) = schema.get("properties").and_then(|p| p.as_object()) else { + return Some("object".to_string()); + }; + let required: Vec<_> = schema + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_default(); + let mut fields: Vec<_> = props + .iter() + .map(|(name, schema)| { + let ty = extract_type_from_schema(schema).unwrap_or_else(|| "any".into()); + let opt = if required.contains(&name.as_str()) { + "" + } else { + "?" + }; + format!("{name}{opt}: {ty}") + }) + .collect(); + fields.sort(); + Some(format!("{{ {} }}", fields.join(", "))) + } Some(Value::String(s)) => Some(s.clone()), Some(Value::Array(arr)) => { let non_null: Vec<_> = arr @@ -248,6 +283,13 @@ fn create_server_module(server_tools: &[&ToolInfo], ctx: &mut Context) -> Module ) } +fn parse_result_to_js(result: &str, ctx: &mut Context) -> JsValue { + serde_json::from_str::(result) + .ok() + .and_then(|v| JsValue::from_json(&v, ctx).ok()) + .unwrap_or_else(|| JsValue::from(js_string!(result))) +} + fn create_tool_function(full_tool_name: String) -> NativeFunction { NativeFunction::from_copy_closure_with_captures( |_this, args, full_name: &String, ctx| { @@ -274,7 +316,7 @@ fn create_tool_function(full_tool_name: String) -> NativeFunction { rx.blocking_recv() .map_err(|e| e.to_string()) .and_then(|r| r) - .map(|result| JsValue::from(js_string!(result.as_str()))) + .map(|result| parse_result_to_js(&result, ctx)) .map_err(|e| JsNativeError::error().with_message(e).into()) }, full_tool_name, @@ -616,15 +658,19 @@ impl CodeExecutionClient { .await { Ok(dispatch_result) => match dispatch_result.result.await { - Ok(result) => Ok(result - .content - .iter() - .filter_map(|c| match &c.raw { - RawContent::Text(t) => Some(t.text.clone()), - _ => None, - }) - .collect::>() - .join("\n")), + Ok(result) => Ok(if let Some(sc) = &result.structured_content { + serde_json::to_string(sc).unwrap_or_default() + } else { + result + .content + .iter() + .filter_map(|c| match &c.raw { + RawContent::Text(t) => Some(t.text.clone()), + _ => None, + }) + .collect::>() + .join("\n") + }), Err(e) => Err(format!("Tool error: {}", e.message)), }, Err(e) => Err(format!("Dispatch error: {e}")), @@ -1021,18 +1067,18 @@ mod tests { "no params, no output schema" )] #[test_case( - "filesystem__read_file", - serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}}, "required": ["path"]}), - Some(serde_json::json!({"type": "object"})), - "read_file({ path: string }): object - Read the complete contents of a file"; - "string param, object output" + "filesystem__read_text_file", + serde_json::json!({"type": "object", "properties": {"path": {"type": "string"}, "tail": {"type": "number"}, "head": {"type": "number"}}, "required": ["path"]}), + Some(serde_json::json!({"type": "object", "properties": {"content": {"type": "string"}}, "required": ["content"]})), + "read_text_file({ head?: number, path: string, tail?: number }): { content: string } - Read the complete contents of a file"; + "optional number params, object output" )] #[test_case( "memory__create_entities", - serde_json::json!({"type": "object", "properties": {"entities": {"type": "array"}}, "required": ["entities"]}), - Some(serde_json::json!({"type": "object"})), - "create_entities({ entities: array }): object - Create multiple new entities"; - "array param, object output" + serde_json::json!({"type": "object", "properties": {"entities": {"type": "array", "items": {"type": "object", "properties": {"name": {"type": "string"}, "entityType": {"type": "string"}, "observations": {"type": "array", "items": {"type": "string"}}}, "required": ["name", "entityType", "observations"]}}}, "required": ["entities"]}), + Some(serde_json::json!({"type": "object", "properties": {"entities": {"type": "array", "items": {"type": "object", "properties": {"name": {"type": "string"}, "entityType": {"type": "string"}, "observations": {"type": "array", "items": {"type": "string"}}}, "required": ["name", "entityType", "observations"]}}}, "required": ["entities"]})), + "create_entities({ entities: { entityType: string, name: string, observations: string[] }[] }): { entities: { entityType: string, name: string, observations: string[] }[] } - Create multiple new entities"; + "nested object array with typed props" )] #[test_case( "github__dismiss_notification", @@ -1081,4 +1127,47 @@ mod tests { let info = ToolInfo::from_mcp_tool(&tool).unwrap(); assert_eq!(info.to_signature(), expected); } + + #[test_case(serde_json::json!({"type": "string"}), "string"; "string")] + #[test_case(serde_json::json!({"type": "number"}), "number"; "number")] + #[test_case(serde_json::json!({"type": "boolean"}), "boolean"; "boolean")] + #[test_case(serde_json::json!({"type": "array"}), "array"; "array bare")] + #[test_case(serde_json::json!({"type": "array", "items": {"type": "string"}}), "string[]"; "array with items")] + #[test_case(serde_json::json!({"type": "object"}), "object"; "object bare")] + #[test_case(serde_json::json!({"type": "object", "properties": {"a": {"type": "string"}}, "required": ["a"]}), "{ a: string }"; "object with prop")] + #[test_case(serde_json::json!({"type": "object", "properties": {"a": {"type": "string"}}}), "{ a?: string }"; "object optional prop")] + #[test_case(serde_json::json!({"type": "object", "properties": {"a": {"type": "array", "items": {"type": "string"}}}, "required": ["a"]}), "{ a: string[] }"; "object with array prop")] + #[test_case(serde_json::json!({"enum": ["a", "b"]}), "\"a\" | \"b\""; "enum array")] + #[test_case(serde_json::json!({"oneOf": [{"const": "x"}, {"const": "y"}]}), "\"x\" | \"y\""; "oneOf const")] + fn test_extract_type_from_schema(schema: serde_json::Value, expected: &str) { + assert_eq!( + extract_type_from_schema(&schema), + Some(expected.to_string()) + ); + } + + fn eval_with_tools(code: &str, tools: &[(&str, &str)]) -> String { + let mut ctx = Context::default(); + for &(name, response) in tools { + let resp = response.to_string(); + let func = NativeFunction::from_copy_closure_with_captures( + |_this, _args, resp: &String, ctx| Ok(parse_result_to_js(resp, ctx)), + resp, + ); + ctx.register_global_callable(js_string!(name), 0, func) + .unwrap(); + } + ctx.eval(Source::from_bytes(code)) + .unwrap() + .display() + .to_string() + } + + #[test_case("2 + 2", &[], "4"; "pure_js")] + #[test_case("get_data({}).content", &[("get_data", r#"{"content":"hello"}"#)], "\"hello\""; "structured_property_access")] + #[test_case("typeof shell({})", &[("shell", "plain text")], "\"string\""; "plain_text_is_string")] + #[test_case("shell({}).content", &[("shell", "plain text")], "undefined"; "plain_text_no_property")] + fn test_tool_result(code: &str, tools: &[(&str, &str)], expected: &str) { + assert_eq!(eval_with_tools(code, tools), expected); + } }