diff --git a/crates/goose/src/conversation_fixer.rs b/crates/goose/src/conversation_fixer.rs index 9eb2b80b539e..1a585ae699d1 100644 --- a/crates/goose/src/conversation_fixer.rs +++ b/crates/goose/src/conversation_fixer.rs @@ -10,11 +10,11 @@ impl ConversationFixer { /// Fix a conversation that we're about to send to an LLM. So the last and first /// messages should always be from the user. pub fn fix_conversation(messages: Vec) -> (Vec, Vec) { - let (messages, empty_removed) = Self::remove_empty_messages(messages); - let (messages, tool_calling_fixed) = Self::fix_tool_calling(messages); - let (messages, messages_merged) = Self::merge_consecutive_messages(messages); - let (messages, lead_trail_fixed) = Self::fix_lead_trail(messages); - let (messages, populated_if_empty) = Self::populate_if_empty(messages); + let (messages_1, empty_removed) = Self::remove_empty_messages(messages); + let (messages_2, tool_calling_fixed) = Self::fix_tool_calling(messages_1); + let (messages_3, messages_merged) = Self::merge_consecutive_messages(messages_2); + let (messages_4, lead_trail_fixed) = Self::fix_lead_trail(messages_3); + let (messages_5, populated_if_empty) = Self::populate_if_empty(messages_4); let mut issues = Vec::new(); issues.extend(empty_removed); @@ -23,7 +23,7 @@ impl ConversationFixer { issues.extend(lead_trail_fixed); issues.extend(populated_if_empty); - (messages, issues) + (messages_5, issues) } fn remove_empty_messages(messages: Vec) -> (Vec, Vec) { @@ -145,13 +145,10 @@ impl ConversationFixer { for message in messages { if let Some(last) = merged_messages.last_mut() { - if last.role == message.role { + let effective = Self::effective_role(&message); + if Self::effective_role(last) == effective { last.content.extend(message.content); - let role_name = match message.role { - Role::User => "user", - Role::Assistant => "assistant", - }; - issues.push(format!("Merged consecutive {} messages", role_name)); + issues.push(format!("Merged consecutive {} messages", effective)); continue; } } @@ -161,6 +158,24 @@ impl ConversationFixer { (merged_messages, issues) } + fn has_tool_response(message: &Message) -> bool { + message + .content + .iter() + .any(|content| matches!(content, MessageContent::ToolResponse(_))) + } + + fn effective_role(message: &Message) -> String { + if message.role == Role::User && Self::has_tool_response(message) { + "tool".to_string() + } else { + match message.role { + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + } + } + } + fn fix_lead_trail(mut messages: Vec) -> (Vec, Vec) { let mut issues = Vec::new(); @@ -375,4 +390,19 @@ mod tests { assert!(issues[0].contains("Removed orphaned tool request")); assert!(issues[1].contains("Merged consecutive assistant messages")); } + + #[test] + fn test_tool_response_effective_role() { + let messages = vec![ + Message::user().with_text("Search for something"), + Message::assistant() + .with_text("I'll search for you") + .with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))), + Message::user().with_tool_response("search_1", Ok(vec![])), + Message::user().with_text("Thanks!"), + ]; + + let (fixed, issues) = run_verify(messages); + assert_eq!(issues.len(), 0); + } }