Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions crates/goose-cli/src/scenario_tests/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rmcp::{
use serde_json::Value;
use std::collections::HashMap;
use tokio::sync::mpsc::{self, Receiver};
use tokio_util::sync::CancellationToken;

pub struct MockClient {
tools: HashMap<String, Tool>,
Expand Down Expand Up @@ -43,6 +44,7 @@ impl McpClientTrait for MockClient {
async fn list_resources(
&self,
_next_cursor: Option<String>,
_cancel_token: CancellationToken,
) -> Result<ListResourcesResult, Error> {
Ok(ListResourcesResult {
resources: vec![],
Expand All @@ -54,11 +56,19 @@ impl McpClientTrait for MockClient {
todo!()
}

async fn read_resource(&self, _uri: &str) -> Result<ReadResourceResult, Error> {
async fn read_resource(
&self,
_uri: &str,
_cancel_token: CancellationToken,
) -> Result<ReadResourceResult, Error> {
Err(Error::UnexpectedResponse)
}

async fn list_tools(&self, _: Option<String>) -> Result<ListToolsResult, Error> {
async fn list_tools(
&self,
_: Option<String>,
_cancel_token: CancellationToken,
) -> Result<ListToolsResult, Error> {
let rmcp_tools: Vec<rmcp::model::Tool> = self
.tools
.values()
Expand All @@ -77,7 +87,12 @@ impl McpClientTrait for MockClient {
})
}

async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
async fn call_tool(
&self,
name: &str,
arguments: Value,
_cancel_token: CancellationToken,
) -> Result<CallToolResult, Error> {
if let Some(handler) = self.handlers.get(name) {
match handler(&arguments) {
Ok(content) => Ok(CallToolResult {
Expand All @@ -91,14 +106,23 @@ impl McpClientTrait for MockClient {
}
}

async fn list_prompts(&self, _next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
async fn list_prompts(
&self,
_next_cursor: Option<String>,
_cancel_token: CancellationToken,
) -> Result<ListPromptsResult, Error> {
Ok(ListPromptsResult {
prompts: vec![],
next_cursor: None,
})
}

async fn get_prompt(&self, _name: &str, _arguments: Value) -> Result<GetPromptResult, Error> {
async fn get_prompt(
&self,
_name: &str,
_arguments: Value,
_cancel_token: CancellationToken,
) -> Result<GetPromptResult, Error> {
Err(Error::UnexpectedResponse)
}

Expand Down
6 changes: 5 additions & 1 deletion crates/goose-cli/src/scenario_tests/scenario_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use goose::providers::{create, testprovider::TestProvider};
use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;

pub const SCENARIO_TESTS_DIR: &str = "src/scenario_tests";

Expand Down Expand Up @@ -205,7 +206,10 @@ where

let mut error = None;
for message in &messages {
if let Err(e) = session.process_message(message.clone()).await {
if let Err(e) = session
.process_message(message.clone(), CancellationToken::default())
.await
{
error = Some(e.to_string());
break;
}
Expand Down
32 changes: 22 additions & 10 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,12 @@ impl Session {
}

/// Process a single message and get the response
pub(crate) async fn process_message(&mut self, message: Message) -> Result<()> {
pub(crate) async fn process_message(
&mut self,
message: Message,
cancel_token: CancellationToken,
) -> Result<()> {
let cancel_token = cancel_token.clone();
let message_text = message.as_concat_text();

self.push_message(message);
Expand Down Expand Up @@ -405,7 +410,7 @@ impl Session {
);
}

self.process_agent_response(false).await?;
self.process_agent_response(false, cancel_token).await?;
Ok(())
}

Expand All @@ -414,7 +419,8 @@ impl Session {
// Process initial message if provided
if let Some(prompt) = prompt {
let msg = Message::user().with_text(&prompt);
self.process_message(msg).await?;
self.process_message(msg, CancellationToken::default())
.await?;
}

// Initialize the completion cache
Expand Down Expand Up @@ -514,7 +520,8 @@ impl Session {
}

output::show_thinking();
self.process_agent_response(true).await?;
self.process_agent_response(true, CancellationToken::default())
.await?;
output::hide_thinking();
}
RunMode::Plan => {
Expand Down Expand Up @@ -814,7 +821,8 @@ impl Session {
self.push_message(plan_message);
// act on the plan
output::show_thinking();
self.process_agent_response(true).await?;
self.process_agent_response(true, CancellationToken::default())
.await?;
output::hide_thinking();

// Reset run & goose mode
Expand Down Expand Up @@ -842,12 +850,15 @@ impl Session {
/// Process a single message and exit
pub async fn headless(&mut self, prompt: String) -> Result<()> {
let message = Message::user().with_text(&prompt);
self.process_message(message).await
self.process_message(message, CancellationToken::default())
.await
}

async fn process_agent_response(&mut self, interactive: bool) -> Result<()> {
// Messages will be auto-compacted in agent.reply() if needed
let cancel_token = CancellationToken::new();
async fn process_agent_response(
&mut self,
interactive: bool,
cancel_token: CancellationToken,
) -> Result<()> {
let cancel_token_clone = cancel_token.clone();

let session_config = self.session_file.as_ref().map(|s| {
Expand Down Expand Up @@ -1511,7 +1522,8 @@ impl Session {

if valid {
output::show_thinking();
self.process_agent_response(true).await?;
self.process_agent_response(true, CancellationToken::default())
.await?;
output::hide_thinking();
}
}
Expand Down
135 changes: 62 additions & 73 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,24 @@ enum MessageEvent {
request_id: String,
message: ServerNotification,
},
Ping,
}

async fn stream_event(
event: MessageEvent,
tx: &mpsc::Sender<String>,
) -> Result<(), mpsc::error::SendError<String>> {
cancel_token: &CancellationToken,
) {
let json = serde_json::to_string(&event).unwrap_or_else(|e| {
format!(
r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#,
e
)
});
tx.send(format!("data: {}\n\n", json)).await
if tx.send(format!("data: {}\n\n", json)).await.is_err() {
tracing::info!("client hung up");
cancel_token.cancel();
}
}

async fn reply_handler(
Expand Down Expand Up @@ -144,6 +149,7 @@ async fn reply_handler(
error: "No agent configured".to_string(),
},
&task_tx,
&cancel_token,
)
.await;
return;
Expand Down Expand Up @@ -173,11 +179,12 @@ async fn reply_handler(
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {:?}", e);
let _ = stream_event(
stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&task_tx,
&cancel_token,
)
.await;
return;
Expand All @@ -194,88 +201,69 @@ async fn reply_handler(
error: format!("Failed to get session path: {}", e),
},
&task_tx,
&cancel_token,
)
.await;
return;
}
};
let saved_message_count = all_messages.len();

let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(500));
loop {
tokio::select! {
_ = task_cancel.cancelled() => {
tracing::info!("Agent task cancelled");
_ = task_cancel.cancelled() => {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is mostly a whitespace change, so I'd suggest turning off whitespace changes for this diff view

annoyingly, rustfmt doens't format code in macros, so this being inside a tokio::select!, didn't get formatted before, and had a bunch of weird indentation

tracing::info!("Agent task cancelled");
break;
}
_ = heartbeat_interval.tick() => {
stream_event(MessageEvent::Ping, &tx, &cancel_token).await;
}
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(AgentEvent::Message(message)))) => {
push_message(&mut all_messages, message.clone());
stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await;
}
Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => {
// Replace the message history with the compacted messages
all_messages = new_messages;
// Note: We don't send this as a stream event since it's an internal operation
// The client will see the compaction notification message that was sent before this event
}
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
stream_event(MessageEvent::ModelChange { model, mode }, &tx, &cancel_token).await;
}
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
message: n,
}, &tx, &cancel_token).await;
}

Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
&cancel_token,
).await;
break;
}
Ok(None) => {
break;
}
Err(_) => {
if tx.is_closed() {
break;
}
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(AgentEvent::Message(message)))) => {
push_message(&mut all_messages, message.clone());
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
}
Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => {
// Replace the message history with the compacted messages
all_messages = new_messages;
// Note: We don't send this as a stream event since it's an internal operation
// The client will see the compaction notification message that was sent before this event
}
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
tracing::error!("Error sending model change through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
if let Err(e) = stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
message: n,
}, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}

Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
Ok(None) => {
break;
}
Err(_) => {
if tx.is_closed() {
break;
}
continue;
}
}
}
continue;
}
}
}
}
}

if all_messages.len() > saved_message_count {
Expand All @@ -301,6 +289,7 @@ async fn reply_handler(
reason: "stop".to_string(),
},
&task_tx,
&cancel_token,
)
.await;
}));
Expand Down
Loading
Loading