Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions crates/goose/src/conversation/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ impl From<PromptMessage> for Message {
}
}

#[derive(ToSchema, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[derive(ToSchema, Clone, Copy, PartialEq, Serialize, Deserialize, Debug)]
/// Metadata for message visibility
#[serde(rename_all = "camelCase")]
pub struct MessageMetadata {
Expand Down Expand Up @@ -456,7 +456,7 @@ fn default_true() -> bool {
true
}

#[derive(ToSchema, Clone, PartialEq, Serialize, Deserialize)]
#[derive(ToSchema, Clone, PartialEq, Serialize, Deserialize, Debug)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
Expand All @@ -470,19 +470,6 @@ pub struct Message {
pub metadata: MessageMetadata,
}

impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let joined_content: String = self
.content
.iter()
.map(|c| format!("{c}"))
.collect::<Vec<_>>()
.join(" ");

write!(f, "{:?}: {}", self.role, joined_content)
}
}

fn default_created() -> i64 {
0 // old messages do not have timestamps.
}
Expand Down
222 changes: 190 additions & 32 deletions crates/goose/src/conversation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,28 +149,73 @@ pub fn fix_conversation(conversation: Conversation) -> (Conversation, Vec<String
}

fn fix_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let (messages_1, empty_removed) = remove_empty_messages(messages);
let (messages_2, tool_calling_fixed) = fix_tool_calling(messages_1);
let (messages_3, messages_merged) = merge_consecutive_messages(messages_2);
let (messages_4, lead_trail_fixed) = fix_lead_trail(messages_3);
let (messages_5, populated_if_empty) = populate_if_empty(messages_4);
[
merge_text_content_items,
remove_empty_messages,
fix_tool_calling,
merge_consecutive_messages,
fix_lead_trail,
populate_if_empty,
]
.into_iter()
.fold(
(messages, Vec::new()),
|(msgs, mut all_issues), processor| {
let (new_msgs, issues) = processor(msgs);
all_issues.extend(issues);
(new_msgs, all_issues)
},
)
}

let mut issues = Vec::new();
issues.extend(empty_removed);
issues.extend(tool_calling_fixed);
issues.extend(messages_merged);
issues.extend(lead_trail_fixed);
issues.extend(populated_if_empty);
fn merge_text_content_in_message(mut msg: Message) -> Message {
if msg.role != Role::Assistant {
return msg;
}
msg.content = msg
.content
.into_iter()
.fold(Vec::new(), |mut content, item| {
match item {
MessageContent::Text(text) => {
if let Some(MessageContent::Text(ref mut last)) = content.last_mut() {
last.text.push_str(&text.text);
} else {
content.push(MessageContent::Text(text));
}
}
other => content.push(other),
}
content
});
msg
}

(messages_5, issues)
fn merge_text_content_items(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
messages.into_iter().fold(
(Vec::new(), Vec::new()),
|(mut messages, mut issues), message| {
let content_len = message.content.len();
let message = merge_text_content_in_message(message);
if content_len != message.content.len() {
issues.push(String::from("Merged text content"))
}
messages.push(message);
(messages, issues)
},
)
}

