diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 1fb2165aca50..4cbc6d87dc4f 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -33,6 +33,7 @@ use crate::conversation::message::{ ActionRequiredData, Message, MessageContent, ProviderMetadata, SystemNotificationType, ToolRequest, }; +use crate::conversation::tool_result_serde::call_tool_result; use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; use crate::mcp_utils::ToolResult; use crate::permission::permission_inspector::PermissionInspector; @@ -1123,6 +1124,8 @@ impl Agent { match item { ToolStreamItem::Result(output) => { + let output = call_tool_result::validate(output); + if enable_extension_request_ids.contains(&request_id) && output.is_err() { diff --git a/crates/goose/src/conversation/mod.rs b/crates/goose/src/conversation/mod.rs index 5f86c660aede..14ec1ef7bcac 100644 --- a/crates/goose/src/conversation/mod.rs +++ b/crates/goose/src/conversation/mod.rs @@ -6,7 +6,7 @@ use thiserror::Error; use utoipa::ToSchema; pub mod message; -mod tool_result_serde; +pub mod tool_result_serde; #[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)] pub struct Conversation(Vec); diff --git a/crates/goose/src/conversation/tool_result_serde.rs b/crates/goose/src/conversation/tool_result_serde.rs index ef359e280127..1aea142929b5 100644 --- a/crates/goose/src/conversation/tool_result_serde.rs +++ b/crates/goose/src/conversation/tool_result_serde.rs @@ -102,7 +102,17 @@ pub mod call_tool_result { }, } - let format = ResultFormat::deserialize(deserializer)?; + let original_value = serde_json::Value::deserialize(deserializer)?; + + let format = ResultFormat::deserialize(&original_value).map_err(|e| { + tracing::debug!( + "Failed to deserialize call_tool_result: {}. Original data: {}", + e, + serde_json::to_string(&original_value) + .unwrap_or_else(|_| "".to_string()) + ); + serde::de::Error::custom(e) + })?; match format { ResultFormat::NewSuccess { status, value } => { @@ -141,4 +151,87 @@ pub mod call_tool_result { } } } + + pub fn validate(result: ToolResult) -> ToolResult { + match &result { + Ok(call_tool_result) => match serde_json::to_string(call_tool_result) { + Ok(json_str) => match serde_json::from_str::(&json_str) { + Ok(_) => result, + Err(e) => { + tracing::error!("CallToolResult failed validation by deserialization: {}. Original data: {}", e, json_str); + Err(ErrorData { + code: ErrorCode::INTERNAL_ERROR, + message: Cow::from(format!("Tool result validation failed: {}", e)), + data: None, + }) + } + }, + Err(e) => { + tracing::error!("CallToolResult failed serialization: {}", e); + Err(ErrorData { + code: ErrorCode::INTERNAL_ERROR, + message: Cow::from(format!("Tool result serialization failed: {}", e)), + data: None, + }) + } + }, + Err(_) => result, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rmcp::model::{CallToolResult, Content, ErrorCode, ErrorData}; + use std::borrow::Cow; + #[test] + fn test_validate_accepts_valid_call_tool_result() { + let valid_result = CallToolResult { + content: vec![Content::text("test")], + is_error: Some(false), + structured_content: None, + meta: None, + }; + + let tool_result: ToolResult = Ok(valid_result); + let validated = call_tool_result::validate(tool_result); + + assert!( + validated.is_ok(), + "Expected validation to pass for valid CallToolResult" + ); + } + #[test] + fn test_validate_returns_error_for_invalid_calltoolresult() { + let valid_result = CallToolResult { + content: vec![], + is_error: Some(false), + structured_content: None, + meta: None, + }; + + let tool_result: ToolResult = Ok(valid_result); + let validated = call_tool_result::validate(tool_result); + + assert!(validated.is_err()); + assert!(validated + .unwrap_err() + .message + .contains("Tool result validation failed")) + } + + #[test] + fn test_validate_passes_through_errors() { + let error_result: ToolResult = Err(ErrorData { + code: ErrorCode::INTERNAL_ERROR, + message: Cow::from("test error"), + data: None, + }); + + let validated = call_tool_result::validate(error_result.clone()); + + assert!(validated.is_err()); + assert_eq!(validated.unwrap_err().message, "test error"); + } }