Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::agents::subagent_execution_tool::tasks_manager::TasksManager;
use crate::agents::tool_route_manager::ToolRouteManager;
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
use crate::agents::types::SessionConfig;
use crate::agents::types::{FrontendTool, ToolResultReceiver};
use crate::agents::types::{FrontendTool, SharedProvider, ToolResultReceiver};
use crate::config::{get_enabled_extensions, Config};
use crate::context_mgmt::DEFAULT_COMPACTION_THRESHOLD;
use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation};
Expand Down Expand Up @@ -86,7 +86,8 @@ pub struct ToolCategorizeResult {

/// The main goose Agent
pub struct Agent {
pub(super) provider: Mutex<Option<Arc<dyn Provider>>>,
pub(super) provider: SharedProvider,

pub extension_manager: Arc<ExtensionManager>,
pub(super) sub_recipe_manager: Mutex<SubRecipeManager>,
pub(super) tasks_manager: TasksManager,
Expand Down Expand Up @@ -159,10 +160,11 @@ impl Agent {
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
let provider = Arc::new(Mutex::new(None));

Self {
provider: Mutex::new(None),
extension_manager: Arc::new(ExtensionManager::new()),
provider: provider.clone(),
extension_manager: Arc::new(ExtensionManager::new(provider.clone())),
sub_recipe_manager: Mutex::new(SubRecipeManager::new()),
tasks_manager: TasksManager::new(),
final_output_tool: Arc::new(Mutex::new(None)),
Expand Down
34 changes: 24 additions & 10 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rmcp::transport::{
TokioChildProcess,
};
use std::collections::HashMap;
use std::option::Option;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -29,6 +30,7 @@ use super::extension::{
ToolInfo, PLATFORM_EXTENSIONS,
};
use super::tool_execution::ToolCallResult;
use super::types::SharedProvider;
use crate::agents::extension::{Envs, ProcessExit};
use crate::agents::extension_malware_check;
use crate::agents::mcp_client::{McpClient, McpClientTrait};
Expand Down Expand Up @@ -91,6 +93,7 @@ impl Extension {
pub struct ExtensionManager {
extensions: Mutex<HashMap<String, Extension>>,
context: Mutex<PlatformExtensionContext>,
provider: SharedProvider,
}

/// A flattened representation of a resource used by the agent to prepare inference
Expand Down Expand Up @@ -171,13 +174,14 @@ pub fn get_parameter_names(tool: &Tool) -> Vec<String> {

impl Default for ExtensionManager {
fn default() -> Self {
Self::new()
Self::new(Arc::new(Mutex::new(None)))
}
}

async fn child_process_client(
mut command: Command,
timeout: &Option<u64>,
provider: SharedProvider,
) -> ExtensionResult<McpClient> {
#[cfg(unix)]
command.process_group(0);
Expand Down Expand Up @@ -205,6 +209,7 @@ async fn child_process_client(
let client_result = McpClient::connect(
transport,
Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT)),
provider,
)
.await;

Expand Down Expand Up @@ -243,17 +248,23 @@ fn extract_auth_error(
}

impl ExtensionManager {
pub fn new() -> Self {
pub fn new(provider: SharedProvider) -> Self {
Self {
extensions: Mutex::new(HashMap::new()),
context: Mutex::new(PlatformExtensionContext {
session_id: None,
extension_manager: None,
tool_route_manager: None,
}),
provider,
}
}

/// Create a new ExtensionManager with no provider (useful for tests)
pub fn new_without_provider() -> Self {
Self::new(Arc::new(Mutex::new(None)))
}

pub async fn set_context(&self, context: PlatformExtensionContext) {
*self.context.lock().await = context;
}
Expand Down Expand Up @@ -348,6 +359,7 @@ impl ExtensionManager {
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
self.provider.clone(),
)
.await?,
)
Expand Down Expand Up @@ -388,6 +400,7 @@ impl ExtensionManager {
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
self.provider.clone(),
)
.await;
let client = if let Some(_auth_error) = extract_auth_error(&client_res) {
Expand All @@ -407,6 +420,7 @@ impl ExtensionManager {
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
self.provider.clone(),
)
.await?
} else {
Expand All @@ -430,7 +444,7 @@ impl ExtensionManager {
// Check for malicious packages before launching the process
extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?;

let client = child_process_client(command, timeout).await?;
let client = child_process_client(command, timeout, self.provider.clone()).await?;
Box::new(client)
}
ExtensionConfig::Builtin {
Expand Down Expand Up @@ -459,7 +473,7 @@ impl ExtensionManager {
let command = Command::new(cmd).configure(|command| {
command.arg("mcp").arg(name);
});
let client = child_process_client(command, timeout).await?;
let client = child_process_client(command, timeout, self.provider.clone()).await?;
Box::new(client)
}
ExtensionConfig::Platform { name, .. } => {
Expand Down Expand Up @@ -495,7 +509,7 @@ impl ExtensionManager {
command.arg("python").arg(file_path.to_str().unwrap());
});

let client = child_process_client(command, timeout).await?;
let client = child_process_client(command, timeout, self.provider.clone()).await?;

Box::new(client)
}
Expand Down Expand Up @@ -1252,7 +1266,7 @@ mod tests {

#[tokio::test]
async fn test_get_client_for_tool() {
let extension_manager = ExtensionManager::new();
let extension_manager = ExtensionManager::new_without_provider();

// Add some mock clients using the helper method
extension_manager
Expand Down Expand Up @@ -1312,7 +1326,7 @@ mod tests {
async fn test_dispatch_tool_call() {
// test that dispatch_tool_call parses out the sanitized name correctly, and extracts
// tool_names
let extension_manager = ExtensionManager::new();
let extension_manager = ExtensionManager::new_without_provider();

// Add some mock clients using the helper method
extension_manager
Expand Down Expand Up @@ -1429,7 +1443,7 @@ mod tests {

#[tokio::test]
async fn test_tool_availability_filtering() {
let extension_manager = ExtensionManager::new();
let extension_manager = ExtensionManager::new_without_provider();

// Only "available_tool" should be available to the LLM
let available_tools = vec!["available_tool".to_string()];
Expand Down Expand Up @@ -1457,7 +1471,7 @@ mod tests {

#[tokio::test]
async fn test_tool_availability_defaults_to_available() {
let extension_manager = ExtensionManager::new();
let extension_manager = ExtensionManager::new_without_provider();

extension_manager
.add_mock_extension_with_tools(
Expand All @@ -1482,7 +1496,7 @@ mod tests {

#[tokio::test]
async fn test_dispatch_unavailable_tool_returns_error() {
let extension_manager = ExtensionManager::new();
let extension_manager = ExtensionManager::new_without_provider();

let available_tools = vec!["available_tool".to_string()];

Expand Down
107 changes: 97 additions & 10 deletions crates/goose/src/agents/mcp_client.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
use rmcp::model::JsonObject;
use crate::agents::types::SharedProvider;
use rmcp::model::{Content, ErrorCode, JsonObject};
/// MCP client implementation for Goose
use rmcp::{
model::{
CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
CancelledNotificationMethod, CancelledNotificationParam, ClientCapabilities, ClientInfo,
ClientRequest, GetPromptRequest, GetPromptRequestParam, GetPromptResult, Implementation,
InitializeResult, ListPromptsRequest, ListPromptsResult, ListResourcesRequest,
ListResourcesResult, ListToolsRequest, ListToolsResult, LoggingMessageNotification,
ClientRequest, CreateMessageRequestParam, CreateMessageResult, GetPromptRequest,
GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
ListPromptsRequest, ListPromptsResult, ListResourcesRequest, ListResourcesResult,
ListToolsRequest, ListToolsResult, LoggingMessageNotification,
LoggingMessageNotificationMethod, PaginatedRequestParam, ProgressNotification,
ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam,
ReadResourceResult, RequestId, ServerNotification, ServerResult,
ReadResourceResult, RequestId, Role, SamplingMessage, ServerNotification, ServerResult,
},
service::{
ClientInitializeError, PeerRequestOptions, RequestHandle, RunningService, ServiceRole,
ClientInitializeError, PeerRequestOptions, RequestContext, RequestHandle, RunningService,
ServiceRole,
},
transport::IntoTransport,
ClientHandler, Peer, RoleClient, ServiceError, ServiceExt,
ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt,
};
use serde_json::Value;
use std::{sync::Arc, time::Duration};
Expand Down Expand Up @@ -76,12 +79,17 @@ pub trait McpClientTrait: Send + Sync {

pub struct GooseClient {
notification_handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
provider: SharedProvider,
}

impl GooseClient {
pub fn new(handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>) -> Self {
pub fn new(
handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
provider: SharedProvider,
) -> Self {
GooseClient {
notification_handlers: handlers,
provider,
}
}
}
Expand Down Expand Up @@ -127,10 +135,88 @@ impl ClientHandler for GooseClient {
});
}

async fn create_message(
&self,
params: CreateMessageRequestParam,
_context: RequestContext<RoleClient>,
) -> Result<CreateMessageResult, ErrorData> {
let provider = self
.provider
.lock()
.await
.as_ref()
.ok_or(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"Could not use provider",
None,
))?
.clone();

let provider_ready_messages: Vec<crate::conversation::message::Message> = params
.messages
.iter()
.map(|msg| {
let base = match msg.role {
Role::User => crate::conversation::message::Message::user(),
Role::Assistant => crate::conversation::message::Message::assistant(),
};

match msg.content.as_text() {
Some(text) => base.with_text(&text.text),
None => base.with_content(msg.content.clone().into()),
}
})
.collect();

let system_prompt = params
.system_prompt
.as_deref()
.unwrap_or("You are a general-purpose AI agent called goose");

let (response, usage) = provider
.complete(system_prompt, &provider_ready_messages, &[])
.await
.map_err(|e| {
ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"Unexpected error while completing the prompt",
Some(Value::from(e.to_string())),
)
})?;

Ok(CreateMessageResult {
model: usage.model,
stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
message: SamplingMessage {
role: Role::Assistant,
// TODO(alexhancock): MCP sampling currently only supports one content on each SamplingMessage
// https://modelcontextprotocol.io/specification/draft/client/sampling#messages
// This doesn't mesh well with goose's approach which has Vec<MessageContent>
// There is a proposal to MCP which is agreed to go in the next version to have SamplingMessages support multiple content parts
// https://github.com/modelcontextprotocol/modelcontextprotocol/pull/198
// Until that is formalized, we can take the first message content from the provider and use it
content: if let Some(content) = response.content.first() {
match content {
crate::conversation::message::MessageContent::Text(text) => {
Content::text(&text.text)
}
crate::conversation::message::MessageContent::Image(img) => {
Content::image(&img.data, &img.mime_type)
}
// TODO(alexhancock) - Content::Audio? goose's messages don't currently have it
_ => Content::text(""),
}
} else {
Content::text("")
},
},
})
}

fn get_info(&self) -> ClientInfo {
ClientInfo {
protocol_version: ProtocolVersion::V_2025_03_26,
capabilities: ClientCapabilities::builder().build(),
capabilities: ClientCapabilities::builder().enable_sampling().build(),
client_info: Implementation {
name: "goose".to_string(),
version: std::env::var("GOOSE_MCP_CLIENT_VERSION")
Expand All @@ -155,6 +241,7 @@ impl McpClient {
pub async fn connect<T, E, A>(
transport: T,
timeout: std::time::Duration,
provider: SharedProvider,
) -> Result<Self, ClientInitializeError>
where
T: IntoTransport<RoleClient, E, A>,
Expand All @@ -163,7 +250,7 @@ impl McpClient {
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));

let client = GooseClient::new(notification_subscribers.clone());
let client = GooseClient::new(notification_subscribers.clone(), provider);
let client: rmcp::service::RunningService<rmcp::RoleClient, GooseClient> =
client.serve(transport).await?;
let server_info = client.peer_info().cloned();
Expand Down
4 changes: 4 additions & 0 deletions crates/goose/src/agents/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::mcp_utils::ToolResult;
use crate::providers::base::Provider;
use rmcp::model::{Content, Tool};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
Expand All @@ -9,6 +10,9 @@ use utoipa::ToSchema;
/// Type alias for the tool result channel receiver
pub type ToolResultReceiver = Arc<Mutex<mpsc::Receiver<(String, ToolResult<Vec<Content>>)>>>;

// We use double Arc here to allow easy provider swaps while sharing concurrent access
pub type SharedProvider = Arc<Mutex<Option<Arc<dyn Provider>>>>;

/// Default timeout for retry operations (5 minutes)
pub const DEFAULT_RETRY_TIMEOUT_SECONDS: u64 = 300;

Expand Down
3 changes: 1 addition & 2 deletions crates/goose/tests/mcp_integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ async fn test_replayed_session(
bundled: Some(false),
available_tools: vec![],
};

let extension_manager = ExtensionManager::new();
let extension_manager = ExtensionManager::new_without_provider();

#[allow(clippy::redundant_closure_call)]
let result = (async || -> Result<(), Box<dyn std::error::Error>> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"goose","version":"0.0.0"}}}
STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}}
STDERR: 2025-09-27T04:13:30.409389Z  INFO goose_mcp::mcp_server_runner: Starting MCP server
STDERR: at crates/goose-mcp/src/mcp_server_runner.rs:18
STDERR:
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/tests/mcp_replays/github-mcp-serverstdio
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"goose","version":"0.0.0"}}}
STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}}
STDERR: GitHub MCP Server running on stdio
STDOUT: {"jsonrpc":"2.0","id":0,"result":{"protocolVersion":"2025-03-26","capabilities":{"logging":{},"prompts":{},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"github-mcp-server","version":"version"}}}
STDIN: {"jsonrpc":"2.0","method":"notifications/initialized"}
Expand Down
Loading