fn remove_empty_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
let filtered_messages = messages
.into_iter()
.filter(|msg| {
if msg.content.is_empty() {
if msg
.content
.iter()
.all(|c| c.as_text().is_some_and(str::is_empty))
{
issues.push("Removed empty message".to_string());
false
} else {
Expand Down Expand Up @@ -384,6 +429,24 @@ mod tests {
use rmcp::model::Role;
use serde_json::json;

macro_rules! assert_has_issues_unordered {
($fixed:expr, $issues:expr, $($expected:expr),+ $(,)?) => {
{
let mut expected: Vec<&str> = vec![$($expected),+];
let mut actual: Vec<&str> = $issues.iter().map(|s| s.as_str()).collect();
expected.sort();
actual.sort();

if actual != expected {
panic!(
"assertion failed: issues don't match\nexpected: {:?}\n actual: {:?}. Fixed conversation is:\n{:#?}",
expected, $issues, $fixed,
);
}
}
};
}

fn run_verify(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let (fixed, issues) = fix_conversation(Conversation::new_unvalidated(messages.clone()));

Expand Down Expand Up @@ -462,17 +525,15 @@ mod tests {
let (fixed, issues) = run_verify(messages);

assert_eq!(fixed.len(), 3);
assert_eq!(issues.len(), 4);

assert!(issues
.iter()
.any(|i| i.contains("Merged consecutive user messages")));
assert!(issues
.iter()
.any(|i| i.contains("Removed tool response 'orphan_1' from assistant message")));
assert!(issues
.iter()
.any(|i| i.contains("Removed tool request 'bad_req' from user message")));

assert_has_issues_unordered!(
fixed,
issues,
"Merged consecutive assistant messages",
"Merged consecutive user messages",
"Removed tool response 'orphan_1' from assistant message",
"Removed tool request 'bad_req' from user message",
);

assert_eq!(fixed[0].role, Role::User);
assert_eq!(fixed[1].role, Role::Assistant);
Expand Down Expand Up @@ -501,10 +562,18 @@ mod tests {

assert_eq!(fixed.len(), 1);

assert!(issues.iter().any(|i| i.contains("Removed empty message")));
assert!(issues
.iter()
.any(|i| i.contains("Removed orphaned tool response 'wrong_id'")));
assert_has_issues_unordered!(
fixed,
issues,
"Removed empty message",
"Removed orphaned tool response 'wrong_id'",
"Removed orphaned tool request 'search_1'",
"Removed orphaned tool request 'search_2'",
"Removed empty message",
"Removed empty message",
"Removed leading assistant message",
"Added placeholder user message to empty conversation",
);

assert_eq!(fixed[0].role, Role::User);
assert_eq!(fixed[0].as_concat_text(), "Hello");
Expand All @@ -530,9 +599,12 @@ mod tests {
let (fixed, issues) = fix_conversation(conversation);

assert_eq!(fixed.len(), 5);
assert_eq!(issues.len(), 2);
assert!(issues[0].contains("Removed orphaned tool request"));
assert!(issues[1].contains("Merged consecutive assistant messages"));
assert_has_issues_unordered!(
fixed,
issues,
"Removed orphaned tool request 'toolu_bdrk_018adWbP4X26CfoJU5hkhu3i'",
"Merged consecutive assistant messages"
)
}

#[test]
Expand All @@ -547,6 +619,92 @@ mod tests {
];

let (_fixed, issues) = run_verify(messages);
assert_eq!(issues.len(), 0);
assert!(issues.is_empty());
}

#[test]
fn test_merge_text_content_items() {
use crate::conversation::message::MessageContent;
use rmcp::model::{AnnotateAble, RawTextContent};

let mut message = Message::assistant().with_text("Hello");

message.content.push(MessageContent::Text(
RawTextContent {
text: " world".to_string(),
meta: None,
}
.no_annotation(),
));
message.content.push(MessageContent::Text(
RawTextContent {
text: "!".to_string(),
meta: None,
}
.no_annotation(),
));

let messages = vec![
Message::user().with_text("hello"),
message,
Message::user().with_text("thanks"),
];

let (fixed, issues) = run_verify(messages);

assert_eq!(fixed.len(), 3);
assert_has_issues_unordered!(fixed, issues, "Merged text content");

let fixed_msg = &fixed[1];
assert_eq!(fixed_msg.content.len(), 1);

if let MessageContent::Text(text_content) = &fixed_msg.content[0] {
assert_eq!(text_content.text, "Hello world!");
} else {
panic!("Expected text content");
}
}

#[test]
fn test_merge_text_content_items_with_mixed_content() {
use crate::conversation::message::MessageContent;
use rmcp::model::{AnnotateAble, RawTextContent};

let mut image_message = Message::assistant().with_text("Look at");

image_message.content.push(MessageContent::Text(
RawTextContent {
text: " this image:".to_string(),
meta: None,
}
.no_annotation(),
));

image_message = image_message.with_image("", "");

let messages = vec![
Message::user().with_text("hello"),
image_message,
Message::user().with_text("thanks"),
];

let (fixed, issues) = run_verify(messages);

assert_eq!(fixed.len(), 3);
assert_has_issues_unordered!(fixed, issues, "Merged text content");
let fixed_msg = &fixed[1];

assert_eq!(fixed_msg.content.len(), 2);
if let MessageContent::Text(text_content) = &fixed_msg.content[0] {
assert_eq!(text_content.text, "Look at this image:");
} else {
panic!("Expected first item to be text content");
}

if let MessageContent::Image(_) = &fixed_msg.content[1] {
// Good
} else {
panic!("Expected second item to be an image");
}
}
}
Loading