diff --git a/crates/goose/src/conversation/message.rs b/crates/goose/src/conversation/message.rs index a257f7a1d75d..a2bb3402b3cd 100644 --- a/crates/goose/src/conversation/message.rs +++ b/crates/goose/src/conversation/message.rs @@ -1438,4 +1438,87 @@ mod tests { panic!("Expected ToolResponse content"); } } + + #[test] + fn test_tool_request_with_value_arguments_backward_compatibility() { + struct TestCase { + name: &'static str, + arguments_json: &'static str, + expected: Option, + } + + let test_cases = [ + TestCase { + name: "string", + arguments_json: r#""string_argument""#, + expected: Some(serde_json::json!({"value": "string_argument"})), + }, + TestCase { + name: "array", + arguments_json: r#"["a", "b", "c"]"#, + expected: Some(serde_json::json!({"value": ["a", "b", "c"]})), + }, + TestCase { + name: "number", + arguments_json: "42", + expected: Some(serde_json::json!({"value": 42})), + }, + TestCase { + name: "null", + arguments_json: "null", + expected: None, + }, + TestCase { + name: "object", + arguments_json: r#"{"key": "value", "number": 123}"#, + expected: Some(serde_json::json!({"key": "value", "number": 123})), + }, + ]; + + for tc in test_cases { + let json = format!( + r#"{{ + "role": "assistant", + "created": 1640995200, + "content": [{{ + "type": "toolRequest", + "id": "tool123", + "toolCall": {{ + "status": "success", + "value": {{ + "name": "test_tool", + "arguments": {} + }} + }} + }}], + "metadata": {{ "agentVisible": true, "userVisible": true }} + }}"#, + tc.arguments_json + ); + + let message: Message = serde_json::from_str(&json) + .unwrap_or_else(|e| panic!("{}: parse failed: {}", tc.name, e)); + + let MessageContent::ToolRequest(request) = &message.content[0] else { + panic!("{}: expected ToolRequest content", tc.name); + }; + + let Ok(tool_call) = &request.tool_call else { + panic!("{}: expected successful tool call", tc.name); + }; + + assert_eq!(tool_call.name, "test_tool", "{}: wrong tool name", tc.name); + + match (&tool_call.arguments, &tc.expected) { + (None, None) => {} + (Some(args), Some(expected)) => { + let args_value = serde_json::to_value(args).unwrap(); + assert_eq!(&args_value, expected, "{}: arguments mismatch", tc.name); + } + (actual, expected) => { + panic!("{}: expected {:?}, got {:?}", tc.name, expected, actual); + } + } + } + } } diff --git a/crates/goose/src/conversation/tool_result_serde.rs b/crates/goose/src/conversation/tool_result_serde.rs index ef359e280127..38886cfb4c77 100644 --- a/crates/goose/src/conversation/tool_result_serde.rs +++ b/crates/goose/src/conversation/tool_result_serde.rs @@ -1,5 +1,5 @@ use crate::mcp_utils::ToolResult; -use rmcp::model::{ErrorCode, ErrorData}; +use rmcp::model::{CallToolRequestParam, ErrorCode, ErrorData, JsonObject}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::borrow::Cow; @@ -25,22 +25,55 @@ where } } -pub fn deserialize<'de, T, D>(deserializer: D) -> Result, D::Error> +#[derive(Deserialize)] +struct ToolCallWithValueArguments { + name: String, + arguments: serde_json::Value, +} + +impl ToolCallWithValueArguments { + fn into_call_tool_request_param(self) -> CallToolRequestParam { + let arguments = match self.arguments { + serde_json::Value::Object(map) => Some(map), + serde_json::Value::Null => None, + other => { + let mut map = JsonObject::new(); + map.insert("value".to_string(), other); + Some(map) + } + }; + CallToolRequestParam { + name: Cow::Owned(self.name), + arguments, + } + } +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where - T: Deserialize<'de>, D: Deserializer<'de>, { #[derive(Deserialize)] #[serde(untagged)] - enum ResultFormat { - Success { status: String, value: T }, - Error { status: String, error: String }, + enum ResultFormat { + SuccessWithCallToolRequestParam { + status: String, + value: CallToolRequestParam, + }, + SuccessWithToolCallValueArguments { + status: String, + value: ToolCallWithValueArguments, + }, + Error { + status: String, + error: String, + }, } let format = ResultFormat::deserialize(deserializer)?; match format { - ResultFormat::Success { status, value } => { + ResultFormat::SuccessWithCallToolRequestParam { status, value } => { if status == "success" { Ok(Ok(value)) } else { @@ -50,6 +83,16 @@ where ))) } } + ResultFormat::SuccessWithToolCallValueArguments { status, value } => { + if status == "success" { + Ok(Ok(value.into_call_tool_request_param())) + } else { + Err(serde::de::Error::custom(format!( + "Expected status 'success', got '{}'", + status + ))) + } + } ResultFormat::Error { status, error } => { if status == "error" { Ok(Err(ErrorData { @@ -88,11 +131,11 @@ pub mod call_tool_result { #[derive(Deserialize)] #[serde(untagged)] enum ResultFormat { - NewSuccess { + SuccessWithCallToolResult { status: String, value: CallToolResult, }, - LegacySuccess { + SuccessWithContentVec { status: String, value: Vec, }, @@ -105,7 +148,7 @@ pub mod call_tool_result { let format = ResultFormat::deserialize(deserializer)?; match format { - ResultFormat::NewSuccess { status, value } => { + ResultFormat::SuccessWithCallToolResult { status, value } => { if status == "success" { Ok(Ok(value)) } else { @@ -115,7 +158,7 @@ pub mod call_tool_result { ))) } } - ResultFormat::LegacySuccess { status, value } => { + ResultFormat::SuccessWithContentVec { status, value } => { if status == "success" { Ok(Ok(CallToolResult::success(value))) } else {