Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 98 additions & 108 deletions crates/goose-cli/src/session/mod.rs

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,11 @@ impl McpSpinners {
spinner.set_message(message.to_string());
}

pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
pub fn update(&mut self, token: &str, value: u32, total: Option<u32>, message: Option<&str>) {
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
if let Some(total) = total {
self.multi_bar.add(
ProgressBar::new((total * 100.0) as u64).with_style(
ProgressBar::new((total * 100) as u64).with_style(
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
.unwrap(),
),
Expand All @@ -812,7 +812,7 @@ impl McpSpinners {
self.multi_bar.add(ProgressBar::new_spinner())
}
});
bar.set_position((value * 100.0) as u64);
bar.set_position((value * 100) as u64);
if let Some(msg) = message {
bar.set_message(msg.to_string());
}
Expand Down
2 changes: 2 additions & 0 deletions crates/goose-mcp/src/developer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ impl DeveloperRouter {
notification: Notification {
method: "notifications/message".to_string(),
params: object!({
"level": "info",
"data": {
"type": "shell",
"stream": "stdout",
Expand All @@ -698,6 +699,7 @@ impl DeveloperRouter {
notification: Notification {
method: "notifications/message".to_string(),
params: object!({
"level": "info",
"data": {
"type": "shell",
"stream": "stderr",
Expand Down
4 changes: 2 additions & 2 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use goose::{
session,
};
use mcp_core::ToolResult;
use rmcp::model::{Content, JsonRpcMessage};
use rmcp::model::{Content, ServerNotification};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
Expand Down Expand Up @@ -97,7 +97,7 @@ enum MessageEvent {
},
Notification {
request_id: String,
message: JsonRpcMessage,
message: ServerNotification,
},
}

Expand Down
11 changes: 5 additions & 6 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ use crate::tool_monitor::{ToolCall, ToolMonitor};
use crate::utils::is_token_cancelled;
use mcp_core::{ToolError, ToolResult};
use regex::Regex;
use rmcp::model::Tool;
use rmcp::model::{Content, GetPromptResult, JsonRpcMessage, Prompt};
use rmcp::model::{Content, GetPromptResult, Prompt, ServerNotification, Tool};
use serde_json::Value;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -83,7 +82,7 @@ pub struct Agent {
#[derive(Clone, Debug)]
pub enum AgentEvent {
Message(Message),
McpNotification((String, JsonRpcMessage)),
McpNotification((String, ServerNotification)),
ModelChange { model: String, mode: String },
}

Expand All @@ -94,19 +93,19 @@ impl Default for Agent {
}

pub enum ToolStreamItem<T> {
Message(JsonRpcMessage),
Message(ServerNotification),
Result(T),
}

pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;

// tool_stream combines a stream of JsonRpcMessages with a future representing the
// tool_stream combines a stream of ServerNotifications with a future representing the
// final result of the tool call. MCP notifications are not request-scoped, but
// this lets us capture all notifications emitted during the tool call for
// simpler consumption
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
where
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
S: Stream<Item = ServerNotification> + Send + Unpin + 'static,
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
{
Box::pin(async_stream::stream! {
Expand Down
4 changes: 2 additions & 2 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ mod tests {
CallToolResult, InitializeResult, ListPromptsResult, ListResourcesResult, ListToolsResult,
ReadResourceResult,
};
use rmcp::model::{GetPromptResult, JsonRpcMessage};
use rmcp::model::{GetPromptResult, ServerNotification};
use serde_json::json;
use tokio::sync::mpsc;

Expand Down Expand Up @@ -891,7 +891,7 @@ mod tests {
Err(Error::NotInitialized)
}

async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
mpsc::channel(1).1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::agents::subagent_execution_tool::task_execution_tracker::{
use crate::agents::subagent_execution_tool::tasks::process_task;
use crate::agents::subagent_execution_tool::workers::spawn_worker;
use crate::agents::subagent_task_config::TaskConfig;
use rmcp::model::JsonRpcMessage;
use rmcp::model::ServerNotification;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tokio::sync::mpsc;
Expand All @@ -20,7 +20,7 @@ const DEFAULT_MAX_WORKERS: usize = 10;

pub async fn execute_single_task(
task: &Task,
notifier: mpsc::Sender<JsonRpcMessage>,
notifier: mpsc::Sender<ServerNotification>,
task_config: TaskConfig,
cancellation_token: Option<CancellationToken>,
) -> ExecutionResponse {
Expand Down Expand Up @@ -56,7 +56,7 @@ pub async fn execute_single_task(

pub async fn execute_tasks_in_parallel(
tasks: Vec<Task>,
notifier: Sender<JsonRpcMessage>,
notifier: Sender<ServerNotification>,
task_config: TaskConfig,
cancellation_token: Option<CancellationToken>,
) -> ExecutionResponse {
Expand Down
4 changes: 2 additions & 2 deletions crates/goose/src/agents/subagent_execution_tool/lib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use crate::agents::subagent_execution_tool::{
tasks_manager::TasksManager,
};
use crate::agents::subagent_task_config::TaskConfig;
use rmcp::model::JsonRpcMessage;
use rmcp::model::ServerNotification;
use serde_json::{json, Value};
use tokio::sync::mpsc::Sender;
use tokio_util::sync::CancellationToken;

pub async fn execute_tasks(
input: Value,
execution_mode: ExecutionMode,
notifier: Sender<JsonRpcMessage>,
notifier: Sender<ServerNotification>,
task_config: TaskConfig,
tasks_manager: &TasksManager,
cancellation_token: Option<CancellationToken>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use mcp_core::ToolError;
use rmcp::model::{Content, Tool, ToolAnnotations};
use rmcp::model::{Content, ServerNotification, Tool, ToolAnnotations};
use serde_json::Value;

use crate::agents::subagent_task_config::TaskConfig;
Expand All @@ -8,7 +8,6 @@ use crate::agents::{
subagent_execution_tool::task_types::ExecutionMode,
subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult,
};
use rmcp::model::JsonRpcMessage;
use rmcp::object;
use tokio::sync::mpsc;
use tokio_stream;
Expand Down Expand Up @@ -67,7 +66,7 @@ pub async fn run_tasks(
tasks_manager: &TasksManager,
cancellation_token: Option<CancellationToken>,
) -> ToolCallResult {
let (notification_tx, notification_rx) = mpsc::channel::<JsonRpcMessage>(100);
let (notification_tx, notification_rx) = mpsc::channel::<ServerNotification>(100);

let tasks_manager_clone = tasks_manager.clone();
let result_future = async move {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use rmcp::model::{JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification};
use rmcp::object;
use rmcp::model::{
LoggingLevel, LoggingMessageNotification, LoggingMessageNotificationMethod,
LoggingMessageNotificationParam, ServerNotification,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
Expand Down Expand Up @@ -52,7 +54,7 @@ fn format_task_metadata(task_info: &TaskInfo) -> String {
pub struct TaskExecutionTracker {
tasks: Arc<RwLock<HashMap<String, TaskInfo>>>,
last_refresh: Arc<RwLock<Instant>>,
notifier: mpsc::Sender<JsonRpcMessage>,
notifier: mpsc::Sender<ServerNotification>,
display_mode: DisplayMode,
cancellation_token: Option<CancellationToken>,
}
Expand All @@ -61,7 +63,7 @@ impl TaskExecutionTracker {
pub fn new(
tasks: Vec<Task>,
display_mode: DisplayMode,
notifier: Sender<JsonRpcMessage>,
notifier: Sender<ServerNotification>,
cancellation_token: Option<CancellationToken>,
) -> Self {
let task_map = tasks
Expand Down Expand Up @@ -97,7 +99,7 @@ impl TaskExecutionTracker {

fn log_notification_error(
&self,
error: &mpsc::error::TrySendError<JsonRpcMessage>,
error: &mpsc::error::TrySendError<ServerNotification>,
context: &str,
) {
if !self.is_cancelled() {
Expand All @@ -108,16 +110,17 @@ impl TaskExecutionTracker {
fn try_send_notification(&self, event: TaskExecutionNotificationEvent, context: &str) {
if let Err(e) = self
.notifier
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: JsonRpcVersion2_0,
notification: Notification {
method: "notifications/message".to_string(),
params: object!({
"data": event.to_notification_data()
}),
.try_send(ServerNotification::LoggingMessageNotification(
LoggingMessageNotification {
method: LoggingMessageNotificationMethod,
params: LoggingMessageNotificationParam {
data: event.to_notification_data(),
level: LoggingLevel::Info,
logger: None,
},
extensions: Default::default(),
},
}))
))
{
self.log_notification_error(&e, context);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/goose/src/agents/tool_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;
use async_stream::try_stream;
use futures::stream::{self, BoxStream};
use futures::{Stream, StreamExt};
use rmcp::model::JsonRpcMessage;
use rmcp::model::ServerNotification;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;

Expand All @@ -19,7 +19,7 @@ use rmcp::model::Content;
// can be used to receive notifications from the tool.
pub struct ToolCallResult {
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
pub notification_stream: Option<Box<dyn Stream<Item = ServerNotification> + Send + Unpin>>,
}

impl From<ToolResult<Vec<Content>>> for ToolCallResult {
Expand Down
26 changes: 20 additions & 6 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use mcp_core::protocol::{
use rmcp::model::{
GetPromptResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest,
JsonRpcResponse, JsonRpcVersion2_0, Notification, NumberOrString, Request, RequestId,
ServerNotification,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
Expand Down Expand Up @@ -106,7 +107,7 @@ pub trait McpClientTrait: Send + Sync {

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;

async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification>;
}

/// The MCP client is the interface for MCP operations.
Expand All @@ -118,7 +119,7 @@ where
next_id_counter: AtomicU64, // Added for atomic ID generation
server_capabilities: Option<ServerCapabilities>,
server_info: Option<Implementation>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<ServerNotification>>>>,
}

impl<T> McpClient<T>
Expand All @@ -129,7 +130,7 @@ where
let service = McpService::new(transport.clone());
let service_ptr = service.clone();
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));
let subscribers_ptr = notification_subscribers.clone();

tokio::spawn(async move {
Expand All @@ -148,9 +149,22 @@ where
}) => {
service_ptr.respond(&id.to_string(), Ok(message)).await;
}
_ => {
JsonRpcMessage::Notification(JsonRpcNotification {
notification,
..
}) => {
let mut subs = subscribers_ptr.lock().await;
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the only change of any significance -- here we'd just forward all "other" messages to notifications subscribers, without checking that they're necessarily notifications. Then on the other end, we'd ignore them because they aren't the types we're looking for. So the result should be the same, but worth noting.

if let Some(server_notification) = notification.into() {
subs.retain(|sub| {
sub.try_send(server_notification.clone()).is_ok()
});
}
}
_ => {
tracing::warn!(
"Received unexpected received message type: {:?}",
message
);
}
}
}
Expand Down Expand Up @@ -437,7 +451,7 @@ where
self.send_request("prompts/get", params).await
}

async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
let (tx, rx) = mpsc::channel(16);
self.notification_subscribers.lock().await.push(tx);
rx
Expand Down
18 changes: 11 additions & 7 deletions crates/mcp-client/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use tokio::sync::{oneshot, RwLock};
use tower::{timeout::Timeout, Service, ServiceBuilder};

use crate::transport::{Error, TransportHandle};
use crate::transport::{Error, TransportHandle, TransportMessageRecv};

/// A wrapper service that implements Tower's Service trait for MCP transport
#[derive(Clone)]
Expand All @@ -23,7 +23,7 @@ impl<T: TransportHandle> McpService<T> {
}
}

pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
pub async fn respond(&self, id: &str, response: Result<TransportMessageRecv, Error>) {
self.pending_requests.respond(id, response).await
}

Expand All @@ -36,7 +36,7 @@ impl<T> Service<JsonRpcMessage> for McpService<T>
where
T: TransportHandle + Send + Sync + 'static,
{
type Response = JsonRpcMessage;
type Response = TransportMessageRecv;
type Error = Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

Expand All @@ -63,7 +63,7 @@ where
// Handle notifications without waiting for a response
transport.send(request).await?;
// Return a dummy response for notifications
let dummy_response: JsonRpcMessage =
let dummy_response: Self::Response =
JsonRpcMessage::Response(rmcp::model::JsonRpcResponse {
jsonrpc: rmcp::model::JsonRpcVersion2_0,
id: rmcp::model::RequestId::Number(0),
Expand Down Expand Up @@ -91,7 +91,7 @@ where

// A data structure to store pending requests and their response channels
pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
requests: RwLock<HashMap<String, oneshot::Sender<Result<TransportMessageRecv, Error>>>>,
}

impl Default for PendingRequests {
Expand All @@ -107,11 +107,15 @@ impl PendingRequests {
}
}

pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
pub async fn insert(
&self,
id: String,
sender: oneshot::Sender<Result<TransportMessageRecv, Error>>,
) {
self.requests.write().await.insert(id, sender);
}

pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
pub async fn respond(&self, id: &str, response: Result<TransportMessageRecv, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) {
let _ = tx.send(response);
}
Expand Down
Loading
Loading