Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions crates/goose/src/conversation/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1438,4 +1438,87 @@ mod tests {
panic!("Expected ToolResponse content");
}
}

#[test]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests are rather repetitive - if you wanted to do this, turn them into parameterized tests that only have the difference between the cases rather than repeating yourself.

but we should maybe think one deeper on how we avoid breaking this in the future - it happened before when we change the message format. adding tests like this is just very specifically going to test for this, doesn't seem that likely that we'll catch future issues.

wouldn't it be better to just check in a bunch of historic conversations and verify that we can load them?

fn test_tool_request_with_value_arguments_backward_compatibility() {
struct TestCase {
name: &'static str,
arguments_json: &'static str,
expected: Option<Value>,
}

let test_cases = [
TestCase {
name: "string",
arguments_json: r#""string_argument""#,
expected: Some(serde_json::json!({"value": "string_argument"})),
},
TestCase {
name: "array",
arguments_json: r#"["a", "b", "c"]"#,
expected: Some(serde_json::json!({"value": ["a", "b", "c"]})),
},
TestCase {
name: "number",
arguments_json: "42",
expected: Some(serde_json::json!({"value": 42})),
},
TestCase {
name: "null",
arguments_json: "null",
expected: None,
},
TestCase {
name: "object",
arguments_json: r#"{"key": "value", "number": 123}"#,
expected: Some(serde_json::json!({"key": "value", "number": 123})),
},
];

for tc in test_cases {
let json = format!(
r#"{{
"role": "assistant",
"created": 1640995200,
"content": [{{
"type": "toolRequest",
"id": "tool123",
"toolCall": {{
"status": "success",
"value": {{
"name": "test_tool",
"arguments": {}
}}
}}
}}],
"metadata": {{ "agentVisible": true, "userVisible": true }}
}}"#,
tc.arguments_json
);

let message: Message = serde_json::from_str(&json)
.unwrap_or_else(|e| panic!("{}: parse failed: {}", tc.name, e));

let MessageContent::ToolRequest(request) = &message.content[0] else {
panic!("{}: expected ToolRequest content", tc.name);
};

let Ok(tool_call) = &request.tool_call else {
panic!("{}: expected successful tool call", tc.name);
};

assert_eq!(tool_call.name, "test_tool", "{}: wrong tool name", tc.name);

match (&tool_call.arguments, &tc.expected) {
(None, None) => {}
(Some(args), Some(expected)) => {
let args_value = serde_json::to_value(args).unwrap();
assert_eq!(&args_value, expected, "{}: arguments mismatch", tc.name);
}
(actual, expected) => {
panic!("{}: expected {:?}, got {:?}", tc.name, expected, actual);
}
}
}
}
}
65 changes: 54 additions & 11 deletions crates/goose/src/conversation/tool_result_serde.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::mcp_utils::ToolResult;
use rmcp::model::{ErrorCode, ErrorData};
use rmcp::model::{CallToolRequestParam, ErrorCode, ErrorData, JsonObject};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Cow;
Expand All @@ -25,22 +25,55 @@ where
}
}

pub fn deserialize<'de, T, D>(deserializer: D) -> Result<ToolResult<T>, D::Error>
#[derive(Deserialize)]
struct ToolCallWithValueArguments {
name: String,
arguments: serde_json::Value,
}

impl ToolCallWithValueArguments {
fn into_call_tool_request_param(self) -> CallToolRequestParam {
let arguments = match self.arguments {
serde_json::Value::Object(map) => Some(map),
serde_json::Value::Null => None,
other => {
let mut map = JsonObject::new();
map.insert("value".to_string(), other);
Some(map)
}
};
CallToolRequestParam {
name: Cow::Owned(self.name),
arguments,
}
}
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<ToolResult<CallToolRequestParam>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum ResultFormat<T> {
Success { status: String, value: T },
Error { status: String, error: String },
enum ResultFormat {
SuccessWithCallToolRequestParam {
status: String,
value: CallToolRequestParam,
},
SuccessWithToolCallValueArguments {
status: String,
value: ToolCallWithValueArguments,
},
Error {
status: String,
error: String,
},
}

let format = ResultFormat::deserialize(deserializer)?;

match format {
ResultFormat::Success { status, value } => {
ResultFormat::SuccessWithCallToolRequestParam { status, value } => {
if status == "success" {
Ok(Ok(value))
} else {
Expand All @@ -50,6 +83,16 @@ where
)))
}
}
ResultFormat::SuccessWithToolCallValueArguments { status, value } => {
if status == "success" {
Ok(Ok(value.into_call_tool_request_param()))
} else {
Err(serde::de::Error::custom(format!(
"Expected status 'success', got '{}'",
status
)))
}
}
ResultFormat::Error { status, error } => {
if status == "error" {
Ok(Err(ErrorData {
Expand Down Expand Up @@ -88,11 +131,11 @@ pub mod call_tool_result {
#[derive(Deserialize)]
#[serde(untagged)]
enum ResultFormat {
NewSuccess {
SuccessWithCallToolResult {
status: String,
value: CallToolResult,
},
LegacySuccess {
SuccessWithContentVec {
status: String,
value: Vec<Content>,
},
Expand All @@ -105,7 +148,7 @@ pub mod call_tool_result {
let format = ResultFormat::deserialize(deserializer)?;

match format {
ResultFormat::NewSuccess { status, value } => {
ResultFormat::SuccessWithCallToolResult { status, value } => {
if status == "success" {
Ok(Ok(value))
} else {
Expand All @@ -115,7 +158,7 @@ pub mod call_tool_result {
)))
}
}
ResultFormat::LegacySuccess { status, value } => {
ResultFormat::SuccessWithContentVec { status, value } => {
if status == "success" {
Ok(Ok(CallToolResult::success(value)))
} else {
Expand Down
Loading