Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions crates/goose/src/conversation_fixer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ impl ConversationFixer {
/// Fix a conversation that we're about to send to an LLM. So the last and first
/// messages should always be from the user.
pub fn fix_conversation(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let (messages, empty_removed) = Self::remove_empty_messages(messages);
let (messages, tool_calling_fixed) = Self::fix_tool_calling(messages);
let (messages, messages_merged) = Self::merge_consecutive_messages(messages);
let (messages, lead_trail_fixed) = Self::fix_lead_trail(messages);
let (messages, populated_if_empty) = Self::populate_if_empty(messages);
let (messages_1, empty_removed) = Self::remove_empty_messages(messages);
let (messages_2, tool_calling_fixed) = Self::fix_tool_calling(messages_1);
let (messages_3, messages_merged) = Self::merge_consecutive_messages(messages_2);
let (messages_4, lead_trail_fixed) = Self::fix_lead_trail(messages_3);
let (messages_5, populated_if_empty) = Self::populate_if_empty(messages_4);

let mut issues = Vec::new();
issues.extend(empty_removed);
Expand All @@ -23,7 +23,7 @@ impl ConversationFixer {
issues.extend(lead_trail_fixed);
issues.extend(populated_if_empty);

(messages, issues)
(messages_5, issues)
}

fn remove_empty_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
Expand Down Expand Up @@ -145,13 +145,10 @@ impl ConversationFixer {

for message in messages {
if let Some(last) = merged_messages.last_mut() {
if last.role == message.role {
let effective = Self::effective_role(&message);
if Self::effective_role(last) == effective {
last.content.extend(message.content);
let role_name = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
};
issues.push(format!("Merged consecutive {} messages", role_name));
issues.push(format!("Merged consecutive {} messages", effective));
continue;
}
}
Expand All @@ -161,6 +158,24 @@ impl ConversationFixer {
(merged_messages, issues)
}

fn has_tool_response(message: &Message) -> bool {
message
.content
.iter()
.any(|content| matches!(content, MessageContent::ToolResponse(_)))
}

fn effective_role(message: &Message) -> String {
if message.role == Role::User && Self::has_tool_response(message) {
"tool".to_string()
} else {
match message.role {
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
}
}
}

fn fix_lead_trail(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();

Expand Down Expand Up @@ -375,4 +390,19 @@ mod tests {
assert!(issues[0].contains("Removed orphaned tool request"));
assert!(issues[1].contains("Merged consecutive assistant messages"));
}

#[test]
fn test_tool_response_effective_role() {
let messages = vec![
Message::user().with_text("Search for something"),
Message::assistant()
.with_text("I'll search for you")
.with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))),
Message::user().with_tool_response("search_1", Ok(vec![])),
Message::user().with_text("Thanks!"),
];

let (fixed, issues) = run_verify(messages);
assert_eq!(issues.len(), 0);
}
}