diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e6e0f7da8cd3..d5dbcfb13bca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,6 +106,11 @@ jobs: - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: Install Dependencies + run: | + sudo apt update -y + sudo apt install -y libdbus-1-dev libxcb1-dev + - name: Cache Cargo artifacts uses: Swatinem/rust-cache@v2 diff --git a/Cargo.lock b/Cargo.lock index a27bb55693f2..d2096eb70287 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2680,6 +2680,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" +[[package]] +name = "fs-err" +version = "3.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf68cef89750956493a66a10f512b9e58d9db21f2a573c079c0bdf1207a54a7" +dependencies = [ + "autocfg", +] + [[package]] name = "fs2" version = "0.4.3" @@ -2964,7 +2973,6 @@ dependencies = [ name = "goose" version = "1.19.0" dependencies = [ - "agent-client-protocol-schema", "ahash", "anyhow", "async-stream", @@ -2989,6 +2997,7 @@ dependencies = [ "etcetera 0.11.0", "fs2", "futures", + "goose-mcp", "ignore", "include_dir", "indexmap 2.12.1", @@ -3014,7 +3023,6 @@ dependencies = [ "regex", "reqwest 0.12.28", "rmcp", - "sacp", "schemars 1.2.0", "serde", "serde_json", @@ -3051,6 +3059,31 @@ dependencies = [ "zip 0.6.6", ] +[[package]] +name = "goose-acp" +version = "1.19.0" +dependencies = [ + "anyhow", + "assert-json-diff", + "axum 0.8.8", + "fs-err", + "futures", + "goose", + "goose-mcp", + "regex", + "rmcp", + "sacp", + "serde_json", + "tempfile", + "test-case", + "tokio", + "tokio-util", + "tower-http 0.6.8", + "tracing", + "url", + "wiremock", +] + [[package]] name = "goose-bench" version = "1.19.0" @@ -3078,7 +3111,6 @@ dependencies = [ name = "goose-cli" version = "1.19.0" dependencies = [ - "agent-client-protocol-schema", "anstream", "anyhow", "async-trait", @@ -3094,20 +3126,16 @@ dependencies = [ "etcetera 0.11.0", "futures", "goose", + "goose-acp", "goose-bench", "goose-mcp", "http 1.4.0", "indicatif", - "is-terminal", - "jsonschema", - "nix 0.30.1", - "once_cell", "open", "rand 0.8.5", "regex", "rmcp", "rustyline", - "sacp", "serde", "serde_json", "serde_yaml", @@ -3123,7 +3151,6 @@ dependencies = [ "tracing", "tracing-appender", "tracing-subscriber", - "url", "urlencoding", "uuid", "webbrowser", @@ -3239,6 +3266,7 @@ dependencies = [ "utoipa", "uuid", "winreg 0.55.0", + "wiremock", ] [[package]] @@ -6517,9 +6545,9 @@ checksum = "dd29631678d6fb0903b69223673e122c32e9ae559d0960a38d574695ebc0ea15" [[package]] name = "sacp" -version = "10.0.0" +version = "10.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47c1b52b3ee79933b19f2ce71945eaa17ef91ee68444e6716d05e335763af1a4" +checksum = "704f40d3c269b30229c34093b658ec80c4fac103281654b3965249c592dd6fa6" dependencies = [ "agent-client-protocol-schema", "anyhow", @@ -7796,9 +7824,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.48.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" dependencies = [ "bytes", "libc", @@ -7886,9 +7914,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -8083,6 +8111,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "async-compression", + "base64 0.22.1", "bitflags 2.10.0", "bytes", "futures-core", @@ -8090,13 +8119,19 @@ dependencies = [ "http 1.4.0", "http-body 1.0.1", "http-body-util", + "http-range-header", + "httpdate", "iri-string", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", "tokio", "tokio-util", "tower 0.5.2", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -8470,14 +8505,15 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.7" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ea49b40cb0d7..35c51f769021 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,11 +16,19 @@ string_slice = "warn" [workspace.dependencies] rmcp = { version = "0.12.0", features = ["schemars", "auth"] } -sacp = "10.0.0" +anyhow = "1.0" +futures = "0.3" +regex = "1.12" +serde_json = "1.0" +tokio = { version = "1.49", features = ["full"] } +tracing = "0.1" webbrowser = "1.0" which = "8.0.0" etcetera = "0.11.0" ignore = "0.4.25" +env-lock = "1.0.1" +wiremock = "0.6" +serial_test = "3.2.0" # Patch for Windows cross-compilation issue with crunchy [patch.crates-io] diff --git a/crates/goose-acp/Cargo.toml b/crates/goose-acp/Cargo.toml new file mode 100644 index 000000000000..234b5ea5eaf7 --- /dev/null +++ b/crates/goose-acp/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "goose-acp" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description.workspace = true + +[lints] +workspace = true + +[dependencies] +goose = { path = "../goose" } +goose-mcp = { path = "../goose-mcp" } +rmcp = { workspace = true } +sacp = "10.1.0" +anyhow = { workspace = true } +tokio = { workspace = true } +tokio-util = { version = "0.7.15", features = ["compat", "rt"] } +tracing = { workspace = true } +url = "2.5" +serde_json = { workspace = true } +futures = { workspace = true } +regex = { workspace = true } +fs-err = "3" + +[dev-dependencies] +assert-json-diff = "2.0.2" +wiremock = { workspace = true } +tempfile = "3" +test-case = "3.3" +axum = "0.8" +tower-http = { version = "0.6", features = ["cors", "fs", "auth"] } +rmcp = { workspace = true, features = ["transport-streamable-http-server"] } diff --git a/crates/goose-acp/src/lib.rs b/crates/goose-acp/src/lib.rs new file mode 100644 index 000000000000..74f47ad347da --- /dev/null +++ b/crates/goose-acp/src/lib.rs @@ -0,0 +1 @@ +pub mod server; diff --git a/crates/goose-cli/src/commands/acp.rs b/crates/goose-acp/src/server.rs similarity index 83% rename from crates/goose-cli/src/commands/acp.rs rename to crates/goose-acp/src/server.rs index 77d9308a21e6..c3b73c561e80 100644 --- a/crates/goose-cli/src/commands/acp.rs +++ b/crates/goose-acp/src/server.rs @@ -1,13 +1,17 @@ use anyhow::Result; -use goose::agents::extension::{Envs, PlatformExtensionContext, PLATFORM_EXTENSIONS}; -use goose::agents::{Agent, ExtensionConfig, SessionConfig}; -use goose::config::{get_all_extensions, Config}; +use fs_err as fs; +use goose::agents::extension::{Envs, PLATFORM_EXTENSIONS}; +use goose::agents::{Agent, AgentConfig, ExtensionConfig, SessionConfig}; +use goose::config::paths::Paths; +use goose::config::permission::PermissionManager; +use goose::config::Config; use goose::conversation::message::{ActionRequiredData, Message, MessageContent}; use goose::conversation::Conversation; use goose::mcp_utils::ToolResult; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::{Permission, PermissionConfirmation}; use goose::providers::create; +use goose::scheduler_trait::unavailable_scheduler; use goose::session::session_manager::SessionType; use goose::session::SessionManager; use rmcp::model::{CallToolResult, RawContent, ResourceContents, Role}; @@ -16,18 +20,16 @@ use sacp::schema::{ CancelNotification, Content, ContentBlock, ContentChunk, EmbeddedResource, EmbeddedResourceResource, ImageContent, InitializeRequest, InitializeResponse, LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, NewSessionRequest, - NewSessionResponse, PermissionOption, PermissionOptionId, PermissionOptionKind, - PromptCapabilities, PromptRequest, PromptResponse, RequestPermissionOutcome, - RequestPermissionRequest, ResourceLink, SessionId, SessionNotification, SessionUpdate, - StopReason, TextContent, TextResourceContents, ToolCall, ToolCallContent, ToolCallId, - ToolCallLocation, ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind, + NewSessionResponse, PermissionOption, PermissionOptionKind, PromptCapabilities, PromptRequest, + PromptResponse, RequestPermissionOutcome, RequestPermissionRequest, ResourceLink, SessionId, + SessionNotification, SessionUpdate, StopReason, TextContent, TextResourceContents, ToolCall, + ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate, + ToolCallUpdateFields, ToolKind, }; use sacp::{AgentToClient, ByteStreams, Handled, JrConnectionCx, JrMessageHandler, MessageCx}; -use std::collections::{HashMap, HashSet}; -use std::fs; +use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; -use tokio::task::JoinSet; use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; @@ -39,11 +41,21 @@ struct GooseAcpSession { cancel_token: Option, } -struct GooseAcpAgent { +pub struct GooseAcpAgent { sessions: Arc>>, agent: Arc, } +/// Configuration for GooseAcpAgent (library mode - no global config) +pub struct GooseAcpConfig { + pub provider: Arc, + pub builtins: Vec, + pub work_dir: std::path::PathBuf, + pub data_dir: std::path::PathBuf, + pub config_dir: std::path::PathBuf, + pub goose_mode: goose::config::GooseMode, +} + fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result { match mcp_server { McpServer::Stdio(stdio) => Ok(ExtensionConfig::Stdio { @@ -259,15 +271,16 @@ async fn add_builtins(agent: &Agent, builtins: Vec) { available_tools: Vec::new(), } }; + match agent.add_extension(config).await { - Ok(_) => info!(extension = %builtin, "builtin extension loaded"), - Err(e) => warn!(extension = %builtin, error = %e, "builtin extension load failed"), + Ok(_) => info!(extension = %builtin, "extension loaded"), + Err(e) => warn!(extension = %builtin, error = %e, "extension load failed"), } } } impl GooseAcpAgent { - async fn new(builtins: Vec) -> Result { + pub async fn new(builtins: Vec) -> Result { let config = Config::global(); let provider_name: String = config @@ -288,65 +301,36 @@ impl GooseAcpAgent { fast_model: None, }; let provider = create(&provider_name, model_config).await?; + let goose_mode = config + .get_goose_mode() + .unwrap_or(goose::config::GooseMode::Approve); + + Self::with_config(GooseAcpConfig { + provider, + builtins, + work_dir: std::env::current_dir().unwrap_or_default(), + data_dir: Paths::data_dir(), + config_dir: Paths::config_dir(), + goose_mode, + }) + .await + } - let session = SessionManager::create_session( - std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), - "ACP Session".to_string(), - SessionType::Hidden, - ) - .await?; - - let agent = Agent::new(); - agent.update_provider(provider.clone(), &session.id).await?; + pub async fn with_config(config: GooseAcpConfig) -> Result { + let session_manager = Arc::new(SessionManager::new(config.data_dir)); + let permission_manager = Arc::new(PermissionManager::new(config.config_dir)); - let extensions_to_run: Vec<_> = get_all_extensions() - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect(); + let agent = Agent::with_config(AgentConfig::new( + Arc::clone(&session_manager), + permission_manager, + config.goose_mode, + unavailable_scheduler(), + )); + agent.set_provider(config.provider.clone()).await; let agent_ptr = Arc::new(agent); - // ACP loads the same default extensions as CLI - agent_ptr - .extension_manager - .set_context(PlatformExtensionContext { - session_id: Some(session.id.clone()), - extension_manager: Some(Arc::downgrade(&agent_ptr.extension_manager)), - }) - .await; - - let mut set = JoinSet::new(); - let mut waiting_on = HashSet::new(); - - for extension in extensions_to_run { - waiting_on.insert(extension.name()); - let agent_ptr_clone = agent_ptr.clone(); - set.spawn(async move { - ( - extension.name(), - agent_ptr_clone.add_extension(extension.clone()).await, - ) - }); - } - - while let Some(result) = set.join_next().await { - match result { - Ok((name, Ok(_))) => { - waiting_on.remove(&name); - info!(extension = %name, "extension loaded"); - } - Ok((name, Err(e))) => { - warn!(extension = %name, error = %e, "extension load failed"); - waiting_on.remove(&name); - } - Err(e) => { - error!(error = %e, "extension task error"); - } - } - } - - add_builtins(&agent_ptr, builtins).await; + add_builtins(&agent_ptr, config.builtins).await; Ok(Self { sessions: Arc::new(Mutex::new(HashMap::new())), @@ -407,7 +391,7 @@ impl GooseAcpAgent { cx.send_notification(SessionNotification::new( session_id.clone(), SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new(&text.text), + TextContent::new(text.text.clone()), ))), ))?; } @@ -424,7 +408,7 @@ impl GooseAcpAgent { cx.send_notification(SessionNotification::new( session_id.clone(), SessionUpdate::AgentThoughtChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new(&thinking.thinking), + TextContent::new(thinking.thinking.clone()), ))), ))?; } @@ -493,11 +477,10 @@ impl GooseAcpAgent { session: &mut GooseAcpSession, cx: &JrConnectionCx, ) -> Result<(), sacp::Error> { - // Determine if the tool call succeeded or failed - let status = if tool_response.tool_result.is_ok() { - ToolCallStatus::Completed - } else { - ToolCallStatus::Failed + let status = match &tool_response.tool_result { + Ok(result) if result.is_error == Some(true) => ToolCallStatus::Failed, + Ok(_) => ToolCallStatus::Completed, + Err(_) => ToolCallStatus::Failed, }; let content = build_tool_call_content(&tool_response.tool_result); @@ -559,12 +542,13 @@ impl GooseAcpAgent { .as_str() .unwrap() .to_string(); - PermissionOption::new(PermissionOptionId::from(id.clone()), id, kind) + PermissionOption::new(id.clone(), id, kind) } let options = vec![ option(PermissionOptionKind::AllowAlways), option(PermissionOptionKind::AllowOnce), option(PermissionOptionKind::RejectOnce), + option(PermissionOptionKind::RejectAlways), ]; let permission_request = @@ -607,13 +591,12 @@ fn outcome_to_confirmation(outcome: &RequestPermissionOutcome) -> PermissionConf RequestPermissionOutcome::Cancelled => Permission::Cancel, RequestPermissionOutcome::Selected(selected) => { match serde_json::from_value::(serde_json::Value::String( - selected.option_id.to_string(), + selected.option_id.0.to_string(), )) { Ok(PermissionOptionKind::AllowAlways) => Permission::AlwaysAllow, Ok(PermissionOptionKind::AllowOnce) => Permission::AllowOnce, - Ok(PermissionOptionKind::RejectOnce | PermissionOptionKind::RejectAlways) => { - Permission::DenyOnce - } + Ok(PermissionOptionKind::RejectOnce) => Permission::DenyOnce, + Ok(PermissionOptionKind::RejectAlways) => Permission::AlwaysDeny, Ok(_) => Permission::Cancel, // Handle any future permission kinds Err(_) => Permission::Cancel, } @@ -633,10 +616,10 @@ fn build_tool_call_content(tool_result: &ToolResult) -> Vec Some(ToolCallContent::Content(Content::new( - ContentBlock::Text(TextContent::new(&val.text)), + ContentBlock::Text(TextContent::new(val.text.clone())), ))), RawContent::Image(val) => Some(ToolCallContent::Content(Content::new( - ContentBlock::Image(ImageContent::new(&val.data, &val.mime_type)), + ContentBlock::Image(ImageContent::new(val.data.clone(), val.mime_type.clone())), ))), RawContent::Resource(val) => { let resource = match &val.resource { @@ -645,25 +628,19 @@ fn build_tool_call_content(tool_result: &ToolResult) -> Vec { - let mut r = TextResourceContents::new(text.clone(), uri.clone()); - if let Some(mt) = mime_type { - r = r.mime_type(mt.clone()); - } - EmbeddedResourceResource::TextResourceContents(r) - } + } => EmbeddedResourceResource::TextResourceContents( + TextResourceContents::new(text.clone(), uri.clone()) + .mime_type(mime_type.clone()), + ), ResourceContents::BlobResourceContents { mime_type, blob, uri, .. - } => { - let mut r = BlobResourceContents::new(blob.clone(), uri.clone()); - if let Some(mt) = mime_type { - r = r.mime_type(mt.clone()); - } - EmbeddedResourceResource::BlobResourceContents(r) - } + } => EmbeddedResourceResource::BlobResourceContents( + BlobResourceContents::new(blob.clone(), uri.clone()) + .mime_type(mime_type.clone()), + ), }; Some(ToolCallContent::Content(Content::new( ContentBlock::Resource(EmbeddedResource::new(resource)), @@ -709,45 +686,42 @@ impl GooseAcpAgent { ) -> Result { debug!(?args, "new session request"); - let goose_session = SessionManager::create_session( - std::env::current_dir().unwrap_or_default(), - "ACP Session".to_string(), // just an initial name - may be replaced by maybe_update_name - SessionType::User, - ) - .await - .map_err(|e| { - sacp::Error::new( - sacp::ErrorCode::InternalError.into(), - format!("Failed to create session: {}", e), + let manager = self.agent.session_manager(); + let goose_session = manager + .create_session( + std::env::current_dir().unwrap_or_default(), + "ACP Session".to_string(), // just an initial name - may be replaced by maybe_update_name + SessionType::User, ) - })?; - - let session = GooseAcpSession { - messages: Conversation::new_unvalidated(Vec::new()), - tool_requests: HashMap::new(), - cancel_token: None, - }; - - let mut sessions = self.sessions.lock().await; - sessions.insert(goose_session.id.clone(), session); + .await + .map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to create session: {}", e)) + })?; // Add MCP servers specified in the session request for mcp_server in args.mcp_servers { let config = match mcp_server_to_extension_config(mcp_server) { Ok(c) => c, Err(msg) => { - return Err(sacp::Error::new(sacp::ErrorCode::InvalidParams.into(), msg)); + return Err(sacp::Error::invalid_params().data(msg)); } }; let name = config.name().to_string(); if let Err(e) = self.agent.add_extension(config).await { - return Err(sacp::Error::new( - sacp::ErrorCode::InternalError.into(), - format!("Failed to add MCP server '{}': {}", name, e), - )); + return Err(sacp::Error::internal_error() + .data(format!("Failed to add MCP server '{}': {}", name, e))); } } + let session = GooseAcpSession { + messages: Conversation::new_unvalidated(Vec::new()), + tool_requests: HashMap::new(), + cancel_token: None, + }; + + let mut sessions = self.sessions.lock().await; + sessions.insert(goose_session.id.clone(), session); + info!( session_id = %goose_session.id, session_type = "acp", @@ -766,31 +740,25 @@ impl GooseAcpAgent { let session_id = args.session_id.0.to_string(); - let goose_session = SessionManager::get_session(&session_id, true) - .await - .map_err(|e| { - sacp::Error::new( - sacp::ErrorCode::InvalidParams.into(), - format!("Failed to load session {}: {}", session_id, e), - ) - })?; + let manager = self.agent.session_manager(); + let goose_session = manager.get_session(&session_id, true).await.map_err(|e| { + sacp::Error::invalid_params() + .data(format!("Failed to load session {}: {}", session_id, e)) + })?; let conversation = goose_session.conversation.ok_or_else(|| { - sacp::Error::new( - sacp::ErrorCode::InternalError.into(), - format!("Session {} has no conversation data", session_id), - ) + sacp::Error::internal_error() + .data(format!("Session {} has no conversation data", session_id)) })?; - SessionManager::update_session(&session_id) + manager + .update(&session_id) .working_dir(args.cwd.clone()) .apply() .await .map_err(|e| { - sacp::Error::new( - sacp::ErrorCode::InternalError.into(), - format!("Failed to update session working directory: {}", e), - ) + sacp::Error::internal_error() + .data(format!("Failed to update session working directory: {}", e)) })?; let mut session = GooseAcpSession { @@ -809,8 +777,9 @@ impl GooseAcpAgent { for content_item in &message.content { match content_item { MessageContent::Text(text) => { - let chunk = - ContentChunk::new(ContentBlock::Text(TextContent::new(&text.text))); + let chunk = ContentChunk::new(ContentBlock::Text(TextContent::new( + text.text.clone(), + ))); let update = match message.role { Role::User => SessionUpdate::UserMessageChunk(chunk), Role::Assistant => SessionUpdate::AgentMessageChunk(chunk), @@ -837,7 +806,7 @@ impl GooseAcpAgent { cx.send_notification(SessionNotification::new( args.session_id.clone(), SessionUpdate::AgentThoughtChunk(ContentChunk::new( - ContentBlock::Text(TextContent::new(&thinking.thinking)), + ContentBlock::Text(TextContent::new(thinking.thinking.clone())), )), ))?; } @@ -871,10 +840,7 @@ impl GooseAcpAgent { { let mut sessions = self.sessions.lock().await; let session = sessions.get_mut(&session_id).ok_or_else(|| { - sacp::Error::new( - sacp::ErrorCode::InvalidParams.into(), - format!("Session not found: {}", session_id), - ) + sacp::Error::invalid_params().data(format!("Session not found: {}", session_id)) })?; session.cancel_token = Some(cancel_token.clone()); } @@ -893,10 +859,7 @@ impl GooseAcpAgent { .reply(user_message, session_config, Some(cancel_token.clone())) .await .map_err(|e| { - sacp::Error::new( - sacp::ErrorCode::InternalError.into(), - format!("Error getting agent reply: {}", e), - ) + sacp::Error::internal_error().data(format!("Error getting agent reply: {}", e)) })?; use futures::StreamExt; @@ -913,10 +876,8 @@ impl GooseAcpAgent { Ok(goose::agents::AgentEvent::Message(message)) => { let mut sessions = self.sessions.lock().await; let session = sessions.get_mut(&session_id).ok_or_else(|| { - sacp::Error::new( - sacp::ErrorCode::InvalidParams.into(), - format!("Session not found: {}", session_id), - ) + sacp::Error::invalid_params() + .data(format!("Session not found: {}", session_id)) })?; session.messages.push(message.clone()); @@ -928,10 +889,8 @@ impl GooseAcpAgent { } Ok(_) => {} Err(e) => { - return Err(sacp::Error::new( - sacp::ErrorCode::InternalError.into(), - format!("Error in agent response stream: {}", e), - )); + return Err(sacp::Error::internal_error() + .data(format!("Error in agent response stream: {}", e))); } } } @@ -941,12 +900,11 @@ impl GooseAcpAgent { session.cancel_token = None; } - let stop_reason = if was_cancelled { + Ok(PromptResponse::new(if was_cancelled { StopReason::Cancelled } else { StopReason::EndTurn - }; - Ok(PromptResponse::new(stop_reason)) + })) } async fn on_cancel(&self, args: CancelNotification) -> Result<(), sacp::Error> { @@ -968,8 +926,8 @@ impl GooseAcpAgent { } } -struct GooseAcpHandler { - agent: Arc, +pub struct GooseAcpHandler { + pub agent: Arc, } impl JrMessageHandler for GooseAcpHandler { @@ -1041,48 +999,49 @@ impl JrMessageHandler for GooseAcpHandler { } } -pub async fn run_acp_agent(builtins: Vec) -> Result<()> { - info!("listening on stdio"); - - let outgoing = tokio::io::stdout().compat_write(); - let incoming = tokio::io::stdin().compat(); - - let agent = Arc::new(GooseAcpAgent::new(builtins).await?); +/// Serve ACP on a given transport (for in-process testing) +pub async fn serve(agent: Arc, read: R, write: W) -> Result<()> +where + R: futures::AsyncRead + Unpin + Send + 'static, + W: futures::AsyncWrite + Unpin + Send + 'static, +{ let handler = GooseAcpHandler { agent }; AgentToClient::builder() .name("goose-acp") .with_handler(handler) - .serve(ByteStreams::new(outgoing, incoming)) + .serve(ByteStreams::new(write, read)) .await?; Ok(()) } +pub async fn run(builtins: Vec) -> Result<()> { + info!("listening on stdio"); + + let outgoing = tokio::io::stdout().compat_write(); + let incoming = tokio::io::stdin().compat(); + + let agent = Arc::new(GooseAcpAgent::new(builtins).await?); + serve(agent, incoming, outgoing).await +} + #[cfg(test)] mod tests { use super::*; use sacp::schema::{ EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio, - ResourceLink, SelectedPermissionOutcome, + PermissionOptionId, ResourceLink, SelectedPermissionOutcome, }; use std::io::Write; use tempfile::NamedTempFile; use test_case::test_case; - use crate::commands::acp::{ - format_tool_name, mcp_server_to_extension_config, read_resource_link, - }; - use goose::agents::ExtensionConfig; - #[test_case( McpServer::Stdio( McpServerStdio::new("github", "/path/to/github-mcp-server") .args(vec!["stdio".into()]) - .env(vec![EnvVariable::new( - "GITHUB_PERSONAL_ACCESS_TOKEN", - "ghp_xxxxxxxxxxxx" - )]) + .env(vec![EnvVariable::new("GITHUB_PERSONAL_ACCESS_TOKEN", "ghp_xxxxxxxxxxxx")]) ), Ok(ExtensionConfig::Stdio { name: "github".into(), @@ -1195,27 +1154,27 @@ print(\"hello, world\") } #[test_case( - RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("allow_once")), + RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(PermissionOptionId::from("allow_once".to_string()))), PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AllowOnce }; "allow_once_maps_to_allow_once" )] #[test_case( - RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("allow_always")), + RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(PermissionOptionId::from("allow_always".to_string()))), PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AlwaysAllow }; "allow_always_maps_to_always_allow" )] #[test_case( - RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("reject_once")), + RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(PermissionOptionId::from("reject_once".to_string()))), PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::DenyOnce }; "reject_once_maps_to_deny_once" )] #[test_case( - RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("reject_always")), - PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::DenyOnce }; - "reject_always_maps_to_deny_once" + RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(PermissionOptionId::from("reject_always".to_string()))), + PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AlwaysDeny }; + "reject_always_maps_to_always_deny" )] #[test_case( - RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("unknown")), + RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(PermissionOptionId::from("unknown".to_string()))), PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::Cancel }; "unknown_option_maps_to_cancel" )] diff --git a/crates/goose-acp/tests/common.rs b/crates/goose-acp/tests/common.rs new file mode 100644 index 000000000000..b2ab84911bb3 --- /dev/null +++ b/crates/goose-acp/tests/common.rs @@ -0,0 +1,127 @@ +use assert_json_diff::{assert_json_matches_no_panic, CompareMode, Config}; +use rmcp::transport::streamable_http_server::{ + session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService, +}; +use rmcp::{ + handler::server::router::tool::ToolRouter, model::*, tool, tool_handler, tool_router, + ErrorData as McpError, ServerHandler, +}; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; +use tokio::task::JoinHandle; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +pub const FAKE_CODE: &str = "test-uuid-12345-67890"; + +/// Mock OpenAI streaming endpoint. Exchanges are (pattern, response) pairs. +/// On mismatch, returns 417 of the diff in OpenAI error format. +pub async fn setup_mock_openai(exchanges: Vec<(String, &'static str)>) -> MockServer { + let mock_server = MockServer::start().await; + let queue: VecDeque<(String, &'static str)> = exchanges.into_iter().collect(); + let queue = Arc::new(Mutex::new(queue)); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with({ + let queue = queue.clone(); + move |req: &wiremock::Request| { + let body = String::from_utf8_lossy(&req.body); + + // Special case session rename request which doesn't happen in a predictable order. + if body.contains("Reply with only a description in four words or less") { + return ResponseTemplate::new(200) + .insert_header("content-type", "application/json") + .set_body_string(include_str!( + "./test_data/openai_session_description.json" + )); + } + + let (expected, response) = { + let mut q = queue.lock().unwrap(); + q.pop_front().unwrap_or_default() + }; + + if body.contains(&expected) && !expected.is_empty() { + return ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_string(response); + } + + // Coerce non-json to allow a uniform JSON diff error response. + let exp = serde_json::from_str(&expected) + .unwrap_or(serde_json::Value::String(expected.clone())); + let act = serde_json::from_str(&body) + .unwrap_or(serde_json::Value::String(body.to_string())); + let diff = + assert_json_matches_no_panic(&exp, &act, Config::new(CompareMode::Strict)) + .unwrap_err(); + ResponseTemplate::new(417) + .insert_header("content-type", "text/event-stream") + .set_body_json(serde_json::json!({"error": {"message": diff}})) + } + }) + .mount(&mock_server) + .await; + + mock_server +} + +#[derive(Clone)] +pub struct Lookup { + tool_router: ToolRouter, +} + +impl Default for Lookup { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] +impl Lookup { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + + #[tool(description = "Get the code")] + fn get_code(&self) -> Result { + Ok(CallToolResult::success(vec![Content::text(FAKE_CODE)])) + } +} + +#[tool_handler] +impl ServerHandler for Lookup { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2025_03_26, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation { + name: "lookup".into(), + version: "1.0.0".into(), + ..Default::default() + }, + instructions: Some("Lookup server with get_code tool.".into()), + } + } +} + +pub async fn spawn_mcp_http_server() -> (String, JoinHandle<()>) { + let service = StreamableHttpService::new( + || Ok(Lookup::new()), + LocalSessionManager::default().into(), + StreamableHttpServerConfig::default(), + ); + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let url = format!("http://{addr}/mcp"); + + let handle = tokio::spawn(async move { + axum::serve(listener, router).await.unwrap(); + }); + + (url, handle) +} diff --git a/crates/goose-acp/tests/server_test.rs b/crates/goose-acp/tests/server_test.rs new file mode 100644 index 000000000000..160342c7cd8e --- /dev/null +++ b/crates/goose-acp/tests/server_test.rs @@ -0,0 +1,400 @@ +mod common; + +use common::{setup_mock_openai, spawn_mcp_http_server, FAKE_CODE}; +use fs_err as fs; +use goose::config::GooseMode; +use goose::model::ModelConfig; +use goose::providers::api_client::{ApiClient, AuthMethod}; +use goose::providers::openai::OpenAiProvider; +use goose_acp::server::{serve, GooseAcpAgent, GooseAcpConfig}; +use sacp::schema::{ + ContentBlock, ContentChunk, InitializeRequest, McpServer, McpServerHttp, NewSessionRequest, + PermissionOptionKind, PromptRequest, ProtocolVersion, RequestPermissionOutcome, + RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, + SessionNotification, SessionUpdate, StopReason, TextContent, ToolCallId, ToolCallStatus, + ToolCallUpdate, ToolCallUpdateFields, +}; +use sacp::{ClientToAgent, JrConnectionCx}; +use std::path::Path; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use test_case::test_case; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; +use wiremock::MockServer; + +#[tokio::test] +async fn test_acp_basic_completion() { + let temp_dir = tempfile::tempdir().unwrap(); + let prompt = "what is 1+1"; + let mock_server = setup_mock_openai(vec![( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_basic_response.txt"), + )]) + .await; + + run_acp_session( + &mock_server, + vec![], + &[], + temp_dir.path(), + GooseMode::Auto, + None, + |cx, session_id, updates| async move { + let response = cx + .send_request(PromptRequest::new( + session_id, + vec![ContentBlock::Text(TextContent::new(prompt))], + )) + .block_task() + .await + .unwrap(); + + assert_eq!(response.stop_reason, StopReason::EndTurn); + wait_for( + &updates, + &SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( + TextContent::new("2"), + ))), + ) + .await; + }, + ) + .await; +} + +#[tokio::test] +async fn test_acp_with_mcp_http_server() { + let temp_dir = tempfile::tempdir().unwrap(); + let prompt = "Use the get_code tool and output only its result."; + let (mcp_url, _handle) = spawn_mcp_http_server().await; + + let mock_server = setup_mock_openai(vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("./test_data/openai_tool_result_response.txt"), + ), + ]) + .await; + + run_acp_session( + &mock_server, + vec![McpServer::Http(McpServerHttp::new("lookup", mcp_url))], + &[], + temp_dir.path(), + GooseMode::Auto, + None, + |cx, session_id, updates| async move { + let response = cx + .send_request(PromptRequest::new( + session_id, + vec![ContentBlock::Text(TextContent::new(prompt))], + )) + .block_task() + .await + .unwrap(); + + assert_eq!(response.stop_reason, StopReason::EndTurn); + wait_for( + &updates, + &SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( + TextContent::new(FAKE_CODE), + ))), + ) + .await; + }, + ) + .await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_acp_with_builtin_and_mcp() { + let temp_dir = tempfile::tempdir().unwrap(); + let prompt = + "Search for get_code and text_editor tools. Use them to save the code to /tmp/result.txt."; + let (lookup_url, _lookup_handle) = spawn_mcp_http_server().await; + + let mock_server = setup_mock_openai(vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_builtin_search.txt"), + ), + ( + r#"lookup/get_code: Get the code"#.into(), + include_str!("./test_data/openai_builtin_read_modules.txt"), + ), + ( + r#"lookup[\"get_code\"]({}): string - Get the code"#.into(), + include_str!("./test_data/openai_builtin_execute.txt"), + ), + ( + r#"Successfully wrote to /tmp/result.txt"#.into(), + include_str!("./test_data/openai_builtin_final.txt"), + ), + ]) + .await; + + run_acp_session( + &mock_server, + vec![McpServer::Http(McpServerHttp::new("lookup", lookup_url))], + &["code_execution", "developer"], + temp_dir.path(), + GooseMode::Auto, + None, + |cx, session_id, updates| async move { + let response = cx + .send_request(PromptRequest::new( + session_id, + vec![ContentBlock::Text(TextContent::new(prompt))], + )) + .block_task() + .await + .unwrap(); + + assert_eq!(response.stop_reason, StopReason::EndTurn); + wait_for( + &updates, + &SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( + TextContent::new(FAKE_CODE), + ))), + ) + .await; + }, + ) + .await; +} + +async fn wait_for(updates: &Arc>>, expected: &SessionUpdate) { + let deadline = tokio::time::Instant::now() + Duration::from_millis(500); + let mut context = String::new(); + + loop { + let matched = { + let guard = updates.lock().unwrap(); + context.clear(); + + match expected { + SessionUpdate::AgentMessageChunk(chunk) => { + let expected_text = match &chunk.content { + ContentBlock::Text(t) => &t.text, + other => panic!("wait_for: unhandled content {:?}", other), + }; + for n in guard.iter() { + if let SessionUpdate::AgentMessageChunk(c) = &n.update { + if let ContentBlock::Text(t) = &c.content { + if t.text.is_empty() { + context.clear(); + } else { + context.push_str(&t.text); + } + } + } + } + context.contains(expected_text) + } + SessionUpdate::ToolCallUpdate(expected_update) => { + for n in guard.iter() { + if let SessionUpdate::ToolCallUpdate(u) = &n.update { + context.push_str(&format!("{:?}\n", u)); + if u.fields.status == expected_update.fields.status { + return; + } + } + } + false + } + other => panic!("wait_for: unhandled update {:?}", other), + } + }; + + if matched { + return; + } + if tokio::time::Instant::now() > deadline { + panic!("Timeout waiting for {:?}\n\n{}", expected, context); + } + tokio::task::yield_now().await; + } +} + +async fn spawn_server_in_process( + mock_server: &MockServer, + builtins: &[&str], + data_root: &Path, + goose_mode: GooseMode, +) -> ( + tokio::io::DuplexStream, + tokio::io::DuplexStream, + tokio::task::JoinHandle<()>, +) { + let api_client = ApiClient::new( + mock_server.uri(), + AuthMethod::BearerToken("test-key".to_string()), + ) + .unwrap(); + let model_config = ModelConfig::new("gpt-5-nano").unwrap(); + let provider = OpenAiProvider::new(api_client, model_config); + + let config = GooseAcpConfig { + provider: Arc::new(provider), + builtins: builtins.iter().map(|s| s.to_string()).collect(), + work_dir: data_root.to_path_buf(), + data_dir: data_root.to_path_buf(), + config_dir: data_root.to_path_buf(), + goose_mode, + }; + + let (client_read, server_write) = tokio::io::duplex(64 * 1024); + let (server_read, client_write) = tokio::io::duplex(64 * 1024); + + let agent = Arc::new(GooseAcpAgent::with_config(config).await.unwrap()); + let handle = tokio::spawn(async move { + if let Err(e) = serve(agent, server_read.compat(), server_write.compat_write()).await { + tracing::error!("ACP server error: {e}"); + } + }); + + (client_read, client_write, handle) +} + +async fn run_acp_session( + mock_server: &MockServer, + mcp_servers: Vec, + builtins: &[&str], + data_root: &Path, + mode: GooseMode, + select: Option, + test_fn: F, +) where + F: FnOnce( + JrConnectionCx, + sacp::schema::SessionId, + Arc>>, + ) -> Fut, + Fut: std::future::Future, +{ + let (client_read, client_write, _handle) = + spawn_server_in_process(mock_server, builtins, data_root, mode).await; + let work_dir = tempfile::tempdir().unwrap(); + let updates = Arc::new(Mutex::new(Vec::new())); + + let transport = sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); + + ClientToAgent::builder() + .on_receive_notification( + { + let updates = updates.clone(); + async move |notification: SessionNotification, _cx| { + updates.lock().unwrap().push(notification); + Ok(()) + } + }, + sacp::on_receive_notification!(), + ) + .on_receive_request( + async move |req: RequestPermissionRequest, request_cx, _connection_cx| { + let response = match select { + Some(kind) => { + let id = req + .options + .iter() + .find(|o| o.kind == kind) + .unwrap() + .option_id + .clone(); + RequestPermissionResponse::new(RequestPermissionOutcome::Selected( + SelectedPermissionOutcome::new(id), + )) + } + None => RequestPermissionResponse::new(RequestPermissionOutcome::Cancelled), + }; + request_cx.respond(response) + }, + sacp::on_receive_request!(), + ) + .connect_to(transport) + .unwrap() + .run_until({ + let updates = updates.clone(); + move |cx: JrConnectionCx| async move { + cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) + .block_task() + .await + .unwrap(); + + let session = cx + .send_request(NewSessionRequest::new(work_dir.path()).mcp_servers(mcp_servers)) + .block_task() + .await + .unwrap(); + + test_fn(cx.clone(), session.session_id, updates).await; + Ok(()) + } + }) + .await + .unwrap(); +} + +#[test_case(Some(PermissionOptionKind::AllowAlways), ToolCallStatus::Completed, "user:\n always_allow:\n - lookup__get_code\n ask_before: []\n never_allow: []\n"; "allow_always")] +#[test_case(Some(PermissionOptionKind::AllowOnce), ToolCallStatus::Completed, ""; "allow_once")] +#[test_case(Some(PermissionOptionKind::RejectAlways), ToolCallStatus::Failed, "user:\n always_allow: []\n ask_before: []\n never_allow:\n - lookup__get_code\n"; "reject_always")] +#[test_case(Some(PermissionOptionKind::RejectOnce), ToolCallStatus::Failed, ""; "reject_once")] +#[test_case(None, ToolCallStatus::Failed, ""; "cancelled")] +#[tokio::test] +async fn test_permission_persistence( + kind: Option, + expected_status: ToolCallStatus, + expected_yaml: &str, +) { + let temp_dir = tempfile::tempdir().unwrap(); + let prompt = "Use the get_code tool and output only its result."; + let (mcp_url, _handle) = spawn_mcp_http_server().await; + + let mock_server = setup_mock_openai(vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("./test_data/openai_tool_result_response.txt"), + ), + ]) + .await; + + run_acp_session( + &mock_server, + vec![McpServer::Http(McpServerHttp::new("lookup", mcp_url))], + &[], + temp_dir.path(), + GooseMode::Approve, + kind, + |cx, session_id, updates| async move { + cx.send_request(PromptRequest::new( + session_id, + vec![ContentBlock::Text(TextContent::new(prompt))], + )) + .block_task() + .await + .unwrap(); + wait_for( + &updates, + &SessionUpdate::ToolCallUpdate(ToolCallUpdate::new( + ToolCallId::new(""), + ToolCallUpdateFields::new().status(Some(expected_status)), + )), + ) + .await; + }, + ) + .await; + + assert_eq!( + fs::read_to_string(temp_dir.path().join("permission.yaml")).unwrap_or_default(), + expected_yaml + ); +} diff --git a/crates/goose/tests/test_data/openai_basic_response.txt b/crates/goose-acp/tests/test_data/openai_basic_response.txt similarity index 100% rename from crates/goose/tests/test_data/openai_basic_response.txt rename to crates/goose-acp/tests/test_data/openai_basic_response.txt diff --git a/crates/goose/tests/test_data/openai_builtin_execute.txt b/crates/goose-acp/tests/test_data/openai_builtin_execute.txt similarity index 100% rename from crates/goose/tests/test_data/openai_builtin_execute.txt rename to crates/goose-acp/tests/test_data/openai_builtin_execute.txt diff --git a/crates/goose/tests/test_data/openai_builtin_final.txt b/crates/goose-acp/tests/test_data/openai_builtin_final.txt similarity index 100% rename from crates/goose/tests/test_data/openai_builtin_final.txt rename to crates/goose-acp/tests/test_data/openai_builtin_final.txt diff --git a/crates/goose/tests/test_data/openai_builtin_read_modules.txt b/crates/goose-acp/tests/test_data/openai_builtin_read_modules.txt similarity index 100% rename from crates/goose/tests/test_data/openai_builtin_read_modules.txt rename to crates/goose-acp/tests/test_data/openai_builtin_read_modules.txt diff --git a/crates/goose/tests/test_data/openai_builtin_search.txt b/crates/goose-acp/tests/test_data/openai_builtin_search.txt similarity index 100% rename from crates/goose/tests/test_data/openai_builtin_search.txt rename to crates/goose-acp/tests/test_data/openai_builtin_search.txt diff --git a/crates/goose/tests/test_data/openai_session_description.json b/crates/goose-acp/tests/test_data/openai_session_description.json similarity index 100% rename from crates/goose/tests/test_data/openai_session_description.json rename to crates/goose-acp/tests/test_data/openai_session_description.json diff --git a/crates/goose/tests/test_data/openai_tool_call_response.txt b/crates/goose-acp/tests/test_data/openai_tool_call_response.txt similarity index 100% rename from crates/goose/tests/test_data/openai_tool_call_response.txt rename to crates/goose-acp/tests/test_data/openai_tool_call_response.txt diff --git a/crates/goose/tests/test_data/openai_tool_result_response.txt b/crates/goose-acp/tests/test_data/openai_tool_result_response.txt similarity index 100% rename from crates/goose/tests/test_data/openai_tool_result_response.txt rename to crates/goose-acp/tests/test_data/openai_tool_result_response.txt diff --git a/crates/goose-bench/Cargo.toml b/crates/goose-bench/Cargo.toml index cabba6121a22..76a6949587e8 100644 --- a/crates/goose-bench/Cargo.toml +++ b/crates/goose-bench/Cargo.toml @@ -11,21 +11,21 @@ description.workspace = true workspace = true [dependencies] -anyhow = "1.0" +anyhow = { workspace = true } paste = "1.0" ctor = "0.2.7" goose = { path = "../goose" } rmcp = { workspace = true } async-trait = "0.1.89" chrono = { version = "0.4", features = ["serde"] } -serde_json = "1.0" +serde_json = { workspace = true } serde = { version = "1.0", features = ["derive"] } -tracing = "0.1" +tracing = { workspace = true } tracing-subscriber = { version = "0.3", features = ["registry"] } -tokio = { version = "1.43", features = ["full"] } +tokio = { workspace = true } include_dir = "0.7.4" once_cell = "1.19" -regex = "1.11.1" +regex = { workspace = true } dotenvy = "0.15.7" [target.'cfg(target_os = "windows")'.dependencies] diff --git a/crates/goose-bench/src/bench_session.rs b/crates/goose-bench/src/bench_session.rs index 972ad99f832e..4b9a21726184 100644 --- a/crates/goose-bench/src/bench_session.rs +++ b/crates/goose-bench/src/bench_session.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use goose::conversation::Conversation; +use goose::session::session_manager::Session; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -19,7 +20,7 @@ pub trait BenchBaseSession: Send + Sync { async fn headless(&mut self, message: String) -> anyhow::Result<()>; fn message_history(&self) -> Conversation; fn get_total_token_usage(&self) -> anyhow::Result>; - fn get_session_id(&self) -> anyhow::Result; + async fn get_session(&self) -> anyhow::Result; } // struct for managing agent-session-access. to be passed to evals for benchmarking pub struct BenchAgent { @@ -52,7 +53,7 @@ impl BenchAgent { self.session.get_total_token_usage().ok().flatten() } - pub(crate) fn get_session_id(&self) -> anyhow::Result { - self.session.get_session_id() + pub(crate) async fn get_session(&self) -> anyhow::Result { + self.session.get_session().await } } diff --git a/crates/goose-bench/src/runners/eval_runner.rs b/crates/goose-bench/src/runners/eval_runner.rs index 88eed45e63f8..e9966bb4fba4 100644 --- a/crates/goose-bench/src/runners/eval_runner.rs +++ b/crates/goose-bench/src/runners/eval_runner.rs @@ -5,7 +5,6 @@ use crate::eval_suites::{EvaluationSuite, ExtensionRequirements}; use crate::reporting::EvaluationResult; use crate::utilities::await_process_exits; use anyhow::{bail, Context, Result}; -use goose::session::SessionManager; use std::env; use std::fs; use std::future::Future; @@ -156,8 +155,7 @@ impl EvalRunner { .canonicalize() .context("Failed to canonicalize current directory path")?; - let session_id = agent.get_session_id()?.to_string(); - let session = SessionManager::get_session(&session_id, true).await?; + let session = agent.get_session().await?; let session_json = serde_json::to_string_pretty(&session) .context("Failed to serialize session to JSON")?; diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index a043894f3765..e7a45c8e77cd 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -16,38 +16,34 @@ path = "src/main.rs" [dependencies] goose = { path = "../goose" } +goose-acp = { path = "../goose-acp" } goose-bench = { path = "../goose-bench" } goose-mcp = { path = "../goose-mcp" } rmcp = { workspace = true } -sacp = { workspace = true } -agent-client-protocol-schema = "0.10.5" clap = { version = "4.4", features = ["derive"] } cliclack = "0.3.5" console = "0.16.1" uuid = { version = "1.11", features = ["v4"] } dotenvy = "0.15.7" bat = "0.25.0" -anyhow = "1.0" -serde_json = "1.0" -jsonschema = "0.30.0" -tokio = { version = "1.43", features = ["full"] } -futures = "0.3" +anyhow = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +futures = { workspace = true } serde = { version = "1.0", features = ["derive"] } # For serialization serde_yaml = "0.9" tempfile = "3" etcetera = { workspace = true } rand = "0.8.5" rustyline = "15.0.0" -tracing = "0.1" +tracing = { workspace = true } chrono = "0.4" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } tracing-appender = "0.2" -once_cell = "1.20.2" shlex = "1.3.0" async-trait = "0.1.89" base64 = "0.22.1" regex = "1.11.1" -nix = { version = "0.30.1", features = ["process", "signal"] } tar = "0.4" # Web server dependencies axum = { version = "0.8.1", features = ["ws", "macros"] } @@ -56,9 +52,7 @@ http = "1.0" webbrowser = {workspace = true} indicatif = "0.18.1" tokio-util = { version = "0.7.15", features = ["compat", "rt"] } -is-terminal = "0.4.16" anstream = "0.6.18" -url = "2.5.7" open = "5.3.2" urlencoding = "2.1" clap_complete = "4.5.62" @@ -74,5 +68,5 @@ disable-update = [] tempfile = "3" temp-env = { version = "0.3.6", features = ["async_closure"] } test-case = "3.3" -tokio = { version = "1.43", features = ["rt", "macros"] } -serial_test = "3.2.0" +tokio = { workspace = true } +serial_test = { workspace = true } diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 643f9277ff8f..92408d743431 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -7,7 +7,6 @@ use goose_mcp::{ AutoVisualiserRouter, ComputerControllerServer, DeveloperServer, MemoryServer, TutorialServer, }; -use crate::commands::acp::run_acp_agent; use crate::commands::bench::agent_generator; use crate::commands::configure::handle_configure; use crate::commands::info::handle_info; @@ -319,21 +318,24 @@ async fn get_or_create_session_id( return Ok(None); } + let session_manager = SessionManager::instance(); + let Some(id) = identifier else { return if resume { - let sessions = SessionManager::list_sessions().await?; + let sessions = session_manager.list_sessions().await?; let session_id = sessions .first() .map(|s| s.id.clone()) .ok_or_else(|| anyhow::anyhow!("No session found to resume"))?; Ok(Some(session_id)) } else { - let session = SessionManager::create_session( - std::env::current_dir()?, - "CLI Session".to_string(), - SessionType::User, - ) - .await?; + let session = session_manager + .create_session( + std::env::current_dir()?, + "CLI Session".to_string(), + SessionType::User, + ) + .await?; Ok(Some(session.id)) }; }; @@ -342,7 +344,7 @@ async fn get_or_create_session_id( Ok(Some(session_id)) } else if let Some(name) = id.name { if resume { - let sessions = SessionManager::list_sessions().await?; + let sessions = session_manager.list_sessions().await?; let session_id = sessions .into_iter() .find(|s| s.name == name || s.id == name) @@ -350,14 +352,12 @@ async fn get_or_create_session_id( .ok_or_else(|| anyhow::anyhow!("No session found with name '{}'", name))?; Ok(Some(session_id)) } else { - let session = SessionManager::create_session( - std::env::current_dir()?, - name.clone(), - SessionType::User, - ) - .await?; + let session = session_manager + .create_session(std::env::current_dir()?, name.clone(), SessionType::User) + .await?; - SessionManager::update_session(&session.id) + session_manager + .update(&session.id) .user_provided_name(name) .apply() .await?; @@ -372,12 +372,13 @@ async fn get_or_create_session_id( .ok_or_else(|| anyhow::anyhow!("Could not extract session ID from path: {:?}", path))?; Ok(Some(session_id)) } else { - let session = SessionManager::create_session( - std::env::current_dir()?, - "CLI Session".to_string(), - SessionType::User, - ) - .await?; + let session = session_manager + .create_session( + std::env::current_dir()?, + "CLI Session".to_string(), + SessionType::User, + ) + .await?; Ok(Some(session.id)) } } @@ -386,7 +387,8 @@ async fn lookup_session_id(identifier: Identifier) -> Result { if let Some(session_id) = identifier.session_id { Ok(session_id) } else if let Some(name) = identifier.name { - let sessions = SessionManager::list_sessions().await?; + let session_manager = SessionManager::instance(); + let sessions = session_manager.list_sessions().await?; sessions .into_iter() .find(|s| s.name == name || s.id == name) @@ -997,10 +999,15 @@ async fn handle_session_subcommand(command: SessionCommand) -> Result<()> { output, format, } => { + let session_manager = SessionManager::instance(); let session_identifier = if let Some(id) = identifier { lookup_session_id(id).await? } else { - match crate::commands::session::prompt_interactive_session_selection().await { + match crate::commands::session::prompt_interactive_session_selection( + &session_manager, + ) + .await + { Ok(id) => id, Err(e) => { eprintln!("Error: {}", e); @@ -1012,10 +1019,15 @@ async fn handle_session_subcommand(command: SessionCommand) -> Result<()> { .await?; } SessionCommand::Diagnostics { identifier, output } => { + let session_manager = SessionManager::instance(); let session_id = if let Some(id) = identifier { lookup_session_id(id).await? } else { - match crate::commands::session::prompt_interactive_session_selection().await { + match crate::commands::session::prompt_interactive_session_selection( + &session_manager, + ) + .await + { Ok(id) => id, Err(e) => { eprintln!("Error: {}", e); @@ -1447,7 +1459,7 @@ pub async fn cli() -> anyhow::Result<()> { Some(Command::Configure {}) => handle_configure().await, Some(Command::Info { verbose }) => handle_info(verbose), Some(Command::Mcp { server }) => handle_mcp_command(server).await, - Some(Command::Acp { builtins }) => run_acp_agent(builtins).await, + Some(Command::Acp { builtins }) => goose_acp::server::run(builtins).await, Some(Command::Session { command: Some(cmd), .. }) => handle_session_subcommand(cmd).await, diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs index 10bd42539d29..7d4522ff118f 100644 --- a/crates/goose-cli/src/commands/bench.rs +++ b/crates/goose-cli/src/commands/bench.rs @@ -3,6 +3,7 @@ use crate::session::SessionBuilderConfig; use crate::{logging, CliSession}; use async_trait::async_trait; use goose::conversation::Conversation; +use goose::session::session_manager::Session; use goose_bench::bench_session::{BenchAgent, BenchBaseSession}; use goose_bench::eval_suites::ExtensionRequirements; use std::sync::Arc; @@ -25,8 +26,8 @@ impl BenchBaseSession for CliSession { }) } - fn get_session_id(&self) -> anyhow::Result { - Ok(self.session_id().to_string()) + async fn get_session(&self) -> anyhow::Result { + self.get_session().await } } pub async fn agent_generator( diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index cf26f3acb6ab..117d9d4b99ce 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -21,7 +21,7 @@ use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::provider_test::test_provider_configuration; use goose::providers::{create, providers, retry_operation, RetryConfig}; -use goose::session::{SessionManager, SessionType}; +use goose::session::SessionType; use serde_json::Value; use std::collections::HashMap; @@ -1018,8 +1018,7 @@ pub fn remove_extension_dialog() -> anyhow::Result<()> { for name in selected { remove_extension(&name_to_key(name)); - let mut permission_manager = PermissionManager::default(); - permission_manager.remove_extension(&name_to_key(name)); + PermissionManager::instance().remove_extension(&name_to_key(name)); cliclack::outro(format!("Removed {} extension", style(name).green()))?; } @@ -1294,15 +1293,18 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { .expect("No model configured. Please set model first"); let model_config = ModelConfig::new(&model)?; - let session = SessionManager::create_session( - std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), - "Tool Permission Configuration".to_string(), - SessionType::Hidden, - ) - .await?; - let agent = Agent::new(); let new_provider = create(&provider_name, model_config).await?; + + let session = agent + .session_manager() + .create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "Tool Permission Configuration".to_string(), + SessionType::Hidden, + ) + .await?; + agent.update_provider(new_provider, &session.id).await?; if let Some(config) = get_extension_by_name(&selected_extension_name) { agent @@ -1324,9 +1326,9 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { return Ok(()); } - let mut permission_manager = PermissionManager::default(); + let permission_manager = PermissionManager::instance(); let selected_tools = agent - .list_tools(Some(selected_extension_name.clone())) + .list_tools(&session.id, Some(selected_extension_name.clone())) .await .into_iter() .map(|tool| { diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index e0d54e96780b..c6511ad7525c 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,4 +1,3 @@ -pub mod acp; pub mod bench; pub mod configure; pub mod info; diff --git a/crates/goose-cli/src/commands/schedule.rs b/crates/goose-cli/src/commands/schedule.rs index d67ffcc6bcc2..c1c0b74a703e 100644 --- a/crates/goose-cli/src/commands/schedule.rs +++ b/crates/goose-cli/src/commands/schedule.rs @@ -3,7 +3,9 @@ use goose::scheduler::{ get_default_scheduled_recipes_dir, get_default_scheduler_storage_path, ScheduledJob, Scheduler, SchedulerError, }; +use goose::session::SessionManager; use std::path::Path; +use std::sync::Arc; fn validate_cron_expression(cron: &str) -> Result<()> { // Basic validation and helpful suggestions @@ -90,7 +92,8 @@ pub async fn handle_schedule_add( let scheduler_storage_path = get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?; - let scheduler = Scheduler::new(scheduler_storage_path) + let session_manager = Arc::new(SessionManager::instance()); + let scheduler = Scheduler::new(scheduler_storage_path, session_manager) .await .context("Failed to initialize scheduler")?; @@ -136,7 +139,8 @@ pub async fn handle_schedule_add( pub async fn handle_schedule_list() -> Result<()> { let scheduler_storage_path = get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?; - let scheduler = Scheduler::new(scheduler_storage_path) + let session_manager = Arc::new(SessionManager::instance()); + let scheduler = Scheduler::new(scheduler_storage_path, session_manager) .await .context("Failed to initialize scheduler")?; @@ -171,7 +175,8 @@ pub async fn handle_schedule_list() -> Result<()> { pub async fn handle_schedule_remove(schedule_id: String) -> Result<()> { let scheduler_storage_path = get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?; - let scheduler = Scheduler::new(scheduler_storage_path) + let session_manager = Arc::new(SessionManager::instance()); + let scheduler = Scheduler::new(scheduler_storage_path, session_manager) .await .context("Failed to initialize scheduler")?; @@ -198,7 +203,8 @@ pub async fn handle_schedule_remove(schedule_id: String) -> Result<()> { pub async fn handle_schedule_sessions(schedule_id: String, limit: Option) -> Result<()> { let scheduler_storage_path = get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?; - let scheduler = Scheduler::new(scheduler_storage_path) + let session_manager = Arc::new(SessionManager::instance()); + let scheduler = Scheduler::new(scheduler_storage_path, session_manager) .await .context("Failed to initialize scheduler")?; @@ -234,7 +240,8 @@ pub async fn handle_schedule_sessions(schedule_id: String, limit: Option) pub async fn handle_schedule_run_now(schedule_id: String) -> Result<()> { let scheduler_storage_path = get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?; - let scheduler = Scheduler::new(scheduler_storage_path) + let session_manager = Arc::new(SessionManager::instance()); + let scheduler = Scheduler::new(scheduler_storage_path, session_manager) .await .context("Failed to initialize scheduler")?; diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index a145f01ffd33..24885848bd7f 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -11,7 +11,7 @@ use std::path::PathBuf; const TRUNCATED_DESC_LENGTH: usize = 60; -pub async fn remove_sessions(sessions: Vec) -> Result<()> { +async fn remove_sessions(session_manager: &SessionManager, sessions: Vec) -> Result<()> { println!("The following sessions will be removed:"); for session in &sessions { println!("- {} {}", session.id, session.name); @@ -23,7 +23,7 @@ pub async fn remove_sessions(sessions: Vec) -> Result<()> { if should_delete { for session in sessions { - SessionManager::delete_session(&session.id).await?; + session_manager.delete_session(&session.id).await?; println!("Session `{}` removed.", session.id); } } else { @@ -76,7 +76,8 @@ pub async fn handle_session_remove( name: Option, regex_string: Option, ) -> Result<()> { - let all_sessions = match SessionManager::list_sessions().await { + let session_manager = SessionManager::instance(); + let all_sessions = match session_manager.list_sessions().await { Ok(sessions) => sessions, Err(e) => { tracing::error!("Failed to retrieve sessions: {:?}", e); @@ -125,7 +126,7 @@ pub async fn handle_session_remove( return Ok(()); } - remove_sessions(matched_sessions).await + remove_sessions(&session_manager, matched_sessions).await } pub async fn handle_session_list( @@ -134,7 +135,8 @@ pub async fn handle_session_list( working_dir: Option, limit: Option, ) -> Result<()> { - let mut sessions = SessionManager::list_sessions().await?; + let session_manager = SessionManager::instance(); + let mut sessions = session_manager.list_sessions().await?; if let Some(ref pat) = working_dir { let pat_lower = pat.to_string_lossy().to_lowercase(); @@ -181,7 +183,8 @@ pub async fn handle_session_export( output_path: Option, format: String, ) -> Result<()> { - let session = match SessionManager::get_session(&session_id, true).await { + let session_manager = SessionManager::instance(); + let session = match session_manager.get_session(&session_id, true).await { Ok(session) => session, Err(e) => { return Err(anyhow::anyhow!( @@ -222,12 +225,15 @@ pub async fn handle_diagnostics(session_id: &str, output_path: Option) session_id ); - let diagnostics_data = generate_diagnostics(session_id).await.with_context(|| { - format!( - "Failed to write to generate diagnostics bundle for session '{}'", - session_id - ) - })?; + let session_manager = SessionManager::instance(); + let diagnostics_data = generate_diagnostics(&session_manager, session_id) + .await + .with_context(|| { + format!( + "Failed to write to generate diagnostics bundle for session '{}'", + session_id + ) + })?; let output_file = if let Some(path) = output_path { path.clone() @@ -319,8 +325,10 @@ fn export_session_to_markdown( /// Prompt the user to interactively select a session /// /// Shows a list of available sessions and lets the user select one -pub async fn prompt_interactive_session_selection() -> Result { - let sessions = SessionManager::list_sessions().await?; +pub async fn prompt_interactive_session_selection( + session_manager: &SessionManager, +) -> Result { + let sessions = session_manager.list_sessions().await?; if sessions.is_empty() { return Err(anyhow::anyhow!("No sessions found")); diff --git a/crates/goose-cli/src/commands/term.rs b/crates/goose-cli/src/commands/term.rs index c9bdfc859099..16438e3b4b28 100644 --- a/crates/goose-cli/src/commands/term.rs +++ b/crates/goose-cli/src/commands/term.rs @@ -1,8 +1,7 @@ use anyhow::{anyhow, Result}; use chrono; use goose::conversation::message::{Message, MessageContent, MessageMetadata}; -use goose::session::SessionManager; -use goose::session::SessionType; +use goose::session::{SessionManager, SessionType}; use rmcp::model::Role; use crate::session::{build_session, SessionBuilderConfig}; @@ -119,10 +118,13 @@ pub async fn handle_term_init( with_command_not_found: bool, ) -> Result<()> { let config = shell.config(); + let session_manager = SessionManager::instance(); let working_dir = std::env::current_dir()?; let named_session = if let Some(ref name) = name { - let sessions = SessionManager::list_sessions_by_types(&[SessionType::Terminal]).await?; + let sessions = session_manager + .list_sessions_by_types(&[SessionType::Terminal]) + .await?; sessions.into_iter().find(|s| s.name == *name) } else { None @@ -131,15 +133,17 @@ pub async fn handle_term_init( let session = match named_session { Some(s) => s, None => { - let session = SessionManager::create_session( - working_dir, - "Goose Term Session".to_string(), - SessionType::Terminal, - ) - .await?; + let session = session_manager + .create_session( + working_dir, + "Goose Term Session".to_string(), + SessionType::Terminal, + ) + .await?; if let Some(name) = name { - SessionManager::update_session(&session.id) + session_manager + .update(&session.id) .user_provided_name(name) .apply() .await?; @@ -184,7 +188,8 @@ pub async fn handle_term_log(command: String) -> Result<()> { ) .with_metadata(MessageMetadata::user_only()); - SessionManager::add_message(&session_id, &message).await?; + let session_manager = SessionManager::instance(); + session_manager.add_message(&session_id, &message).await?; Ok(()) } @@ -201,13 +206,15 @@ pub async fn handle_term_run(prompt: Vec) -> Result<()> { })?; let working_dir = std::env::current_dir()?; + let session_manager = SessionManager::instance(); - SessionManager::update_session(&session_id) + session_manager + .update(&session_id) .working_dir(working_dir) .apply() .await?; - let session = SessionManager::get_session(&session_id, true).await?; + let session = session_manager.get_session(&session_id, true).await?; let user_messages_after_last_assistant: Vec<&Message> = if let Some(conv) = &session.conversation { conv.messages() @@ -220,7 +227,9 @@ pub async fn handle_term_run(prompt: Vec) -> Result<()> { }; if let Some(oldest_user) = user_messages_after_last_assistant.last() { - SessionManager::truncate_conversation(&session_id, oldest_user.created).await?; + session_manager + .truncate_conversation(&session_id, oldest_user.created) + .await?; } let prompt_with_context = if user_messages_after_last_assistant.is_empty() { @@ -260,7 +269,8 @@ pub async fn handle_term_info() -> Result<()> { Err(_) => return Ok(()), }; - let session = SessionManager::get_session(&session_id, false).await.ok(); + let session_manager = SessionManager::instance(); + let session = session_manager.get_session(&session_id, false).await.ok(); let total_tokens = session.as_ref().and_then(|s| s.total_tokens).unwrap_or(0) as usize; let config = goose::config::Config::global(); diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index c6aa47f6b82e..7aece83d7685 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -16,7 +16,6 @@ use futures::{sink::SinkExt, stream::StreamExt}; use goose::agents::{Agent, AgentEvent}; use goose::conversation::message::Message as GooseMessage; use goose::session::session_manager::SessionType; -use goose::session::SessionManager; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::{net::SocketAddr, sync::Arc}; @@ -153,15 +152,18 @@ pub async fn handle_web( let model_config = goose::model::ModelConfig::new(&model)?; - let init_session = SessionManager::create_session( - std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), - "Web Agent Initialization".to_string(), - SessionType::Hidden, - ) - .await?; - let agent = Agent::new(); let provider = goose::providers::create(&provider_name, model_config).await?; + + let init_session = agent + .session_manager() + .create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "Web Agent Initialization".to_string(), + SessionType::Hidden, + ) + .await?; + agent.update_provider(provider, &init_session.id).await?; let enabled_configs = goose::config::get_enabled_extensions(); @@ -240,14 +242,20 @@ pub async fn handle_web( Ok(()) } -async fn serve_index(uri: Uri) -> Result { - let session = SessionManager::create_session( - std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), - "Web session".to_string(), - SessionType::User, - ) - .await - .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?; +async fn serve_index( + State(state): State, + uri: Uri, +) -> Result { + let session = state + .agent + .session_manager() + .create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "Web session".to_string(), + SessionType::User, + ) + .await + .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?; let redirect_url = if let Some(query) = uri.query() { format!("/session/{}?{}", session.id, query) @@ -307,8 +315,8 @@ async fn health_check() -> Json { })) } -async fn list_sessions() -> Json { - match SessionManager::list_sessions().await { +async fn list_sessions(State(state): State) -> Json { + match state.agent.session_manager().list_sessions().await { Ok(sessions) => { let mut session_info = Vec::new(); @@ -331,9 +339,15 @@ async fn list_sessions() -> Json { } } async fn get_session( + State(state): State, axum::extract::Path(session_id): axum::extract::Path, ) -> Json { - match SessionManager::get_session(&session_id, true).await { + match state + .agent + .session_manager() + .get_session(&session_id, true) + .await + { Ok(session) => Json(serde_json::json!({ "metadata": session, "messages": session.conversation.unwrap_or_default().messages() @@ -505,7 +519,10 @@ async fn process_message_streaming( return Ok(()); } - let session = SessionManager::get_session(&session_id, true).await?; + let session = agent + .session_manager() + .get_session(&session_id, true) + .await?; let mut messages = session.conversation.unwrap_or_default(); messages.push(user_message.clone()); diff --git a/crates/goose-cli/src/scenario_tests/mock_client.rs b/crates/goose-cli/src/scenario_tests/mock_client.rs index 1314f251e76a..54a1917a39e3 100644 --- a/crates/goose-cli/src/scenario_tests/mock_client.rs +++ b/crates/goose-cli/src/scenario_tests/mock_client.rs @@ -94,6 +94,7 @@ impl McpClientTrait for MockClient { &self, name: &str, arguments: Option>, + _session_id: &str, _cancel_token: CancellationToken, ) -> Result { if let Some(handler) = self.handlers.get(name) { diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 5dae5c2b62db..db4fe9f22042 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -6,14 +6,18 @@ use crate::scenario_tests::mock_client::weather_client; use crate::scenario_tests::provider_configs::{get_provider_configs, ProviderConfig}; use crate::session::CliSession; use anyhow::Result; -use goose::agents::Agent; +use goose::agents::{Agent, AgentConfig}; +use goose::config::permission::PermissionManager; +use goose::config::GooseMode; use goose::model::ModelConfig; use goose::providers::{create, testprovider::TestProvider}; +use goose::scheduler_trait::unavailable_scheduler; use goose::session::session_manager::SessionType; use goose::session::SessionManager; use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use std::sync::Arc; +use tempfile::TempDir; use tokio_util::sync::CancellationToken; pub const SCENARIO_TESTS_DIR: &str = "src/scenario_tests"; @@ -198,7 +202,16 @@ where let mock_client = weather_client(); - let agent = Agent::new(); + let temp_dir = TempDir::new()?; + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let permission_manager = Arc::new(PermissionManager::new(temp_dir.path().to_path_buf())); + let agent_config = AgentConfig::new( + session_manager, + permission_manager, + GooseMode::Auto, + unavailable_scheduler(), + ); + let agent = Agent::with_config(agent_config); agent .extension_manager .add_client( @@ -217,12 +230,14 @@ where ) .await; - let session = SessionManager::create_session( - PathBuf::default(), - "scenario-runner".to_string(), - SessionType::Hidden, - ) - .await?; + let session = agent + .session_manager() + .create_session( + PathBuf::default(), + "scenario-runner".to_string(), + SessionType::Hidden, + ) + .await?; agent .update_provider( diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 11e15f7717b0..5a4eeb32df66 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -10,9 +10,7 @@ use goose::config::{ use goose::providers::create; use goose::recipe::{Response, SubRecipe}; -use goose::agents::extension::PlatformExtensionContext; use goose::session::session_manager::SessionType; -use goose::session::SessionManager; use goose::session::{EnabledExtensionsState, ExtensionState}; use rustyline::EditMode; use std::collections::HashSet; @@ -147,12 +145,14 @@ async fn offer_extension_debugging_help( // Create a minimal agent for debugging let debug_agent = Agent::new(); - let session = SessionManager::create_session( - std::env::current_dir()?, - "CLI Session".to_string(), - SessionType::Hidden, - ) - .await?; + let session = debug_agent + .session_manager() + .create_session( + std::env::current_dir()?, + "CLI Session".to_string(), + SessionType::Hidden, + ) + .await?; debug_agent.update_provider(provider, &session.id).await?; @@ -241,10 +241,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { goose::posthog::set_session_context("cli", session_config.resume); let config = Config::global(); + let agent: Agent = Agent::new(); + let session_manager = agent.session_manager(); let (saved_provider, saved_model_config) = if session_config.resume { if let Some(ref session_id) = session_config.session_id { - match SessionManager::get_session(session_id, false).await { + match session_manager.get_session(session_id, false).await { Ok(session_data) => (session_data.provider_name, session_data.model_config), Err(_) => (None, None), } @@ -299,8 +301,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { .with_temperature(temperature) }; - let agent: Agent = Agent::new(); - agent .apply_recipe_components( session_config.sub_recipes, @@ -337,17 +337,14 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { let session_id: String = if session_config.no_session { let working_dir = std::env::current_dir().expect("Could not get working directory"); - let session = SessionManager::create_session( - working_dir, - "CLI Session".to_string(), - SessionType::Hidden, - ) - .await - .expect("Could not create session"); + let session = session_manager + .create_session(working_dir, "CLI Session".to_string(), SessionType::Hidden) + .await + .expect("Could not create session"); session.id } else if session_config.resume { if let Some(session_id) = session_config.session_id { - match SessionManager::get_session(&session_id, false).await { + match session_manager.get_session(&session_id, false).await { Ok(_) => session_id, Err(_) => { output::render_error(&format!( @@ -358,7 +355,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { } } } else { - match SessionManager::list_sessions().await { + match session_manager.list_sessions().await { Ok(sessions) if !sessions.is_empty() => sessions[0].id.clone(), _ => { output::render_error("Cannot resume - no previous sessions found"); @@ -378,16 +375,10 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { process::exit(1); }); - agent - .extension_manager - .set_context(PlatformExtensionContext { - session_id: Some(session_id.clone()), - extension_manager: Some(Arc::downgrade(&agent.extension_manager)), - }) - .await; - if session_config.resume { - let session = SessionManager::get_session(&session_id, false) + let session = agent + .session_manager() + .get_session(&session_id, false) .await .unwrap_or_else(|e| { output::render_error(&format!("Failed to read session metadata: {}", e)); @@ -428,7 +419,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override { extensions.into_iter().collect() } else if session_config.resume { - match SessionManager::get_session(&session_id, false).await { + match agent + .session_manager() + .get_session(&session_id, false) + .await + { Ok(session_data) => { if let Some(saved_state) = EnabledExtensionsState::from_extension_data(&session_data.extension_data) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 8f63894a104c..74241fb99a17 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -33,7 +33,6 @@ use goose::agents::extension::{Envs, ExtensionConfig, PLATFORM_EXTENSIONS}; use goose::agents::types::RetryConfig; use goose::agents::{Agent, SessionConfig, COMPACT_TRIGGERS}; use goose::config::{Config, GooseMode}; -use goose::session::SessionManager; use input::InputResult; use rmcp::model::PromptMessage; use rmcp::model::ServerNotification; @@ -229,7 +228,9 @@ impl CliSession { retry_config: Option, output_format: String, ) -> Self { - let messages = SessionManager::get_session(&session_id, true) + let messages = agent + .session_manager() + .get_session(&session_id, true) .await .map(|session| session.conversation.unwrap_or_default()) .unwrap(); @@ -668,14 +669,20 @@ impl CliSession { } async fn handle_clear(&mut self) -> Result<()> { - if let Err(e) = - SessionManager::replace_conversation(&self.session_id, &Conversation::default()).await + if let Err(e) = self + .agent + .session_manager() + .replace_conversation(&self.session_id, &Conversation::default()) + .await { output::render_error(&format!("Failed to clear session: {}", e)); return Ok(()); } - if let Err(e) = SessionManager::update_session(&self.session_id) + if let Err(e) = self + .agent + .session_manager() + .update(&self.session_id) .total_tokens(Some(0)) .input_tokens(Some(0)) .output_tokens(Some(0)) @@ -1270,7 +1277,12 @@ impl CliSession { // Output based on format if is_json_mode { - let metadata = match SessionManager::get_session(&self.session_id, false).await { + let metadata = match self + .agent + .session_manager() + .get_session(&self.session_id, false) + .await + { Ok(session) => JsonMetadata { total_tokens: session.total_tokens, status: "completed".to_string(), @@ -1288,7 +1300,10 @@ impl CliSession { println!("{}", serde_json::to_string_pretty(&json_output)?); } else if is_stream_json_mode { - let total_tokens = SessionManager::get_session(&self.session_id, false) + let total_tokens = self + .agent + .session_manager() + .get_session(&self.session_id, false) .await .ok() .and_then(|s| s.total_tokens); @@ -1458,7 +1473,10 @@ impl CliSession { } pub async fn get_session(&self) -> Result { - SessionManager::get_session(&self.session_id, false).await + self.agent + .session_manager() + .get_session(&self.session_id, false) + .await } // Get the session's total token usage diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 2e6a87b33ff3..0e0a8c0d3a7c 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -12,17 +12,17 @@ workspace = true [dependencies] rmcp = { workspace = true, features = ["server", "client", "transport-io", "macros"] } -anyhow = "1.0.94" -tokio = { version = "1", features = ["full"] } +anyhow = { workspace = true } +tokio = { workspace = true } tokio-stream = { version = "0.1", features = ["io-util"] } -tracing = "0.1" +tracing = { workspace = true } tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" url = "2.5" base64 = "0.21" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { workspace = true } schemars = "1.0" lazy_static = "1.5" shellexpand = "3.1.0" @@ -39,7 +39,7 @@ tempfile = "3.8" include_dir = "0.7.4" webbrowser = {workspace = true} http-body-util = "0.1.2" -regex = "1.11.1" +regex = { workspace = true } once_cell = "1.20.2" ignore = { workspace = true } lopdf = "0.36.0" @@ -78,12 +78,11 @@ mpatch = "=0.2.0" tokio-util = "0.7.16" clap = { version = "4", features = ["derive"] } - [dev-dependencies] -serial_test = "3.0.0" sysinfo = "0.32.1" temp-env = "0.3.6" colored = "2" +serial_test = { workspace = true } [features] utoipa = ["dep:utoipa"] diff --git a/crates/goose-mcp/src/lib.rs b/crates/goose-mcp/src/lib.rs index a096a1cea3df..b940ec4a173e 100644 --- a/crates/goose-mcp/src/lib.rs +++ b/crates/goose-mcp/src/lib.rs @@ -1,5 +1,7 @@ use etcetera::AppStrategyArgs; use once_cell::sync::Lazy; +use rmcp::{ServerHandler, ServiceExt}; +use std::collections::HashMap; pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { top_level_domain: "Block".to_string(), @@ -19,3 +21,52 @@ pub use computercontroller::ComputerControllerServer; pub use developer::rmcp_developer::DeveloperServer; pub use memory::MemoryServer; pub use tutorial::TutorialServer; + +pub type SpawnServerFn = fn(tokio::io::DuplexStream, tokio::io::DuplexStream); + +pub struct BuiltinDef { + pub name: &'static str, + pub spawn_server: SpawnServerFn, +} + +fn spawn_and_serve( + name: &'static str, + server: S, + transport: (tokio::io::DuplexStream, tokio::io::DuplexStream), +) where + S: ServerHandler + Send + 'static, +{ + tokio::spawn(async move { + match server.serve(transport).await { + Ok(running) => { + let _ = running.waiting().await; + } + Err(e) => tracing::error!(builtin = name, error = %e, "server error"), + } + }); +} + +macro_rules! builtin { + ($name:ident, $server_ty:ty) => {{ + fn spawn(r: tokio::io::DuplexStream, w: tokio::io::DuplexStream) { + spawn_and_serve(stringify!($name), <$server_ty>::new(), (r, w)); + } + ( + stringify!($name), + BuiltinDef { + name: stringify!($name), + spawn_server: spawn, + }, + ) + }}; +} + +pub static BUILTIN_EXTENSIONS: Lazy> = Lazy::new(|| { + HashMap::from([ + builtin!(developer, DeveloperServer), + builtin!(autovisualiser, AutoVisualiserRouter), + builtin!(computercontroller, ComputerControllerServer), + builtin!(memory, MemoryServer), + builtin!(tutorial, TutorialServer), + ]) +}); diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index f990b3882e7a..549601e124da 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -16,17 +16,17 @@ goose-mcp = { path = "../goose-mcp" } rmcp = { workspace = true } schemars = "1.0" axum = { version = "0.8.1", features = ["ws", "macros"] } -tokio = { version = "1.43", features = ["full"] } +tokio = { workspace = true } chrono = "0.4" tower-http = { version = "0.5", features = ["cors"] } serde = { version = "1.0", features = ["derive"] } -serde_json = { version = "1.0", features = ["preserve_order"] } -futures = "0.3" -tracing = "0.1" +serde_json = { workspace = true, features = ["preserve_order"] } +futures = { workspace = true } +tracing = { workspace = true } tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } tracing-appender = "0.2" tokio-stream = "0.1" -anyhow = "1.0" +anyhow = { workspace = true } bytes = "1.5" http = "1.0" base64 = "0.21" @@ -62,4 +62,5 @@ path = "src/bin/generate_schema.rs" tower = "0.5" async-trait = "0.1.89" tempfile = "3.15.0" -env-lock = "1.0.1" +env-lock = { workspace = true } +wiremock = { workspace = true } diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 70b871fa7048..30f1c11d7c9f 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -10,7 +10,6 @@ use axum::{ routing::{get, post}, Json, Router, }; -use goose::config::PermissionManager; use base64::Engine; use goose::agents::ExtensionConfig; @@ -21,7 +20,7 @@ use goose::providers::create; use goose::recipe::Recipe; use goose::recipe_deeplink; use goose::session::session_manager::SessionType; -use goose::session::{Session, SessionManager}; +use goose::session::Session; use goose::{ agents::{extension::ToolInfo, extension_manager::get_parameter_names}, config::permission::PermissionLevel, @@ -179,20 +178,23 @@ async fn start_agent( let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1; let name = format!("New session {}", counter); - let mut session = - SessionManager::create_session(PathBuf::from(&working_dir), name, SessionType::User) - .await - .map_err(|err| { - error!("Failed to create session: {}", err); - goose::posthog::emit_error("session_create_failed", &err.to_string()); - ErrorResponse { - message: format!("Failed to create session: {}", err), - status: StatusCode::BAD_REQUEST, - } - })?; + let manager = state.session_manager(); + + let mut session = manager + .create_session(PathBuf::from(&working_dir), name, SessionType::User) + .await + .map_err(|err| { + error!("Failed to create session: {}", err); + goose::posthog::emit_error("session_create_failed", &err.to_string()); + ErrorResponse { + message: format!("Failed to create session: {}", err), + status: StatusCode::BAD_REQUEST, + } + })?; if let Some(recipe) = original_recipe { - SessionManager::update_session(&session.id) + manager + .update(&session.id) .recipe(Some(recipe)) .apply() .await @@ -204,7 +206,8 @@ async fn start_agent( } })?; - session = SessionManager::get_session(&session.id, false) + session = manager + .get_session(&session.id, false) .await .map_err(|err| { error!("Failed to get updated session: {}", err); @@ -235,7 +238,9 @@ async fn resume_agent( ) -> Result, ErrorResponse> { goose::posthog::set_session_context("desktop", true); - let session = SessionManager::get_session(&payload.session_id, true) + let session = state + .session_manager() + .get_session(&payload.session_id, true) .await .map_err(|err| { error!("Failed to resume session {}: {}", payload.session_id, err); @@ -353,7 +358,9 @@ async fn update_from_session( message: format!("Failed to get agent: {}", status), status, })?; - let session = SessionManager::get_session(&payload.session_id, false) + let session = state + .session_manager() + .get_session(&payload.session_id, false) .await .map_err(|err| ErrorResponse { message: format!("Failed to get session: {}", err), @@ -411,11 +418,12 @@ async fn get_tools( ) -> Result>, StatusCode> { let config = Config::global(); let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto); - let agent = state.get_agent_for_route(query.session_id).await?; - let permission_manager = PermissionManager::default(); + let session_id = query.session_id; + let agent = state.get_agent_for_route(session_id.clone()).await?; + let permission_manager = agent.permission_manager(); let mut tools: Vec = agent - .list_tools(query.extension_name) + .list_tools(&session_id, query.extension_name) .await .into_iter() .map(|tool| { @@ -681,7 +689,7 @@ async fn call_tool( let tool_result = agent .extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call(&payload.session_id, tool_call, CancellationToken::default()) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; diff --git a/crates/goose-server/src/routes/audio.rs b/crates/goose-server/src/routes/audio.rs index 73434b0abf4c..cfd61d44205b 100644 --- a/crates/goose-server/src/routes/audio.rs +++ b/crates/goose-server/src/routes/audio.rs @@ -392,14 +392,25 @@ mod tests { use super::*; use axum::{body::Body, http::Request}; use tower::ServiceExt; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[tokio::test(flavor = "multi_thread")] async fn test_transcribe_endpoint_requires_auth() { - let _guard = env_lock::lock_env([("OPENAI_API_KEY", Some("fake-openai-no-keyring"))]); + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/audio/transcriptions")) + .respond_with(ResponseTemplate::new(401)) + .mount(&mock_server) + .await; + + let _guard = env_lock::lock_env([ + ("OPENAI_API_KEY", Some("fake-key")), + ("OPENAI_HOST", Some(mock_server.uri().as_str())), + ]); let state = AppState::new().await.unwrap(); let app = routes(state); - // Test without auth header let request = Request::builder() .uri("/audio/transcribe") .method("POST") @@ -414,10 +425,7 @@ mod tests { .unwrap(); let response = app.oneshot(request).await.unwrap(); - assert!( - response.status() == StatusCode::PRECONDITION_FAILED - || response.status() == StatusCode::UNAUTHORIZED - ); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } #[tokio::test(flavor = "multi_thread")] diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 409bf4d0d97a..d1d5a18902a4 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -557,7 +557,7 @@ pub async fn init_config() -> Result, StatusCode> { pub async fn upsert_permissions( Json(query): Json, ) -> Result, StatusCode> { - let mut permission_manager = goose::config::PermissionManager::default(); + let permission_manager = goose::config::PermissionManager::instance(); for tool_permission in &query.tool_permissions { permission_manager.update_user_permission( diff --git a/crates/goose-server/src/routes/mod.rs b/crates/goose-server/src/routes/mod.rs index 03039e98df15..aa9392b92fda 100644 --- a/crates/goose-server/src/routes/mod.rs +++ b/crates/goose-server/src/routes/mod.rs @@ -23,7 +23,7 @@ use axum::Router; // Function to configure all routes pub fn configure(state: Arc, secret_key: String) -> Router { Router::new() - .merge(status::routes()) + .merge(status::routes(state.clone())) .merge(reply::routes(state.clone())) .merge(action_required::routes(state.clone())) .merge(agent::routes(state.clone())) diff --git a/crates/goose-server/src/routes/recipe.rs b/crates/goose-server/src/routes/recipe.rs index 566b6e037190..34a4f0e73dc6 100644 --- a/crates/goose-server/src/routes/recipe.rs +++ b/crates/goose-server/src/routes/recipe.rs @@ -9,7 +9,6 @@ use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; use goose::recipe::local_recipes; use goose::recipe::validate_recipe::validate_recipe_template_from_content; use goose::recipe::Recipe; -use goose::session::SessionManager; use goose::{recipe_deeplink, slash_commands}; use serde::{Deserialize, Serialize}; @@ -168,7 +167,11 @@ async fn create_recipe( request.session_id ); - let session = match SessionManager::get_session(&request.session_id, true).await { + let session = match state + .session_manager() + .get_session(&request.session_id, true) + .await + { Ok(session) => session, Err(e) => { tracing::error!("Failed to get session: {}", e); diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index eced7bc7bb8c..39a6e16bb1e2 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -146,8 +146,9 @@ pub enum MessageEvent { Ping, } -async fn get_token_state(session_id: &str) -> TokenState { - SessionManager::get_session(session_id, false) +async fn get_token_state(session_manager: &SessionManager, session_id: &str) -> TokenState { + session_manager + .get_session(session_id, false) .await .map(|session| TokenState { input_tokens: session.input_tokens.unwrap_or(0), @@ -258,7 +259,7 @@ pub async fn reply( } }; - let session = match SessionManager::get_session(&session_id, true).await { + let session = match state.session_manager().get_session(&session_id, true).await { Ok(metadata) => metadata, Err(e) => { tracing::error!("Failed to read session for {}: {}", session_id, e); @@ -284,7 +285,11 @@ pub async fn reply( let mut all_messages = match conversation_so_far { Some(history) => { let conv = Conversation::new_unvalidated(history); - if let Err(e) = SessionManager::replace_conversation(&session_id, &conv).await { + if let Err(e) = state + .session_manager() + .replace_conversation(&session_id, &conv) + .await + { tracing::warn!( "Failed to replace session conversation for {}: {}", session_id, @@ -339,7 +344,7 @@ pub async fn reply( all_messages.push(message.clone()); - let token_state = get_token_state(&session_id).await; + let token_state = get_token_state(state.session_manager(), &session_id).await; stream_event(MessageEvent::Message { message, token_state }, &tx, &cancel_token).await; } @@ -385,7 +390,7 @@ pub async fn reply( let session_duration = session_start.elapsed(); - if let Ok(session) = SessionManager::get_session(&session_id, true).await { + if let Ok(session) = state.session_manager().get_session(&session_id, true).await { let total_tokens = session.total_tokens.unwrap_or(0); tracing::info!( counter.goose.session_completions = 1, @@ -433,7 +438,7 @@ pub async fn reply( ); } - let final_token_state = get_token_state(&session_id).await; + let final_token_state = get_token_state(state.session_manager(), &session_id).await; let _ = stream_event( MessageEvent::Finish { diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index 09dc95b1f3a7..1606724df755 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -11,7 +11,7 @@ use axum::{ }; use goose::recipe::Recipe; use goose::session::session_manager::SessionInsights; -use goose::session::{Session, SessionManager}; +use goose::session::Session; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; @@ -89,8 +89,12 @@ const MAX_NAME_LENGTH: usize = 200; ), tag = "Session Management" )] -async fn list_sessions() -> Result, StatusCode> { - let sessions = SessionManager::list_sessions() +async fn list_sessions( + State(state): State>, +) -> Result, StatusCode> { + let sessions = state + .session_manager() + .list_sessions() .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; @@ -114,8 +118,13 @@ async fn list_sessions() -> Result, StatusCode> { ), tag = "Session Management" )] -async fn get_session(Path(session_id): Path) -> Result, StatusCode> { - let session = SessionManager::get_session(&session_id, true) +async fn get_session( + State(state): State>, + Path(session_id): Path, +) -> Result, StatusCode> { + let session = state + .session_manager() + .get_session(&session_id, true) .await .map_err(|_| StatusCode::NOT_FOUND)?; @@ -134,8 +143,12 @@ async fn get_session(Path(session_id): Path) -> Result, St ), tag = "Session Management" )] -async fn get_session_insights() -> Result, StatusCode> { - let insights = SessionManager::get_insights() +async fn get_session_insights( + State(state): State>, +) -> Result, StatusCode> { + let insights = state + .session_manager() + .get_insights() .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(insights)) @@ -161,6 +174,7 @@ async fn get_session_insights() -> Result, StatusCode> { tag = "Session Management" )] async fn update_session_name( + State(state): State>, Path(session_id): Path, Json(request): Json, ) -> Result { @@ -172,7 +186,9 @@ async fn update_session_name( return Err(StatusCode::BAD_REQUEST); } - SessionManager::update_session(&session_id) + state + .session_manager() + .update(&session_id) .user_provided_name(name.to_string()) .apply() .await @@ -205,7 +221,9 @@ async fn update_session_user_recipe_values( Path(session_id): Path, Json(request): Json, ) -> Result, ErrorResponse> { - SessionManager::update_session(&session_id) + state + .session_manager() + .update(&session_id) .user_recipe_values(Some(request.user_recipe_values)) .apply() .await @@ -214,7 +232,9 @@ async fn update_session_user_recipe_values( status: StatusCode::INTERNAL_SERVER_ERROR, })?; - let session = SessionManager::get_session(&session_id, false) + let session = state + .session_manager() + .get_session(&session_id, false) .await .map_err(|err| ErrorResponse { message: err.to_string(), @@ -268,8 +288,13 @@ async fn update_session_user_recipe_values( ), tag = "Session Management" )] -async fn delete_session(Path(session_id): Path) -> Result { - SessionManager::delete_session(&session_id) +async fn delete_session( + State(state): State>, + Path(session_id): Path, +) -> Result { + state + .session_manager() + .delete_session(&session_id) .await .map_err(|e| { if e.to_string().contains("not found") { @@ -299,8 +324,13 @@ async fn delete_session(Path(session_id): Path) -> Result) -> Result, StatusCode> { - let exported = SessionManager::export_session(&session_id) +async fn export_session( + State(state): State>, + Path(session_id): Path, +) -> Result, StatusCode> { + let exported = state + .session_manager() + .export_session(&session_id) .await .map_err(|_| StatusCode::NOT_FOUND)?; @@ -323,9 +353,12 @@ async fn export_session(Path(session_id): Path) -> Result, tag = "Session Management" )] async fn import_session( + State(state): State>, Json(request): Json, ) -> Result, StatusCode> { - let session = SessionManager::import_session(&request.json) + let session = state + .session_manager() + .import_session(&request.json) .await .map_err(|_| StatusCode::BAD_REQUEST)?; @@ -352,12 +385,15 @@ async fn import_session( tag = "Session Management" )] async fn edit_message( + State(state): State>, Path(session_id): Path, Json(request): Json, ) -> Result, StatusCode> { + let manager = state.session_manager(); match request.edit_type { EditType::Fork => { - let new_session = SessionManager::copy_session(&session_id, "(edited)".to_string()) + let new_session = manager + .copy_session(&session_id, "(edited)".to_string()) .await .map_err(|e| { tracing::error!("Failed to copy session: {}", e); @@ -365,7 +401,8 @@ async fn edit_message( StatusCode::INTERNAL_SERVER_ERROR })?; - SessionManager::truncate_conversation(&new_session.id, request.timestamp) + manager + .truncate_conversation(&new_session.id, request.timestamp) .await .map_err(|e| { tracing::error!("Failed to truncate conversation: {}", e); @@ -378,7 +415,8 @@ async fn edit_message( })) } EditType::Edit => { - SessionManager::truncate_conversation(&session_id, request.timestamp) + manager + .truncate_conversation(&session_id, request.timestamp) .await .map_err(|e| { tracing::error!("Failed to truncate conversation: {}", e); diff --git a/crates/goose-server/src/routes/status.rs b/crates/goose-server/src/routes/status.rs index 110d4c7ad87c..306c02f82ebf 100644 --- a/crates/goose-server/src/routes/status.rs +++ b/crates/goose-server/src/routes/status.rs @@ -1,8 +1,12 @@ use axum::body::Body; +use axum::extract::State; use axum::http::HeaderValue; use axum::response::IntoResponse; use axum::{extract::Path, http::StatusCode, routing::get, Router}; use goose::session::generate_diagnostics; +use std::sync::Arc; + +use crate::state::AppState; #[utoipa::path(get, path = "/status", responses( @@ -19,8 +23,11 @@ async fn status() -> String { (status = 500, description = "Failed to generate diagnostics"), ) )] -async fn diagnostics(Path(session_id): Path) -> impl IntoResponse { - match generate_diagnostics(&session_id).await { +async fn diagnostics( + State(state): State>, + Path(session_id): Path, +) -> impl IntoResponse { + match generate_diagnostics(state.session_manager(), &session_id).await { Ok(zip_data) => { let filename = format!("attachment; filename=\"diagnostics_{}.zip\"", session_id); let headers = [ @@ -39,8 +46,9 @@ async fn diagnostics(Path(session_id): Path) -> impl IntoResponse { Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } -pub fn routes() -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/status", get(status)) .route("/diagnostics/{session_id}", get(diagnostics)) + .with_state(state) } diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 4a9c582e39e2..558c40c84b68 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,6 +1,7 @@ use axum::http::StatusCode; use goose::execution::manager::AgentManager; use goose::scheduler_trait::SchedulerTrait; +use goose::session::SessionManager; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::atomic::AtomicUsize; @@ -37,6 +38,10 @@ impl AppState { self.agent_manager.scheduler() } + pub fn session_manager(&self) -> &SessionManager { + self.agent_manager.session_manager() + } + pub async fn set_recipe_file_hash_map(&self, hash_map: HashMap) { let mut map = self.recipe_file_hash_map.lock().await; *map = hash_map; diff --git a/crates/goose-test/Cargo.toml b/crates/goose-test/Cargo.toml index f42dbfd5d5a7..feb1755c37f8 100644 --- a/crates/goose-test/Cargo.toml +++ b/crates/goose-test/Cargo.toml @@ -16,4 +16,4 @@ path = "src/bin/capture.rs" [dependencies] clap = { version = "4.5.44", features = ["derive"] } -serde_json = "1.0.142" +serde_json = { workspace = true } diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 9a2f93eb81e3..e9eb2ef8786d 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -11,7 +11,7 @@ description.workspace = true workspace = true [build-dependencies] -tokio = { version = "1.43", features = ["full"] } +tokio = { workspace = true } reqwest = { version = "0.12.9", features = ["json", "rustls-tls-native-roots"], default-features = false } [dependencies] @@ -23,9 +23,9 @@ rmcp = { workspace = true, features = [ "transport-streamable-http-client", "transport-streamable-http-client-reqwest", ] } -anyhow = "1.0" +anyhow = { workspace = true } thiserror = "1.0" -futures = "0.3" +futures = { workspace = true } dirs = "5.0" reqwest = { version = "0.12.9", features = [ "rustls-tls-native-roots", @@ -40,13 +40,13 @@ reqwest = { version = "0.12.9", features = [ "stream", "blocking" ], default-features = false } -tokio = { version = "1.43", features = ["full"] } +tokio = { workspace = true } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { workspace = true } serde_urlencoded = "0.7" jsonschema = "0.30.0" uuid = { version = "1.0", features = ["v4"] } -regex = "1.11.1" +regex = { workspace = true } async-trait = "0.1.89" async-stream = "0.3" minijinja = { version = "2.12.0", features = ["loader"] } @@ -62,7 +62,7 @@ url = "2.5" axum = "0.8.1" webbrowser = {workspace = true} lazy_static = "1.5.0" -tracing = "0.1" +tracing = { workspace = true } tracing-subscriber = "0.3" tracing-opentelemetry = "0.28" opentelemetry = "0.27" @@ -97,8 +97,9 @@ tokio-stream = "0.1.17" tempfile = "3.15.0" dashmap = "6.1" ahash = "0.8" -tokio-util = "0.7.15" +tokio-util = { version = "0.7.15", features = ["compat"] } unicode-normalization = "0.1" +goose-mcp = { path = "../goose-mcp" } zip = "0.6" sys-info = "0.9" @@ -119,19 +120,17 @@ unbinder = "0.1.7" winapi = { version = "0.3", features = ["wincred"] } [dev-dependencies] -sacp = { workspace = true } -agent-client-protocol-schema = "0.10.5" criterion = "0.5" -serial_test = "3.2.0" +serial_test = { workspace = true } mockall = "0.13.1" -wiremock = "0.6.0" -tokio = { version = "1.43", features = ["full"] } +wiremock = { workspace = true } +tokio = { workspace = true } tokio-util = { version = "0.7.15", features = ["compat"] } temp-env = "0.3.6" dotenvy = "0.15.7" ctor = "0.2.9" test-case = "3.3" -env-lock = "1.0.1" +env-lock = { workspace = true } rmcp = { workspace = true, features = ["transport-streamable-http-server"] } [[example]] diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index 4e4bb5795903..9a2034494d36 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -6,7 +6,6 @@ use goose::conversation::message::Message; use goose::providers::create_with_named_model; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::session::session_manager::SessionType; -use goose::session::SessionManager; use std::path::PathBuf; #[tokio::main] @@ -17,12 +16,14 @@ async fn main() -> anyhow::Result<()> { let agent = Agent::new(); - let session = SessionManager::create_session( - PathBuf::default(), - "max-turn-test".to_string(), - SessionType::Hidden, - ) - .await?; + let session = agent + .session_manager() + .create_session( + PathBuf::default(), + "max-turn-test".to_string(), + SessionType::Hidden, + ) + .await?; let _ = agent.update_provider(provider, &session.id).await; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 1fb2165aca50..a310d4f9a63a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -25,6 +25,7 @@ use crate::agents::subagent_tool::{ }; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, SharedProvider, ToolResultReceiver}; +use crate::config::permission::PermissionManager; use crate::config::{get_enabled_extensions, Config, GooseMode}; use crate::context_mgmt::{ check_if_compaction_needed, compact_messages, DEFAULT_COMPACTION_THRESHOLD, @@ -41,7 +42,7 @@ use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe, Response, Settings, SubRecipe}; -use crate::scheduler_trait::SchedulerTrait; +use crate::scheduler_trait::{unavailable_scheduler, SchedulerTrait}; use crate::security::security_inspector::SecurityInspector; use crate::session::extension_data::{EnabledExtensionsState, ExtensionState}; use crate::session::{Session, SessionManager, SessionType}; @@ -77,9 +78,34 @@ pub struct ToolCategorizeResult { pub filtered_response: Message, } +#[derive(Clone)] +pub struct AgentConfig { + pub session_manager: Arc, + pub permission_manager: Arc, + pub goose_mode: GooseMode, + pub scheduler: Arc, +} + +impl AgentConfig { + pub fn new( + session_manager: Arc, + permission_manager: Arc, + goose_mode: GooseMode, + scheduler: Arc, + ) -> Self { + Self { + session_manager, + permission_manager, + goose_mode, + scheduler, + } + } +} + /// The main goose Agent pub struct Agent { pub(super) provider: SharedProvider, + pub(crate) config: AgentConfig, pub extension_manager: Arc, pub(super) sub_recipes: Mutex>, @@ -92,7 +118,6 @@ pub struct Agent { pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult)>, pub(super) tool_result_rx: ToolResultReceiver, - pub(super) scheduler_service: Mutex>>, pub(super) retry_manager: RetryManager, pub(super) tool_inspection_manager: ToolInspectionManager, } @@ -148,6 +173,19 @@ where impl Agent { pub fn new() -> Self { + Self::with_config(AgentConfig::new( + Arc::new(SessionManager::instance()), + PermissionManager::instance(), + Config::global().get_goose_mode().unwrap_or(GooseMode::Auto), + // Historically, only goose-server had a scheduler. We can revisit this + unavailable_scheduler(), + )) + } + + pub fn with_config(config: AgentConfig) -> Self { + let permission_manager = Arc::clone(&config.permission_manager); + let session_manager = Arc::clone(&config.session_manager); + // Create channels with buffer size 32 (adjust if needed) let (confirm_tx, confirm_rx) = mpsc::channel(32); let (tool_tx, tool_rx) = mpsc::channel(32); @@ -155,7 +193,8 @@ impl Agent { Self { provider: provider.clone(), - extension_manager: Arc::new(ExtensionManager::new(provider.clone())), + config, + extension_manager: Arc::new(ExtensionManager::new(provider.clone(), session_manager)), sub_recipes: Mutex::new(HashMap::new()), final_output_tool: Arc::new(Mutex::new(None)), frontend_tools: Mutex::new(HashMap::new()), @@ -165,25 +204,33 @@ impl Agent { confirmation_rx: Mutex::new(confirm_rx), tool_result_tx: tool_tx, tool_result_rx: Arc::new(Mutex::new(tool_rx)), - scheduler_service: Mutex::new(None), retry_manager: RetryManager::new(), - tool_inspection_manager: Self::create_default_tool_inspection_manager(), + tool_inspection_manager: Self::create_tool_inspection_manager(permission_manager), } } + pub fn session_manager(&self) -> Arc { + self.config.session_manager.clone() + } + + pub fn permission_manager(&self) -> Arc { + Arc::clone(&self.config.permission_manager) + } + /// Create a tool inspection manager with default inspectors - fn create_default_tool_inspection_manager() -> ToolInspectionManager { + fn create_tool_inspection_manager( + permission_manager: Arc, + ) -> ToolInspectionManager { let mut tool_inspection_manager = ToolInspectionManager::new(); // Add security inspector (highest priority - runs first) tool_inspection_manager.add_inspector(Box::new(SecurityInspector::new())); // Add permission inspector (medium-high priority) - // Note: mode will be updated dynamically based on session config tool_inspection_manager.add_inspector(Box::new(PermissionInspector::new( - GooseMode::SmartApprove, std::collections::HashSet::new(), // readonly tools - will be populated from extension manager std::collections::HashSet::new(), // regular tools - will be populated from extension manager + permission_manager, ))); // Add repetition inspector (lower priority - basic repetition checking) @@ -230,11 +277,12 @@ impl Agent { | RetryResult::SuccessChecksPassed => Ok(false), } } - async fn drain_elicitation_messages(session_id: &str) -> Vec { + async fn drain_elicitation_messages(&self, session_id: &str) -> Vec { let mut messages = Vec::new(); + let manager = self.session_manager(); let mut elicitation_rx = ActionRequiredManager::global().request_rx.lock().await; while let Ok(elicitation_message) = elicitation_rx.try_recv() { - if let Err(e) = SessionManager::add_message(session_id, &elicitation_message).await { + if let Err(e) = manager.add_message(session_id, &elicitation_message).await { warn!("Failed to save elicitation message to session: {}", e); } messages.push(elicitation_message); @@ -244,6 +292,7 @@ impl Agent { async fn prepare_reply_context( &self, + session_id: &str, unfixed_conversation: Conversation, working_dir: &std::path::Path, ) -> Result { @@ -260,22 +309,17 @@ impl Agent { ); } let initial_messages = conversation.messages().clone(); - let config = Config::global(); - let (tools, toolshim_tools, system_prompt) = - self.prepare_tools_and_prompt(working_dir).await?; - let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto); - - self.tool_inspection_manager - .update_permission_inspector_mode(goose_mode) - .await; + let (tools, toolshim_tools, system_prompt) = self + .prepare_tools_and_prompt(session_id, working_dir) + .await?; Ok(ReplyContext { conversation, tools, toolshim_tools, system_prompt, - goose_mode, + goose_mode: self.config.goose_mode, initial_messages, }) } @@ -359,11 +403,6 @@ impl Agent { } } - pub async fn set_scheduler(&self, scheduler: Arc) { - let mut scheduler_service = self.scheduler_service.lock().await; - *scheduler_service = Some(scheduler); - } - /// Get a reference count clone to the provider pub async fn provider(&self) -> Result, anyhow::Error> { match &*self.provider.lock().await { @@ -496,6 +535,7 @@ impl Agent { .unwrap_or(Value::Object(serde_json::Map::new())); handle_subagent_tool( + &self.config, arguments, task_config, sub_recipes, @@ -513,7 +553,11 @@ impl Agent { // Clone the result to ensure no references to extension_manager are returned let result = self .extension_manager - .dispatch_tool_call(tool_call.clone(), cancellation_token.unwrap_or_default()) + .dispatch_tool_call( + &session.id, + tool_call.clone(), + cancellation_token.unwrap_or_default(), + ) .await; result.unwrap_or_else(|e| { crate::posthog::emit_error( @@ -550,14 +594,18 @@ impl Agent { let extensions_state = EnabledExtensionsState::new(extension_configs); - let mut session_data = SessionManager::get_session(&session.id, false).await?; + let mut session_data = self + .session_manager() + .get_session(&session.id, false) + .await?; if let Err(e) = extensions_state.to_extension_data(&mut session_data.extension_data) { warn!("Failed to serialize extension state: {}", e); return Err(anyhow!("Extension state serialization failed: {}", e)); } - SessionManager::update_session(&session.id) + self.session_manager() + .update(&session.id) .extension_data(session_data.extension_data) .apply() .await?; @@ -602,10 +650,8 @@ impl Agent { Ok(()) } - pub async fn subagents_enabled(&self) -> bool { - let config = crate::config::Config::global(); - let is_autonomous = config.get_goose_mode().unwrap_or(GooseMode::Auto) == GooseMode::Auto; - if !is_autonomous { + pub async fn subagents_enabled(&self, session_id: &str) -> bool { + if self.config.goose_mode != GooseMode::Auto { return false; } if self @@ -616,16 +662,17 @@ impl Agent { { return false; } - if let Some(ref session_id) = self.extension_manager.get_context().await.session_id { - if matches!( - SessionManager::get_session(session_id, false) - .await - .ok() - .map(|session| session.session_type), - Some(SessionType::SubAgent) - ) { - return false; - } + let context = self.extension_manager.get_context(); + if matches!( + context + .session_manager + .get_session(session_id, false) + .await + .ok() + .map(|session| session.session_type), + Some(SessionType::SubAgent) + ) { + return false; } !self .extension_manager @@ -635,14 +682,14 @@ impl Agent { .unwrap_or(true) } - pub async fn list_tools(&self, extension_name: Option) -> Vec { + pub async fn list_tools(&self, session_id: &str, extension_name: Option) -> Vec { let mut prefixed_tools = self .extension_manager .get_prefixed_tools(extension_name.clone()) .await .unwrap_or_default(); - let subagents_enabled = self.subagents_enabled().await; + let subagents_enabled = self.subagents_enabled(session_id).await; if extension_name.is_none() || extension_name.as_deref() == Some("platform") { prefixed_tools.push(platform_tools::manage_schedule_tool()); } @@ -696,6 +743,8 @@ impl Agent { session_config: SessionConfig, cancel_token: Option, ) -> Result>> { + let session_manager = self.session_manager(); + for content in &user_message.content { if let MessageContent::ActionRequired(action_required) = content { if let ActionRequiredData::ElicitationResponse { id, user_data } = @@ -713,7 +762,9 @@ impl Agent { )) }))); } - SessionManager::add_message(&session_config.id, &user_message).await?; + session_manager + .add_message(&session_config.id, &user_message) + .await?; return Ok(Box::pin(futures::stream::empty())); } } @@ -745,16 +796,18 @@ impl Agent { }))); } Ok(Some(response)) if response.role == rmcp::model::Role::Assistant => { - SessionManager::add_message( - &session_config.id, - &user_message.clone().with_visibility(true, false), - ) - .await?; - SessionManager::add_message( - &session_config.id, - &response.clone().with_visibility(true, false), - ) - .await?; + session_manager + .add_message( + &session_config.id, + &user_message.clone().with_visibility(true, false), + ) + .await?; + session_manager + .add_message( + &session_config.id, + &response.clone().with_visibility(true, false), + ) + .await?; // Check if this was a command that modifies conversation history let modifies_history = crate::agents::execute_commands::COMPACT_TRIGGERS @@ -767,7 +820,7 @@ impl Agent { // After commands that modify history, notify UI that history was replaced if modifies_history { - let updated_session = SessionManager::get_session(&session_config.id, true) + let updated_session = session_manager.get_session(&session_config.id, true) .await .map_err(|e| anyhow!("Failed to fetch updated session: {}", e))?; let updated_conversation = updated_session @@ -778,22 +831,28 @@ impl Agent { })); } Ok(Some(resolved_message)) => { - SessionManager::add_message( - &session_config.id, - &user_message.clone().with_visibility(true, false), - ) - .await?; - SessionManager::add_message( - &session_config.id, - &resolved_message.clone().with_visibility(false, true), - ) - .await?; + session_manager + .add_message( + &session_config.id, + &user_message.clone().with_visibility(true, false), + ) + .await?; + session_manager + .add_message( + &session_config.id, + &resolved_message.clone().with_visibility(false, true), + ) + .await?; } Ok(None) => { - SessionManager::add_message(&session_config.id, &user_message).await?; + session_manager + .add_message(&session_config.id, &user_message) + .await?; } } - let session = SessionManager::get_session(&session_config.id, true).await?; + let session = session_manager + .get_session(&session_config.id, true) + .await?; let conversation = session .conversation .clone() @@ -840,8 +899,8 @@ impl Agent { match compact_messages(self.provider().await?.as_ref(), &conversation_to_compact, false).await { Ok((compacted_conversation, summarization_usage)) => { - SessionManager::replace_conversation(&session_config.id, &compacted_conversation).await?; - Self::update_session_metrics(&session_config, &summarization_usage, true).await?; + session_manager.replace_conversation(&session_config.id, &compacted_conversation).await?; + self.update_session_metrics(&session_config, &summarization_usage, true).await?; yield AgentEvent::HistoryReplaced(compacted_conversation.clone()); @@ -880,7 +939,7 @@ impl Agent { cancel_token: Option, ) -> Result>> { let context = self - .prepare_reply_context(conversation, &session.working_dir) + .prepare_reply_context(&session_config.id, conversation, &session.working_dir) .await?; let ReplyContext { mut conversation, @@ -894,10 +953,14 @@ impl Agent { self.reset_retry_attempts().await; let provider = self.provider().await?; + let session_manager = self.session_manager(); let session_id = session_config.id.clone(); - let working_dir = session.working_dir.clone(); + let manager_for_spawn = session_manager.clone(); tokio::spawn(async move { - if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await { + if let Err(e) = manager_for_spawn + .maybe_update_name(&session_id, provider) + .await + { warn!("Failed to generate session description: {}", e); } }); @@ -934,6 +997,7 @@ impl Agent { } let conversation_with_moim = super::moim::inject_moim( + &session_config.id, conversation.clone(), &self.extension_manager, ).await; @@ -982,7 +1046,7 @@ impl Agent { } if let Some(ref usage) = usage { - Self::update_session_metrics(&session_config, usage, false).await?; + self.update_session_metrics(&session_config, usage, false).await?; } if let Some(response) = response { @@ -1047,6 +1111,7 @@ impl Agent { .inspect_tools( &remaining_requests, conversation.messages(), + goose_mode, ) .await?; @@ -1117,7 +1182,7 @@ impl Agent { break; } - for msg in Self::drain_elicitation_messages(&session_config.id).await { + for msg in self.drain_elicitation_messages(&session_config.id).await { yield AgentEvent::Message(msg); } @@ -1141,7 +1206,7 @@ impl Agent { } // check for remaining elicitation messages after all tools complete - for msg in Self::drain_elicitation_messages(&session_config.id).await { + for msg in self.drain_elicitation_messages(&session_config.id).await { yield AgentEvent::Message(msg); } @@ -1219,8 +1284,8 @@ impl Agent { match compact_messages(self.provider().await?.as_ref(), &conversation, false).await { Ok((compacted_conversation, usage)) => { - SessionManager::replace_conversation(&session_config.id, &compacted_conversation).await?; - Self::update_session_metrics(&session_config, &usage, true).await?; + session_manager.replace_conversation(&session_config.id, &compacted_conversation).await?; + self.update_session_metrics(&session_config, &usage, true).await?; conversation = compacted_conversation; did_recovery_compact_this_iteration = true; yield AgentEvent::HistoryReplaced(conversation.clone()); @@ -1247,7 +1312,7 @@ impl Agent { } if tools_updated { (tools, toolshim_tools, system_prompt) = - self.prepare_tools_and_prompt(&working_dir).await?; + self.prepare_tools_and_prompt(&session_config.id, &session.working_dir).await?; } let mut exit_chat = false; if no_tools_called { @@ -1288,7 +1353,7 @@ impl Agent { } for msg in &messages_to_add { - SessionManager::add_message(&session_config.id, msg).await?; + session_manager.add_message(&session_config.id, msg).await?; } conversation.extend(messages_to_add); if exit_chat { @@ -1305,15 +1370,20 @@ impl Agent { prompt_manager.add_system_prompt_extra(instruction); } + pub async fn set_provider(&self, provider: Arc) { + let mut current_provider = self.provider.lock().await; + *current_provider = Some(provider); + } + pub async fn update_provider( &self, provider: Arc, session_id: &str, ) -> Result<()> { - let mut current_provider = self.provider.lock().await; - *current_provider = Some(provider.clone()); + self.set_provider(provider.clone()).await; - SessionManager::update_session(session_id) + self.session_manager() + .update(session_id) .provider_name(provider.get_name()) .model_config(provider.get_model_config()) .apply() @@ -1613,7 +1683,7 @@ mod tests { agent.add_final_output_tool(response).await; - let tools = agent.list_tools(None).await; + let tools = agent.list_tools("test-session-id", None).await; let final_output_tool = tools .iter() .find(|tool| tool.name == FINAL_OUTPUT_TOOL_NAME); diff --git a/crates/goose/src/agents/chatrecall_extension.rs b/crates/goose/src/agents/chatrecall_extension.rs index 5d45976c8077..d862fdcc1fca 100644 --- a/crates/goose/src/agents/chatrecall_extension.rs +++ b/crates/goose/src/agents/chatrecall_extension.rs @@ -1,18 +1,14 @@ use crate::agents::extension::PlatformExtensionContext; use crate::agents::mcp_client::{Error, McpClientTrait}; -use crate::session::SessionManager; use anyhow::Result; use async_trait::async_trait; use indoc::indoc; use rmcp::model::{ - CallToolResult, Content, GetPromptResult, Implementation, InitializeResult, JsonObject, - ListPromptsResult, ListResourcesResult, ListToolsResult, ProtocolVersion, ReadResourceResult, - ServerCapabilities, ServerNotification, Tool, ToolAnnotations, ToolsCapability, + CallToolResult, Content, Implementation, InitializeResult, JsonObject, ListToolsResult, + ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, ToolsCapability, }; use schemars::{schema_for, JsonSchema}; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; pub static EXTENSION_NAME: &str = "chatrecall"; @@ -80,18 +76,19 @@ impl ChatRecallClient { #[allow(clippy::too_many_lines)] async fn handle_chatrecall( &self, + current_session_id: &str, arguments: Option, ) -> Result, String> { let arguments = arguments.ok_or("Missing arguments")?; - let session_id = arguments + let target_session_id = arguments .get("session_id") .and_then(|v| v.as_str()) .map(|s| s.to_string()); - if let Some(sid) = session_id { + if let Some(sid) = target_session_id { // LOAD MODE: Get session summary (first and last few messages) - match SessionManager::get_session(&sid, true).await { + match self.context.session_manager.get_session(&sid, true).await { Ok(loaded_session) => { let conversation = loaded_session.conversation.as_ref(); @@ -187,16 +184,19 @@ impl ChatRecallClient { .map(|dt| dt.with_timezone(&chrono::Utc)); // Exclude current session from results to avoid self-referential loops - let exclude_session_id = self.context.session_id.clone(); - - match SessionManager::search_chat_history( - &query, - Some(limit), - after_date, - before_date, - exclude_session_id, - ) - .await + let exclude_session_id = Some(current_session_id.to_string()); + + match self + .context + .session_manager + .search_chat_history( + &query, + Some(limit), + after_date, + before_date, + exclude_session_id, + ) + .await { Ok(results) => { let formatted_results = if results.total_matches == 0 { @@ -278,22 +278,6 @@ impl ChatRecallClient { #[async_trait] impl McpClientTrait for ChatRecallClient { - async fn list_resources( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn read_resource( - &self, - _uri: &str, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - async fn list_tools( &self, _next_cursor: Option, @@ -310,10 +294,11 @@ impl McpClientTrait for ChatRecallClient { &self, name: &str, arguments: Option, + session_id: &str, _cancellation_token: CancellationToken, ) -> Result { let content = match name { - "chatrecall" => self.handle_chatrecall(arguments).await, + "chatrecall" => self.handle_chatrecall(session_id, arguments).await, _ => Err(format!("Unknown tool: {}", name)), }; @@ -326,27 +311,6 @@ impl McpClientTrait for ChatRecallClient { } } - async fn list_prompts( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn get_prompt( - &self, - _name: &str, - _arguments: Value, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn subscribe(&self) -> mpsc::Receiver { - mpsc::channel(1).1 - } - fn get_info(&self) -> Option<&InitializeResult> { Some(&self.info) } diff --git a/crates/goose/src/agents/code_execution_extension.rs b/crates/goose/src/agents/code_execution_extension.rs index 2d66599201ff..19e9bed7d423 100644 --- a/crates/goose/src/agents/code_execution_extension.rs +++ b/crates/goose/src/agents/code_execution_extension.rs @@ -10,10 +10,9 @@ use boa_engine::{js_string, Context, JsNativeError, JsString, JsValue, NativeFun use indoc::indoc; use regex::Regex; use rmcp::model::{ - CallToolRequestParam, CallToolResult, Content, GetPromptResult, Implementation, - InitializeResult, JsonObject, ListPromptsResult, ListResourcesResult, ListToolsResult, - ProtocolVersion, RawContent, ReadResourceResult, ServerCapabilities, ServerNotification, - Tool as McpTool, ToolAnnotations, ToolsCapability, + CallToolRequestParam, CallToolResult, Content, Implementation, InitializeResult, JsonObject, + ListToolsResult, ProtocolVersion, RawContent, ServerCapabilities, Tool as McpTool, + ToolAnnotations, ToolsCapability, }; use schemars::{schema_for, JsonSchema}; use serde::{Deserialize, Serialize}; @@ -496,6 +495,7 @@ impl CodeExecutionClient { async fn handle_execute_code( &self, + session_id: &str, arguments: Option, ) -> Result, String> { let code = arguments @@ -508,6 +508,7 @@ impl CodeExecutionClient { let tools = self.get_tool_infos().await; let (call_tx, call_rx) = mpsc::unbounded_channel(); let tool_handler = tokio::spawn(Self::run_tool_handler( + session_id.to_string(), call_rx, self.context.extension_manager.clone(), )); @@ -672,6 +673,7 @@ impl CodeExecutionClient { } async fn run_tool_handler( + session_id: String, mut call_rx: mpsc::UnboundedReceiver, extension_manager: Option>, ) { @@ -683,7 +685,7 @@ impl CodeExecutionClient { arguments: serde_json::from_str(&arguments).ok(), }; match manager - .dispatch_tool_call(tool_call, CancellationToken::new()) + .dispatch_tool_call(&session_id, tool_call, CancellationToken::new()) .await { Ok(dispatch_result) => match dispatch_result.result.await { @@ -714,22 +716,6 @@ impl CodeExecutionClient { #[async_trait] impl McpClientTrait for CodeExecutionClient { - async fn list_resources( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn read_resource( - &self, - _uri: &str, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - #[allow(clippy::too_many_lines)] async fn list_tools( &self, @@ -857,10 +843,11 @@ impl McpClientTrait for CodeExecutionClient { &self, name: &str, arguments: Option, + session_id: &str, _cancellation_token: CancellationToken, ) -> Result { let content = match name { - "execute_code" => self.handle_execute_code(arguments).await, + "execute_code" => self.handle_execute_code(session_id, arguments).await, "read_module" => self.handle_read_module(arguments).await, "search_modules" => self.handle_search_modules(arguments).await, _ => Err(format!("Unknown tool: {name}")), @@ -874,32 +861,11 @@ impl McpClientTrait for CodeExecutionClient { } } - async fn list_prompts( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn get_prompt( - &self, - _name: &str, - _arguments: Value, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn subscribe(&self) -> mpsc::Receiver { - mpsc::channel(1).1 - } - fn get_info(&self) -> Option<&InitializeResult> { Some(&self.info) } - async fn get_moim(&self) -> Option { + async fn get_moim(&self, _session_id: &str) -> Option { let tools = self.get_tool_infos().await; if tools.is_empty() { return None; @@ -935,9 +901,13 @@ mod tests { #[tokio::test] async fn test_execute_code_simple() { + let temp_dir = tempfile::tempdir().unwrap(); + let session_manager = Arc::new(crate::session::SessionManager::new( + temp_dir.path().to_path_buf(), + )); let context = PlatformExtensionContext { - session_id: None, extension_manager: None, + session_manager, }; let client = CodeExecutionClient::new(context).unwrap(); @@ -945,7 +915,12 @@ mod tests { args.insert("code".to_string(), Value::String("2 + 2".to_string())); let result = client - .call_tool("execute_code", Some(args), CancellationToken::new()) + .call_tool( + "execute_code", + Some(args), + "test-session-id", + CancellationToken::new(), + ) .await .unwrap(); @@ -959,9 +934,13 @@ mod tests { #[tokio::test] async fn test_read_module_not_found() { + let temp_dir = tempfile::tempdir().unwrap(); + let session_manager = Arc::new(crate::session::SessionManager::new( + temp_dir.path().to_path_buf(), + )); let context = PlatformExtensionContext { - session_id: None, extension_manager: None, + session_manager, }; let client = CodeExecutionClient::new(context).unwrap(); diff --git a/crates/goose/src/agents/execute_commands.rs b/crates/goose/src/agents/execute_commands.rs index 545fcb9b6de9..ecf3ae91da44 100644 --- a/crates/goose/src/agents/execute_commands.rs +++ b/crates/goose/src/agents/execute_commands.rs @@ -5,7 +5,6 @@ use anyhow::{anyhow, Result}; use crate::context_mgmt::compact_messages; use crate::conversation::message::{Message, SystemNotificationType}; use crate::recipe::build_recipe::build_recipe_from_template_with_positional_params; -use crate::session::SessionManager; use super::Agent; @@ -81,7 +80,8 @@ impl Agent { } async fn handle_compact_command(&self, session_id: &str) -> Result> { - let session = SessionManager::get_session(session_id, true).await?; + let manager = self.session_manager(); + let session = manager.get_session(session_id, true).await?; let conversation = session .conversation .ok_or_else(|| anyhow!("Session has no conversation"))?; @@ -93,7 +93,9 @@ impl Agent { ) .await?; - SessionManager::replace_conversation(session_id, &compacted_conversation).await?; + manager + .replace_conversation(session_id, &compacted_conversation) + .await?; Ok(Some(Message::assistant().with_system_notification( SystemNotificationType::InlineMessage, @@ -104,9 +106,13 @@ impl Agent { async fn handle_clear_command(&self, session_id: &str) -> Result> { use crate::conversation::Conversation; - SessionManager::replace_conversation(session_id, &Conversation::default()).await?; + let manager = self.session_manager(); + manager + .replace_conversation(session_id, &Conversation::default()) + .await?; - SessionManager::update_session(session_id) + manager + .update(session_id) .total_tokens(Some(0)) .input_tokens(Some(0)) .output_tokens(Some(0)) @@ -238,10 +244,12 @@ impl Agent { return Ok(Some(Message::assistant().with_text(error_msg))); } - SessionManager::add_message(session_id, &msg).await?; + self.session_manager().add_message(session_id, &msg).await?; } - let last_message = SessionManager::get_session(session_id, true) + let last_message = self + .session_manager() + .get_session(session_id, true) .await? .conversation .ok_or_else(|| anyhow!("No conversation found"))? diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index ef18d76ca50e..8e046924902a 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -106,9 +106,9 @@ pub static PLATFORM_EXTENSIONS: Lazy #[derive(Clone)] pub struct PlatformExtensionContext { - pub session_id: Option, pub extension_manager: Option>, + pub session_manager: std::sync::Arc, } #[derive(Debug, Clone)] diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index fe2ba045b16b..63701b1c33ea 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -94,7 +94,7 @@ impl Extension { /// Manages goose extensions / MCP clients and their interactions pub struct ExtensionManager { extensions: Mutex>, - context: Mutex, + context: PlatformExtensionContext, provider: SharedProvider, } @@ -209,12 +209,6 @@ pub fn get_parameter_names(tool: &Tool) -> Vec { names } -impl Default for ExtensionManager { - fn default() -> Self { - Self::new(Arc::new(Mutex::new(None))) - } -} - async fn child_process_client( mut command: Command, timeout: &Option, @@ -442,28 +436,28 @@ async fn create_stdio_client( } impl ExtensionManager { - pub fn new(provider: SharedProvider) -> Self { + pub fn new( + provider: SharedProvider, + session_manager: Arc, + ) -> Self { Self { extensions: Mutex::new(HashMap::new()), - context: Mutex::new(PlatformExtensionContext { - session_id: None, + context: PlatformExtensionContext { extension_manager: None, - }), + session_manager, + }, provider, } } - /// Create a new ExtensionManager with no provider (useful for tests) - pub fn new_without_provider() -> Self { - Self::new(Arc::new(Mutex::new(None))) - } - - pub async fn set_context(&self, context: PlatformExtensionContext) { - *self.context.lock().await = context; + #[cfg(test)] + pub fn new_without_provider(data_dir: std::path::PathBuf) -> Self { + let session_manager = Arc::new(crate::session::SessionManager::new(data_dir)); + Self::new(Arc::new(Mutex::new(None)), session_manager) } - pub async fn get_context(&self) -> PlatformExtensionContext { - self.context.lock().await.clone() + pub fn get_context(&self) -> &PlatformExtensionContext { + &self.context } pub async fn supports_resources(&self) -> bool { @@ -474,7 +468,7 @@ impl ExtensionManager { .any(|ext| ext.supports_resources()) } - pub async fn add_extension(&self, config: ExtensionConfig) -> ExtensionResult<()> { + pub async fn add_extension(self: &Arc, config: ExtensionConfig) -> ExtensionResult<()> { let config_name = config.key().to_string(); let sanitized_name = normalize(config_name.clone()); @@ -522,25 +516,23 @@ impl ExtensionManager { create_stdio_client(cmd, args, all_envs, timeout, self.provider.clone()).await? } ExtensionConfig::Builtin { name, timeout, .. } => { - let cmd = std::env::current_exe() - .and_then(|path| { - path.to_str().map(|s| s.to_string()).ok_or_else(|| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Invalid UTF-8 in executable path", - ) - }) - }) - .map_err(|e| { - ExtensionError::ConfigError(format!( - "Failed to resolve executable path: {}", - e - )) + let timeout_duration = Duration::from_secs(timeout.unwrap_or(300)); + let def = goose_mcp::BUILTIN_EXTENSIONS + .get(name.as_str()) + .ok_or_else(|| { + ExtensionError::ConfigError(format!("Unknown builtin extension: {}", name)) })?; - let command = Command::new(cmd).configure(|command| { - command.arg("mcp").arg(name); - }); - Box::new(child_process_client(command, timeout, self.provider.clone()).await?) + let (server_read, client_write) = tokio::io::duplex(65536); + let (client_read, server_write) = tokio::io::duplex(65536); + (def.spawn_server)(server_read, server_write); + Box::new( + McpClient::connect( + (client_read, client_write), + timeout_duration, + self.provider.clone(), + ) + .await?, + ) } ExtensionConfig::Platform { name, .. } => { let normalized_key = normalize(name.clone()); @@ -549,7 +541,8 @@ impl ExtensionManager { .ok_or_else(|| { ExtensionError::ConfigError(format!("Unknown platform extension: {}", name)) })?; - let context = self.get_context().await; + let mut context = self.context.clone(); + context.extension_manager = Some(Arc::downgrade(self)); (def.client_factory)(context) } ExtensionConfig::InlinePython { @@ -1024,6 +1017,7 @@ impl ExtensionManager { pub async fn dispatch_tool_call( &self, + session_id: &str, tool_call: CallToolRequestParam, cancellation_token: CancellationToken, ) -> Result { @@ -1062,11 +1056,17 @@ impl ExtensionManager { let arguments = tool_call.arguments.clone(); let client = client.clone(); let notifications_receiver = client.lock().await.subscribe().await; + let session_id = session_id.to_string(); let fut = async move { + tracing::debug!( + "dispatch_tool_call fut: calling client.call_tool tool={} session_id={}", + tool_name, + session_id + ); let client_guard = client.lock().await; client_guard - .call_tool(&tool_name, arguments, cancellation_token) + .call_tool(&tool_name, arguments, &session_id, cancellation_token) .await .map_err(|e| match e { ServiceError::McpError(error_data) => error_data, @@ -1247,7 +1247,7 @@ impl ExtensionManager { .map(|ext| ext.get_client()) } - pub async fn collect_moim(&self) -> Option { + pub async fn collect_moim(&self, session_id: &str) -> Option { // Use minute-level granularity to prevent conversation changes every second let timestamp = chrono::Local::now().format("%Y-%m-%d %H:%M:00").to_string(); let mut content = format!("\nIt is currently {}\n", timestamp); @@ -1268,7 +1268,7 @@ impl ExtensionManager { for (name, client) in platform_clients { let client_guard = client.lock().await; - if let Some(moim_content) = client_guard.get_moim().await { + if let Some(moim_content) = client_guard.get_moim(session_id).await { tracing::debug!("MOIM content from {}: {} chars", name, moim_content.len()); content.push('\n'); content.push_str(&moim_content); @@ -1383,6 +1383,7 @@ mod tests { &self, name: &str, _arguments: Option, + _session_id: &str, _cancellation_token: CancellationToken, ) -> Result { match name { @@ -1420,7 +1421,9 @@ mod tests { #[tokio::test] async fn test_get_client_for_tool() { - let extension_manager = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let extension_manager = + ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); // Add some mock clients using the helper method extension_manager @@ -1480,7 +1483,9 @@ mod tests { async fn test_dispatch_tool_call() { // test that dispatch_tool_call parses out the sanitized name correctly, and extracts // tool_names - let extension_manager = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let extension_manager = + ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); // Add some mock clients using the helper method extension_manager @@ -1511,7 +1516,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call("test-session-id", tool_call, CancellationToken::default()) .await; assert!(result.is_ok()); @@ -1521,7 +1526,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call("test-session-id", tool_call, CancellationToken::default()) .await; assert!(result.is_ok()); @@ -1532,7 +1537,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call("test-session-id", tool_call, CancellationToken::default()) .await; assert!(result.is_ok()); @@ -1543,7 +1548,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call("test-session-id", tool_call, CancellationToken::default()) .await; assert!(result.is_ok()); @@ -1553,7 +1558,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call("test-session-id", tool_call, CancellationToken::default()) .await; assert!(result.is_ok()); @@ -1564,7 +1569,11 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(invalid_tool_call, CancellationToken::default()) + .dispatch_tool_call( + "test-session-id", + invalid_tool_call, + CancellationToken::default(), + ) .await .unwrap() .result @@ -1585,7 +1594,11 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(invalid_tool_call, CancellationToken::default()) + .dispatch_tool_call( + "test-session-id", + invalid_tool_call, + CancellationToken::default(), + ) .await; if let Err(err) = result { let tool_err = err.downcast_ref::().expect("Expected ErrorData"); @@ -1597,7 +1610,9 @@ mod tests { #[tokio::test] async fn test_tool_availability_filtering() { - let extension_manager = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let extension_manager = + ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); // Only "available_tool" should be available to the LLM let available_tools = vec!["available_tool".to_string()]; @@ -1625,7 +1640,9 @@ mod tests { #[tokio::test] async fn test_tool_availability_defaults_to_available() { - let extension_manager = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let extension_manager = + ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); extension_manager .add_mock_extension_with_tools( @@ -1650,7 +1667,9 @@ mod tests { #[tokio::test] async fn test_dispatch_unavailable_tool_returns_error() { - let extension_manager = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let extension_manager = + ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); let available_tools = vec!["available_tool".to_string()]; @@ -1669,7 +1688,11 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(unavailable_tool_call, CancellationToken::default()) + .dispatch_tool_call( + "test-session-id", + unavailable_tool_call, + CancellationToken::default(), + ) .await; // Should return RESOURCE_NOT_FOUND error @@ -1688,7 +1711,11 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(available_tool_call, CancellationToken::default()) + .dispatch_tool_call( + "test-session-id", + available_tool_call, + CancellationToken::default(), + ) .await; assert!(result.is_ok()); @@ -1759,9 +1786,10 @@ mod tests { #[tokio::test] async fn test_collect_moim_uses_minute_granularity() { - let em = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); - if let Some(moim) = em.collect_moim().await { + if let Some(moim) = em.collect_moim("test-session-id").await { // Timestamp should end with :00 (seconds fixed to 00) assert!( moim.contains(":00\n"), diff --git a/crates/goose/src/agents/extension_manager_extension.rs b/crates/goose/src/agents/extension_manager_extension.rs index 49e09ed066aa..2bd91ba5b271 100644 --- a/crates/goose/src/agents/extension_manager_extension.rs +++ b/crates/goose/src/agents/extension_manager_extension.rs @@ -418,6 +418,7 @@ impl McpClientTrait for ExtensionManagerClient { &self, name: &str, arguments: Option, + _session_id: &str, _cancellation_token: CancellationToken, ) -> Result { let result = match name { diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index 677dabf9149f..b83104737754 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -3,7 +3,7 @@ use crate::agents::types::SharedProvider; use crate::session_context::SESSION_ID_HEADER; use rmcp::model::{ Content, CreateElicitationRequestParam, CreateElicitationResult, ElicitationAction, ErrorCode, - JsonObject, + Extensions, JsonObject, Meta, }; /// MCP client implementation for Goose use rmcp::{ @@ -39,18 +39,6 @@ pub type Error = rmcp::ServiceError; #[async_trait::async_trait] pub trait McpClientTrait: Send + Sync { - async fn list_resources( - &self, - next_cursor: Option, - cancel_token: CancellationToken, - ) -> Result; - - async fn read_resource( - &self, - uri: &str, - cancel_token: CancellationToken, - ) -> Result; - async fn list_tools( &self, next_cursor: Option, @@ -61,27 +49,50 @@ pub trait McpClientTrait: Send + Sync { &self, name: &str, arguments: Option, + session_id: &str, cancel_token: CancellationToken, ) -> Result; + fn get_info(&self) -> Option<&InitializeResult>; + + async fn list_resources( + &self, + _next_cursor: Option, + _cancel_token: CancellationToken, + ) -> Result { + Err(Error::TransportClosed) + } + + async fn read_resource( + &self, + _uri: &str, + _cancel_token: CancellationToken, + ) -> Result { + Err(Error::TransportClosed) + } + async fn list_prompts( &self, - next_cursor: Option, - cancel_token: CancellationToken, - ) -> Result; + _next_cursor: Option, + _cancel_token: CancellationToken, + ) -> Result { + Err(Error::TransportClosed) + } async fn get_prompt( &self, - name: &str, - arguments: Value, - cancel_token: CancellationToken, - ) -> Result; - - async fn subscribe(&self) -> mpsc::Receiver; + _name: &str, + _arguments: Value, + _cancel_token: CancellationToken, + ) -> Result { + Err(Error::TransportClosed) + } - fn get_info(&self) -> Option<&InitializeResult>; + async fn subscribe(&self) -> mpsc::Receiver { + mpsc::channel(1).1 + } - async fn get_moim(&self) -> Option { + async fn get_moim(&self, _session_id: &str) -> Option { None } } @@ -379,7 +390,7 @@ impl McpClientTrait for McpClient { ClientRequest::ListResourcesRequest(ListResourcesRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: inject_session_into_extensions(Default::default()), + extensions: inject_current_session_id_into_extensions(Default::default()), }), cancel_token, ) @@ -403,7 +414,7 @@ impl McpClientTrait for McpClient { uri: uri.to_string(), }, method: Default::default(), - extensions: inject_session_into_extensions(Default::default()), + extensions: inject_current_session_id_into_extensions(Default::default()), }), cancel_token, ) @@ -425,7 +436,7 @@ impl McpClientTrait for McpClient { ClientRequest::ListToolsRequest(ListToolsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: inject_session_into_extensions(Default::default()), + extensions: inject_current_session_id_into_extensions(Default::default()), }), cancel_token, ) @@ -441,6 +452,7 @@ impl McpClientTrait for McpClient { &self, name: &str, arguments: Option, + session_id: &str, cancel_token: CancellationToken, ) -> Result { let res = self @@ -451,7 +463,7 @@ impl McpClientTrait for McpClient { arguments, }, method: Default::default(), - extensions: inject_session_into_extensions(Default::default()), + extensions: inject_session_id_into_extensions(Default::default(), session_id), }), cancel_token, ) @@ -473,7 +485,7 @@ impl McpClientTrait for McpClient { ClientRequest::ListPromptsRequest(ListPromptsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: inject_session_into_extensions(Default::default()), + extensions: inject_current_session_id_into_extensions(Default::default()), }), cancel_token, ) @@ -503,7 +515,7 @@ impl McpClientTrait for McpClient { arguments, }, method: Default::default(), - extensions: inject_session_into_extensions(Default::default()), + extensions: inject_current_session_id_into_extensions(Default::default()), }), cancel_token, ) @@ -522,27 +534,32 @@ impl McpClientTrait for McpClient { } } -/// Replaces session ID, case-insensitively, in Extensions._meta. -fn inject_session_into_extensions( - mut extensions: rmcp::model::Extensions, -) -> rmcp::model::Extensions { - use rmcp::model::Meta; +/// Injects the given session_id into Extensions._meta. +fn inject_session_id_into_extensions(mut extensions: Extensions, session_id: &str) -> Extensions { + let mut meta_map = extensions + .get::() + .map(|meta| meta.0.clone()) + .unwrap_or_default(); - if let Some(session_id) = crate::session_context::current_session_id() { - let mut meta_map = extensions - .get::() - .map(|meta| meta.0.clone()) - .unwrap_or_default(); + // JsonObject is case-sensitive, so we use retain for case-insensitive removal + meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER)); - // JsonObject is case-sensitive, so we use retain for case-insensitive removal - meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER)); + meta_map.insert( + SESSION_ID_HEADER.to_string(), + Value::String(session_id.to_string()), + ); - meta_map.insert(SESSION_ID_HEADER.to_string(), Value::String(session_id)); + extensions.insert(Meta(meta_map)); + extensions +} - extensions.insert(Meta(meta_map)); +/// Injects session ID from task-local context into Extensions._meta. +fn inject_current_session_id_into_extensions(extensions: Extensions) -> Extensions { + if let Some(session_id) = crate::session_context::current_session_id() { + inject_session_id_into_extensions(extensions, &session_id) + } else { + extensions } - - extensions } #[cfg(test)] @@ -556,7 +573,7 @@ mod tests { let session_id = "test-session-789"; crate::session_context::with_session_id(Some(session_id.to_string()), async { - let extensions = inject_session_into_extensions(Default::default()); + let extensions = inject_current_session_id_into_extensions(Default::default()); let meta = extensions.get::().unwrap(); assert_eq!( @@ -573,7 +590,7 @@ mod tests { #[tokio::test] async fn test_no_session_id_in_mcp_when_absent() { - let extensions = inject_session_into_extensions(Default::default()); + let extensions = inject_current_session_id_into_extensions(Default::default()); let meta = extensions.get::(); assert!(meta.is_none()); @@ -585,9 +602,9 @@ mod tests { let session_id = "consistent-session-id"; crate::session_context::with_session_id(Some(session_id.to_string()), async { - let ext1 = inject_session_into_extensions(Default::default()); - let ext2 = inject_session_into_extensions(Default::default()); - let ext3 = inject_session_into_extensions(Default::default()); + let ext1 = inject_current_session_id_into_extensions(Default::default()); + let ext2 = inject_current_session_id_into_extensions(Default::default()); + let ext3 = inject_current_session_id_into_extensions(Default::default()); for ext in [&ext1, &ext2, &ext3] { assert_eq!( @@ -620,7 +637,7 @@ mod tests { .unwrap(), ); - let extensions = inject_session_into_extensions(extensions); + let extensions = inject_current_session_id_into_extensions(extensions); let meta = extensions.get::().unwrap(); assert_eq!( diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 0384990594eb..23f6bc88caaa 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -24,7 +24,7 @@ pub(crate) mod todo_extension; mod tool_execution; pub mod types; -pub use agent::{Agent, AgentEvent}; +pub use agent::{Agent, AgentConfig, AgentEvent}; pub use execute_commands::COMPACT_TRIGGERS; pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; diff --git a/crates/goose/src/agents/moim.rs b/crates/goose/src/agents/moim.rs index 97f273d52412..b4c3122dca56 100644 --- a/crates/goose/src/agents/moim.rs +++ b/crates/goose/src/agents/moim.rs @@ -9,6 +9,7 @@ thread_local! { } pub async fn inject_moim( + session_id: &str, conversation: Conversation, extension_manager: &ExtensionManager, ) -> Conversation { @@ -16,7 +17,7 @@ pub async fn inject_moim( return conversation; } - if let Some(moim) = extension_manager.collect_moim().await { + if let Some(moim) = extension_manager.collect_moim(session_id).await { let mut messages = conversation.messages().clone(); let idx = messages .iter() @@ -48,14 +49,15 @@ mod tests { #[tokio::test] async fn test_moim_injection_before_assistant() { - let em = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); let conv = Conversation::new_unvalidated(vec![ Message::user().with_text("Hello"), Message::assistant().with_text("Hi"), Message::user().with_text("Bye"), ]); - let result = inject_moim(conv, &em).await; + let result = inject_moim("test-session-id", conv, &em).await; let msgs = result.messages(); assert_eq!(msgs.len(), 3); @@ -74,10 +76,11 @@ mod tests { #[tokio::test] async fn test_moim_injection_no_assistant() { - let em = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); let conv = Conversation::new_unvalidated(vec![Message::user().with_text("Hello")]); - let result = inject_moim(conv, &em).await; + let result = inject_moim("test-session-id", conv, &em).await; assert_eq!(result.messages().len(), 1); @@ -93,7 +96,8 @@ mod tests { #[tokio::test] async fn test_moim_with_tool_calls() { - let em = ExtensionManager::new_without_provider(); + let temp_dir = tempfile::tempdir().unwrap(); + let em = ExtensionManager::new_without_provider(temp_dir.path().to_path_buf()); let conv = Conversation::new_unvalidated(vec![ Message::user().with_text("Search for something"), @@ -135,7 +139,7 @@ mod tests { ), ]); - let result = inject_moim(conv, &em).await; + let result = inject_moim("test-session-id", conv, &em).await; let msgs = result.messages(); assert_eq!(msgs.len(), 6); diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index e98bcf145c2f..7e334f48c610 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -17,7 +17,6 @@ use crate::providers::toolshim::{ }; use crate::agents::code_execution_extension::EXTENSION_NAME as CODE_EXECUTION_EXTENSION; -use crate::session::SessionManager; #[cfg(test)] use crate::session::SessionType; use rmcp::model::Tool; @@ -111,10 +110,11 @@ async fn toolshim_postprocess( impl Agent { pub async fn prepare_tools_and_prompt( &self, + session_id: &str, working_dir: &std::path::Path, ) -> Result<(Vec, Vec, String)> { // Get tools from extension manager - let mut tools = self.list_tools(None).await; + let mut tools = self.list_tools(session_id, None).await; // Add frontend tools let frontend_tools = self.frontend_tools.lock().await; @@ -151,7 +151,7 @@ impl Agent { .with_extension_and_tool_counts(extension_count, tool_count) .with_code_execution_mode(code_execution_active) .with_hints(working_dir) - .with_enable_subagents(self.subagents_enabled().await) + .with_enable_subagents(self.subagents_enabled(session_id).await) .build(); // Handle toolshim if enabled @@ -345,12 +345,14 @@ impl Agent { } pub(crate) async fn update_session_metrics( + &self, session_config: &crate::agents::types::SessionConfig, usage: &ProviderUsage, is_compaction_usage: bool, ) -> Result<()> { let session_id = session_config.id.as_str(); - let session = SessionManager::get_session(session_id, false).await?; + let manager = self.session_manager(); + let session = manager.get_session(session_id, false).await?; let accumulate = |a: Option, b: Option| -> Option { match (a, b) { @@ -378,7 +380,8 @@ impl Agent { ) }; - SessionManager::update_session(session_id) + manager + .update(session_id) .schedule_id(session_config.schedule_id.clone()) .total_tokens(current_total) .input_tokens(current_input) @@ -440,12 +443,14 @@ mod tests { async fn prepare_tools_sorts_and_includes_frontend_and_list_tools() -> anyhow::Result<()> { let agent = crate::agents::Agent::new(); - let session = SessionManager::create_session( - std::path::PathBuf::default(), - "test-prepare-tools".to_string(), - SessionType::Hidden, - ) - .await?; + let session = agent + .session_manager() + .create_session( + std::path::PathBuf::default(), + "test-prepare-tools".to_string(), + SessionType::Hidden, + ) + .await?; let model_config = ModelConfig::new("test-model").unwrap(); let provider = std::sync::Arc::new(MockProvider { model_config }); @@ -478,8 +483,9 @@ mod tests { .unwrap(); let working_dir = std::env::current_dir()?; - let (tools, _toolshim_tools, _system_prompt) = - agent.prepare_tools_and_prompt(&working_dir).await?; + let (tools, _toolshim_tools, _system_prompt) = agent + .prepare_tools_and_prompt(&session.id, &working_dir) + .await?; // Ensure both platform and frontend tools are present let names: Vec = tools.iter().map(|t| t.name.clone().into_owned()).collect(); diff --git a/crates/goose/src/agents/schedule_tool.rs b/crates/goose/src/agents/schedule_tool.rs index fe23265bf6c8..3c83570bbdd8 100644 --- a/crates/goose/src/agents/schedule_tool.rs +++ b/crates/goose/src/agents/schedule_tool.rs @@ -20,16 +20,7 @@ impl Agent { arguments: serde_json::Value, _request_id: String, ) -> ToolResult> { - let scheduler = match self.scheduler_service.lock().await.as_ref() { - Some(s) => s.clone(), - None => { - return Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - "Scheduler not available. This tool only works in server mode.".to_string(), - None, - )) - } - }; + let scheduler = Arc::clone(&self.config.scheduler); let action = arguments .get("action") @@ -437,7 +428,7 @@ impl Agent { ) })?; - let session = match crate::session::SessionManager::get_session(session_id, true).await { + let session = match self.session_manager().get_session(session_id, true).await { Ok(metadata) => metadata, Err(e) => { return Err(ErrorData::new( diff --git a/crates/goose/src/agents/skills_extension.rs b/crates/goose/src/agents/skills_extension.rs index 633dba906835..892d09696dc3 100644 --- a/crates/goose/src/agents/skills_extension.rs +++ b/crates/goose/src/agents/skills_extension.rs @@ -5,16 +5,13 @@ use anyhow::Result; use async_trait::async_trait; use indoc::indoc; use rmcp::model::{ - CallToolResult, Content, GetPromptResult, Implementation, InitializeResult, JsonObject, - ListPromptsResult, ListResourcesResult, ListToolsResult, ProtocolVersion, ReadResourceResult, - ServerCapabilities, ServerNotification, Tool, ToolAnnotations, ToolsCapability, + CallToolResult, Content, Implementation, InitializeResult, JsonObject, ListToolsResult, + ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, ToolsCapability, }; use schemars::{schema_for, JsonSchema}; use serde::{Deserialize, Serialize}; -use serde_json::Value; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; pub static EXTENSION_NAME: &str = "skills"; @@ -263,22 +260,6 @@ impl SkillsClient { #[async_trait] impl McpClientTrait for SkillsClient { - async fn list_resources( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn read_resource( - &self, - _uri: &str, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - async fn list_tools( &self, _next_cursor: Option, @@ -300,6 +281,7 @@ impl McpClientTrait for SkillsClient { &self, name: &str, arguments: Option, + _session_id: &str, _cancellation_token: CancellationToken, ) -> Result { let content = match name { @@ -316,27 +298,6 @@ impl McpClientTrait for SkillsClient { } } - async fn list_prompts( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn get_prompt( - &self, - _name: &str, - _arguments: Value, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn subscribe(&self) -> mpsc::Receiver { - mpsc::channel(1).1 - } - fn get_info(&self) -> Option<&InitializeResult> { Some(&self.info) } diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index b54d747db199..631a08a5b962 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -1,7 +1,6 @@ use crate::{ - agents::{subagent_task_config::TaskConfig, AgentEvent, SessionConfig}, + agents::{subagent_task_config::TaskConfig, Agent, AgentConfig, AgentEvent, SessionConfig}, conversation::{message::Message, Conversation}, - execution::manager::AgentManager, prompt_template::render_global_file, recipe::Recipe, }; @@ -11,6 +10,7 @@ use rmcp::model::{ErrorCode, ErrorData}; use serde::Serialize; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; use tokio_util::sync::CancellationToken; use tracing::{debug, info}; @@ -28,22 +28,28 @@ type AgentMessagesFuture = /// Standalone function to run a complete subagent task with output options pub async fn run_complete_subagent_task( + agent_config: AgentConfig, recipe: Recipe, task_config: TaskConfig, return_last_only: bool, session_id: String, cancellation_token: Option, ) -> Result { - let (messages, final_output) = - get_agent_messages(recipe, task_config, session_id, cancellation_token) - .await - .map_err(|e| { - ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Failed to execute task: {}", e), - None, - ) - })?; + let (messages, final_output) = get_agent_messages( + agent_config, + recipe, + task_config, + session_id, + cancellation_token, + ) + .await + .map_err(|e| { + ErrorData::new( + ErrorCode::INTERNAL_ERROR, + format!("Failed to execute task: {}", e), + None, + ) + })?; if let Some(output) = final_output { return Ok(output); @@ -111,6 +117,7 @@ pub async fn run_complete_subagent_task( } fn get_agent_messages( + agent_config: AgentConfig, recipe: Recipe, task_config: TaskConfig, session_id: String, @@ -123,14 +130,7 @@ fn get_agent_messages( .clone() .unwrap_or_else(|| "Begin.".to_string()); - let agent_manager = AgentManager::instance() - .await - .map_err(|e| anyhow!("Failed to create AgentManager: {}", e))?; - - let agent = agent_manager - .get_or_create_agent(session_id.clone()) - .await - .map_err(|e| anyhow!("Failed to get sub agent session file path: {}", e))?; + let agent = Arc::new(Agent::with_config(agent_config)); agent .update_provider(task_config.provider, &session_id) @@ -152,7 +152,7 @@ fn get_agent_messages( .apply_recipe_components(recipe.sub_recipes.clone(), recipe.response.clone(), true) .await; - let tools = agent.list_tools(None).await; + let tools = agent.list_tools(&session_id, None).await; let subagent_prompt = render_global_file( "subagent_system.md", &SubagentPromptContext { diff --git a/crates/goose/src/agents/subagent_tool.rs b/crates/goose/src/agents/subagent_tool.rs index 9a56122e7135..f625406cf7a5 100644 --- a/crates/goose/src/agents/subagent_tool.rs +++ b/crates/goose/src/agents/subagent_tool.rs @@ -12,11 +12,11 @@ use tokio_util::sync::CancellationToken; use crate::agents::subagent_handler::run_complete_subagent_task; use crate::agents::subagent_task_config::TaskConfig; use crate::agents::tool_execution::ToolCallResult; +use crate::agents::AgentConfig; use crate::providers; use crate::recipe::build_recipe::build_recipe_from_template; use crate::recipe::local_recipes::load_local_recipe_file; use crate::recipe::{Recipe, SubRecipe}; -use crate::session::SessionManager; pub const SUBAGENT_TOOL_NAME: &str = "subagent"; @@ -176,6 +176,7 @@ fn get_subrecipe_params_description(sub_recipe: &SubRecipe) -> String { /// (e.g., "[run sequentially, not in parallel]") but not enforced. The LLM controls /// sequencing by making sequential vs parallel tool calls. pub fn handle_subagent_tool( + agent_config: &AgentConfig, params: Value, task_config: TaskConfig, sub_recipes: HashMap, @@ -220,10 +221,12 @@ pub fn handle_subagent_tool( } }; + let agent_config = agent_config.clone(); ToolCallResult { notification_stream: None, result: Box::new( execute_subagent( + agent_config, recipe, task_config, parsed_params, @@ -236,23 +239,26 @@ pub fn handle_subagent_tool( } async fn execute_subagent( + agent_config: AgentConfig, recipe: Recipe, task_config: TaskConfig, params: SubagentParams, working_dir: PathBuf, cancellation_token: Option, ) -> Result { - let session = SessionManager::create_session( - working_dir, - "Subagent task".to_string(), - crate::session::session_manager::SessionType::SubAgent, - ) - .await - .map_err(|e| ErrorData { - code: ErrorCode::INTERNAL_ERROR, - message: Cow::from(format!("Failed to create session: {}", e)), - data: None, - })?; + let session = agent_config + .session_manager + .create_session( + working_dir, + "Subagent task".to_string(), + crate::session::session_manager::SessionType::SubAgent, + ) + .await + .map_err(|e| ErrorData { + code: ErrorCode::INTERNAL_ERROR, + message: Cow::from(format!("Failed to create session: {}", e)), + data: None, + })?; let task_config = apply_settings_overrides(task_config, ¶ms) .await @@ -263,6 +269,7 @@ async fn execute_subagent( })?; let result = run_complete_subagent_task( + agent_config, recipe, task_config, params.summary, diff --git a/crates/goose/src/agents/todo_extension.rs b/crates/goose/src/agents/todo_extension.rs index 4b1c0b686176..89046cd2a603 100644 --- a/crates/goose/src/agents/todo_extension.rs +++ b/crates/goose/src/agents/todo_extension.rs @@ -1,19 +1,16 @@ use crate::agents::extension::PlatformExtensionContext; use crate::agents::mcp_client::{Error, McpClientTrait}; +use crate::session::extension_data; use crate::session::extension_data::ExtensionState; -use crate::session::{extension_data, SessionManager}; use anyhow::Result; use async_trait::async_trait; use indoc::indoc; use rmcp::model::{ - CallToolResult, Content, GetPromptResult, Implementation, InitializeResult, JsonObject, - ListPromptsResult, ListResourcesResult, ListToolsResult, ProtocolVersion, ReadResourceResult, - ServerCapabilities, ServerNotification, Tool, ToolAnnotations, ToolsCapability, + CallToolResult, Content, Implementation, InitializeResult, JsonObject, ListToolsResult, + ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, ToolsCapability, }; use schemars::{schema_for, JsonSchema}; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; pub static EXTENSION_NAME: &str = "todo"; @@ -26,7 +23,6 @@ struct TodoWriteParams { pub struct TodoClient { info: InitializeResult, context: PlatformExtensionContext, - fallback_content: tokio::sync::RwLock, } impl TodoClient { @@ -70,15 +66,12 @@ impl TodoClient { ), }; - Ok(Self { - info, - context, - fallback_content: tokio::sync::RwLock::new(String::new()), - }) + Ok(Self { info, context }) } async fn handle_write_todo( &self, + session_id: &str, arguments: Option, ) -> Result, String> { let content = arguments @@ -102,38 +95,31 @@ impl TodoClient { )); } - if let Some(session_id) = &self.context.session_id { - match SessionManager::get_session(session_id, false).await { - Ok(mut session) => { - let todo_state = extension_data::TodoState::new(content); - if todo_state - .to_extension_data(&mut session.extension_data) - .is_ok() + let manager = &self.context.session_manager; + match manager.get_session(session_id, false).await { + Ok(mut session) => { + let todo_state = extension_data::TodoState::new(content); + if todo_state + .to_extension_data(&mut session.extension_data) + .is_ok() + { + match manager + .update(session_id) + .extension_data(session.extension_data) + .apply() + .await { - match SessionManager::update_session(session_id) - .extension_data(session.extension_data) - .apply() - .await - { - Ok(_) => Ok(vec![Content::text(format!( - "Updated ({} chars)", - char_count - ))]), - Err(_) => Err("Failed to update session metadata".to_string()), - } - } else { - Err("Failed to serialize TODO state".to_string()) + Ok(_) => Ok(vec![Content::text(format!( + "Updated ({} chars)", + char_count + ))]), + Err(_) => Err("Failed to update session metadata".to_string()), } + } else { + Err("Failed to serialize TODO state".to_string()) } - Err(_) => Err("Failed to read session metadata".to_string()), } - } else { - let mut fallback = self.fallback_content.write().await; - *fallback = content; - Ok(vec![Content::text(format!( - "Updated ({} chars)", - char_count - ))]) + Err(_) => Err("Failed to read session metadata".to_string()), } } @@ -169,22 +155,6 @@ impl TodoClient { #[async_trait] impl McpClientTrait for TodoClient { - async fn list_resources( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn read_resource( - &self, - _uri: &str, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - async fn list_tools( &self, _next_cursor: Option, @@ -201,10 +171,11 @@ impl McpClientTrait for TodoClient { &self, name: &str, arguments: Option, + session_id: &str, _cancellation_token: CancellationToken, ) -> Result { let content = match name { - "todo_write" => self.handle_write_todo(arguments).await, + "todo_write" => self.handle_write_todo(session_id, arguments).await, _ => Err(format!("Unknown tool: {}", name)), }; @@ -217,34 +188,17 @@ impl McpClientTrait for TodoClient { } } - async fn list_prompts( - &self, - _next_cursor: Option, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn get_prompt( - &self, - _name: &str, - _arguments: Value, - _cancellation_token: CancellationToken, - ) -> Result { - Err(Error::TransportClosed) - } - - async fn subscribe(&self) -> mpsc::Receiver { - mpsc::channel(1).1 - } - fn get_info(&self) -> Option<&InitializeResult> { Some(&self.info) } - async fn get_moim(&self) -> Option { - let session_id = self.context.session_id.as_ref()?; - let metadata = SessionManager::get_session(session_id, false).await.ok()?; + async fn get_moim(&self, session_id: &str) -> Option { + let metadata = self + .context + .session_manager + .get_session(session_id, false) + .await + .ok()?; match extension_data::TodoState::from_extension_data(&metadata.extension_data) { Some(state) if !state.content.trim().is_empty() => { diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 5323561cd3f0..14ab3ec1153e 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -131,6 +131,12 @@ impl Agent { request.metadata.as_ref(), ); } + + if confirmation.permission == Permission::AlwaysDeny { + self.tool_inspection_manager + .update_permission_manager(&tool_call.name, PermissionLevel::NeverAllow) + .await; + } } break; // Exit the loop once the matching `req_id` is found } diff --git a/crates/goose/src/config/base.rs b/crates/goose/src/config/base.rs index 90e21a28d7e6..b83175d72c72 100644 --- a/crates/goose/src/config/base.rs +++ b/crates/goose/src/config/base.rs @@ -1007,7 +1007,6 @@ mod tests { } #[test] - #[serial] fn test_multiple_secrets() -> Result<(), ConfigError> { let config = new_test_config(); diff --git a/crates/goose/src/config/permission.rs b/crates/goose/src/config/permission.rs index b2f2b63b1997..556cc5cda1cb 100644 --- a/crates/goose/src/config/permission.rs +++ b/crates/goose/src/config/permission.rs @@ -3,8 +3,14 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock, RwLock}; use utoipa::ToSchema; +const PERMISSION_FILE: &str = "permission.yaml"; + +static PERMISSION_MANAGER: LazyLock> = + LazyLock::new(|| Arc::new(PermissionManager::new(Paths::config_dir()))); + /// Enum representing the possible permission levels for a tool. #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, ToSchema)] #[serde(rename_all = "snake_case")] @@ -23,62 +29,44 @@ pub struct PermissionConfig { } /// PermissionManager manages permission configurations for various tools. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct PermissionManager { - config_path: PathBuf, // Path to the permission configuration file - permission_map: HashMap, // Mapping of permission names to configurations + config_path: PathBuf, + permission_map: RwLock>, } // Constants representing specific permission categories const USER_PERMISSION: &str = "user"; const SMART_APPROVE_PERMISSION: &str = "smart_approve"; -/// Implements the default constructor for `PermissionManager`. -impl Default for PermissionManager { - fn default() -> Self { - let config_path = Paths::config_dir().join("permission.yaml"); - - // Load the existing configuration file or create an empty map if the file doesn't exist - let permission_map = if config_path.exists() { - // Load the configuration file - let file_contents = - fs::read_to_string(&config_path).expect("Failed to read permission.yaml"); +impl PermissionManager { + pub fn new(config_dir: PathBuf) -> Self { + let permission_path = config_dir.join(PERMISSION_FILE); + let permission_map = if permission_path.exists() { + let file_contents = fs::read_to_string(&permission_path).unwrap(); serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new()) } else { - HashMap::new() // No config file, create an empty map + fs::create_dir_all(&config_dir).unwrap(); + HashMap::new() }; - PermissionManager { - config_path, - permission_map, + config_path: permission_path, + permission_map: RwLock::new(permission_map), } } -} -impl PermissionManager { - /// Creates a new `PermissionManager` with a specified config path. - pub fn new>(config_path: P) -> Self { - let config_path = config_path.as_ref().to_path_buf(); - - // Load the existing configuration file or create an empty map if the file doesn't exist - let permission_map = if config_path.exists() { - // Load the configuration file - let file_contents = - fs::read_to_string(&config_path).expect("Failed to read permission.yaml"); - serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new()) - } else { - HashMap::new() // No config file, create an empty map - }; - - PermissionManager { - config_path, - permission_map, - } + pub fn instance() -> Arc { + Arc::clone(&PERMISSION_MANAGER) } /// Returns a list of all the names (keys) in the permission map. pub fn get_permission_names(&self) -> Vec { - self.permission_map.keys().cloned().collect() + self.permission_map + .read() + .unwrap() + .keys() + .cloned() + .collect() } /// Retrieves the user permission level for a specific tool. @@ -98,8 +86,9 @@ impl PermissionManager { /// Helper function to retrieve the permission level for a specific permission category and tool. fn get_permission(&self, name: &str, principal_name: &str) -> Option { + let map = self.permission_map.read().unwrap(); // Check if the permission category exists in the map - if let Some(permission_config) = self.permission_map.get(name) { + if let Some(permission_config) = map.get(name) { // Check the permission levels for the given tool if permission_config .always_allow @@ -122,23 +111,20 @@ impl PermissionManager { } /// Updates the user permission level for a specific tool. - pub fn update_user_permission(&mut self, principal_name: &str, level: PermissionLevel) { + pub fn update_user_permission(&self, principal_name: &str, level: PermissionLevel) { self.update_permission(USER_PERMISSION, principal_name, level) } /// Updates the smart approve permission level for a specific tool. - pub fn update_smart_approve_permission( - &mut self, - principal_name: &str, - level: PermissionLevel, - ) { + pub fn update_smart_approve_permission(&self, principal_name: &str, level: PermissionLevel) { self.update_permission(SMART_APPROVE_PERMISSION, principal_name, level) } /// Helper function to update a permission level for a specific tool in a given permission category. - fn update_permission(&mut self, name: &str, principal_name: &str, level: PermissionLevel) { + fn update_permission(&self, name: &str, principal_name: &str, level: PermissionLevel) { + let mut map = self.permission_map.write().unwrap(); // Get or create a new PermissionConfig for the specified category - let permission_config = self.permission_map.entry(name.to_string()).or_default(); + let permission_config = map.entry(name.to_string()).or_default(); // Remove the principal from all existing lists to avoid duplicates permission_config @@ -163,14 +149,14 @@ impl PermissionManager { } // Serialize the updated permission map and write it back to the config file - let yaml_content = serde_yaml::to_string(&self.permission_map) - .expect("Failed to serialize permission config"); - fs::write(&self.config_path, yaml_content).expect("Failed to write to permission.yaml"); + let yaml_content = serde_yaml::to_string(&*map).unwrap(); + fs::write(&self.config_path, yaml_content).unwrap(); } /// Removes all entries where the principal name starts with the given extension name. - pub fn remove_extension(&mut self, extension_name: &str) { - for permission_config in self.permission_map.values_mut() { + pub fn remove_extension(&self, extension_name: &str) { + let mut map = self.permission_map.write().unwrap(); + for permission_config in map.values_mut() { permission_config .always_allow .retain(|p| !p.starts_with(extension_name)); @@ -182,34 +168,33 @@ impl PermissionManager { .retain(|p| !p.starts_with(extension_name)); } - let yaml_content = serde_yaml::to_string(&self.permission_map) - .expect("Failed to serialize permission config"); - fs::write(&self.config_path, yaml_content).expect("Failed to write to permission.yaml"); + let yaml_content = serde_yaml::to_string(&*map).unwrap(); + fs::write(&self.config_path, yaml_content).unwrap(); } } #[cfg(test)] mod tests { use super::*; - use tempfile::NamedTempFile; + use tempfile::TempDir; // Helper function to create a test instance of PermissionManager with a temp dir - fn create_test_permission_manager() -> PermissionManager { - let temp_file = NamedTempFile::new().unwrap(); - let temp_path = temp_file.path(); - PermissionManager::new(temp_path) + fn create_test_permission_manager() -> (PermissionManager, TempDir) { + let temp_dir = TempDir::new().unwrap(); + let manager = PermissionManager::new(temp_dir.path().to_path_buf()); + (manager, temp_dir) } #[test] fn test_get_permission_names_empty() { - let manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); assert!(manager.get_permission_names().is_empty()); } #[test] fn test_update_user_permission() { - let mut manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); manager.update_user_permission("tool1", PermissionLevel::AlwaysAllow); let permission = manager.get_user_permission("tool1"); @@ -218,7 +203,7 @@ mod tests { #[test] fn test_update_smart_approve_permission() { - let mut manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); manager.update_smart_approve_permission("tool2", PermissionLevel::AskBefore); let permission = manager.get_smart_approve_permission("tool2"); @@ -227,7 +212,7 @@ mod tests { #[test] fn test_get_permission_not_found() { - let manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); let permission = manager.get_user_permission("non_existent_tool"); assert_eq!(permission, None); @@ -235,7 +220,7 @@ mod tests { #[test] fn test_permission_levels() { - let mut manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); manager.update_user_permission("tool4", PermissionLevel::AlwaysAllow); manager.update_user_permission("tool5", PermissionLevel::AskBefore); @@ -258,7 +243,7 @@ mod tests { #[test] fn test_permission_update_replaces_existing_level() { - let mut manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); // Initially AlwaysAllow manager.update_user_permission("tool7", PermissionLevel::AlwaysAllow); @@ -275,7 +260,8 @@ mod tests { ); // Ensure it's removed from other levels - let config = manager.permission_map.get(USER_PERMISSION).unwrap(); + let map = manager.permission_map.read().unwrap(); + let config = map.get(USER_PERMISSION).unwrap(); assert!(!config.always_allow.contains(&"tool7".to_string())); assert!(!config.ask_before.contains(&"tool7".to_string())); assert!(config.never_allow.contains(&"tool7".to_string())); @@ -283,7 +269,7 @@ mod tests { #[test] fn test_remove_extension() { - let mut manager = create_test_permission_manager(); + let (manager, _temp_dir) = create_test_permission_manager(); manager.update_user_permission("prefix__tool1", PermissionLevel::AlwaysAllow); manager.update_user_permission("nonprefix__tool2", PermissionLevel::AlwaysAllow); manager.update_user_permission("prefix__tool3", PermissionLevel::AskBefore); @@ -291,7 +277,8 @@ mod tests { // Remove entries starting with "prefix" manager.remove_extension("prefix"); - let config = manager.permission_map.get(USER_PERMISSION).unwrap(); + let map = manager.permission_map.read().unwrap(); + let config = map.get(USER_PERMISSION).unwrap(); // Verify entries with "prefix" are removed assert!(!config.always_allow.contains(&"prefix__tool1".to_string())); diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index 99d9b3d0d29c..a6cc42e38778 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -1,9 +1,10 @@ -use crate::agents::extension::PlatformExtensionContext; -use crate::agents::Agent; +use crate::agents::{Agent, AgentConfig}; use crate::config::paths::Paths; -use crate::config::Config; +use crate::config::permission::PermissionManager; +use crate::config::{Config, GooseMode}; use crate::scheduler::Scheduler; use crate::scheduler_trait::SchedulerTrait; +use crate::session::SessionManager; use anyhow::Result; use lru::LruCache; use std::num::NonZeroUsize; @@ -18,25 +19,17 @@ static AGENT_MANAGER: OnceCell> = OnceCell::const_new(); pub struct AgentManager { sessions: Arc>>>, scheduler: Arc, + session_manager: Arc, default_provider: Arc>>>, } impl AgentManager { - #[cfg(test)] - pub fn reset_for_test() { - unsafe { - // Cast away the const to get mutable access - // This is safe in test context where we control execution with #[serial] - let cell_ptr = &AGENT_MANAGER as *const OnceCell> - as *mut OnceCell>; - let _ = (*cell_ptr).take(); - } - } - - async fn new(max_sessions: Option) -> Result { - let schedule_file_path = Paths::data_dir().join("schedule.json"); - - let scheduler = Scheduler::new(schedule_file_path).await?; + pub async fn new( + session_manager: Arc, + schedule_file_path: std::path::PathBuf, + max_sessions: Option, + ) -> Result { + let scheduler = Scheduler::new(schedule_file_path, session_manager.clone()).await?; let capacity = NonZeroUsize::new(max_sessions.unwrap_or(DEFAULT_MAX_SESSION)) .unwrap_or_else(|| NonZeroUsize::new(100).unwrap()); @@ -44,6 +37,7 @@ impl AgentManager { let manager = Self { sessions: Arc::new(RwLock::new(LruCache::new(capacity))), scheduler, + session_manager, default_provider: Arc::new(RwLock::new(None)), }; @@ -56,7 +50,10 @@ impl AgentManager { let max_sessions = Config::global() .get_goose_max_active_agents() .unwrap_or(DEFAULT_MAX_SESSION); - let manager = Self::new(Some(max_sessions)).await?; + let schedule_file_path = Paths::data_dir().join("schedule.json"); + let session_manager = Arc::new(SessionManager::instance()); + let manager = + Self::new(session_manager, schedule_file_path, Some(max_sessions)).await?; Ok(Arc::new(manager)) }) .await @@ -67,6 +64,11 @@ impl AgentManager { Arc::clone(&self.scheduler) } + /// Get the shared SessionManager for session-only operations + pub fn session_manager(&self) -> &SessionManager { + &self.session_manager + } + pub async fn set_default_provider(&self, provider: Arc) { debug!("Setting default provider on AgentManager"); *self.default_provider.write().await = Some(provider); @@ -80,15 +82,15 @@ impl AgentManager { } } - let agent = Arc::new(Agent::new()); - agent.set_scheduler(Arc::clone(&self.scheduler)).await; - agent - .extension_manager - .set_context(PlatformExtensionContext { - session_id: Some(session_id.clone()), - extension_manager: Some(Arc::downgrade(&agent.extension_manager)), - }) - .await; + let mode = Config::global().get_goose_mode().unwrap_or(GooseMode::Auto); + let permission_manager = PermissionManager::instance(); + let config = AgentConfig::new( + Arc::clone(&self.session_manager), + permission_manager, + mode, + Arc::clone(&self.scheduler), + ); + let agent = Arc::new(Agent::with_config(config)); if let Some(provider) = &*self.default_provider.read().await { agent .update_provider(Arc::clone(provider), &session_id) @@ -124,10 +126,21 @@ impl AgentManager { #[cfg(test)] mod tests { - use serial_test::serial; use std::sync::Arc; + use tempfile::TempDir; + + use crate::execution::SessionExecutionMode; + use crate::session::SessionManager; - use crate::execution::{manager::AgentManager, SessionExecutionMode}; + use super::AgentManager; + + async fn create_test_manager(temp_dir: &TempDir) -> AgentManager { + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let schedule_path = temp_dir.path().join("schedule.json"); + AgentManager::new(session_manager, schedule_path, Some(100)) + .await + .unwrap() + } #[test] fn test_execution_mode_constructors() { @@ -150,10 +163,9 @@ mod tests { } #[tokio::test] - #[serial] async fn test_session_isolation() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = create_test_manager(&temp_dir).await; let session1 = uuid::Uuid::new_v4().to_string(); let session2 = uuid::Uuid::new_v4().to_string(); @@ -169,15 +181,12 @@ mod tests { let agent1_again = manager.get_or_create_agent(session1).await.unwrap(); assert!(Arc::ptr_eq(&agent1, &agent1_again)); - - AgentManager::reset_for_test(); } #[tokio::test] - #[serial] async fn test_session_limit() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = create_test_manager(&temp_dir).await; let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); @@ -193,10 +202,9 @@ mod tests { } #[tokio::test] - #[serial] async fn test_remove_session() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = create_test_manager(&temp_dir).await; let session = String::from("remove-test"); manager.get_or_create_agent(session.clone()).await.unwrap(); @@ -209,10 +217,9 @@ mod tests { } #[tokio::test] - #[serial] async fn test_concurrent_access() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = Arc::new(create_test_manager(&temp_dir).await); let session = String::from("concurrent-test"); let mut handles = vec![]; @@ -238,12 +245,11 @@ mod tests { } #[tokio::test] - #[serial] async fn test_concurrent_session_creation_race_condition() { // Test that concurrent attempts to create the same new session ID // result in only one agent being created (tests double-check pattern) - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = Arc::new(create_test_manager(&temp_dir).await); let session_id = String::from("race-condition-test"); // Spawn multiple tasks trying to create the same NEW session simultaneously @@ -273,24 +279,18 @@ mod tests { } #[tokio::test] - #[serial] async fn test_set_default_provider() { use crate::providers::testprovider::TestProvider; - use std::sync::Arc; - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = create_test_manager(&temp_dir).await; // Create a test provider for replaying (doesn't need inner provider) - let temp_file = format!( - "{}/test_provider_{}.json", - std::env::temp_dir().display(), - std::process::id() - ); + let temp_file = temp_dir.path().join("test_provider.json"); // Create an empty test provider (will fail on actual use but that's ok for this test) - let test_provider = TestProvider::new_replaying(&temp_file) - .unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap()); + std::fs::write(&temp_file, "{}").unwrap(); + let test_provider = TestProvider::new_replaying(temp_file.to_str().unwrap()).unwrap(); manager.set_default_provider(Arc::new(test_provider)).await; @@ -301,12 +301,11 @@ mod tests { } #[tokio::test] - #[serial] async fn test_eviction_updates_last_used() { - AgentManager::reset_for_test(); // Test that accessing a session updates its last_used timestamp // and affects eviction order - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = create_test_manager(&temp_dir).await; let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); @@ -336,11 +335,10 @@ mod tests { } #[tokio::test] - #[serial] async fn test_remove_nonexistent_session_error() { // Test that removing a non-existent session returns an error - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); + let temp_dir = TempDir::new().unwrap(); + let manager = create_test_manager(&temp_dir).await; let session = String::from("never-created"); let result = manager.remove_session(&session).await; diff --git a/crates/goose/src/hints/load_hints.rs b/crates/goose/src/hints/load_hints.rs index ec4b4375f5a7..be7e2d7ee09b 100644 --- a/crates/goose/src/hints/load_hints.rs +++ b/crates/goose/src/hints/load_hints.rs @@ -123,7 +123,7 @@ pub fn load_hint_files( mod tests { use super::*; use ignore::gitignore::GitignoreBuilder; - use std::fs::{self}; + use std::fs; use tempfile::TempDir; fn create_dummy_gitignore() -> Gitignore { diff --git a/crates/goose/src/permission/permission_confirmation.rs b/crates/goose/src/permission/permission_confirmation.rs index f56da1172000..59e3a8fefcdb 100644 --- a/crates/goose/src/permission/permission_confirmation.rs +++ b/crates/goose/src/permission/permission_confirmation.rs @@ -7,6 +7,7 @@ pub enum Permission { AllowOnce, Cancel, DenyOnce, + AlwaysDeny, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, ToSchema)] diff --git a/crates/goose/src/permission/permission_inspector.rs b/crates/goose/src/permission/permission_inspector.rs index 88b99a6b05d9..c105c87081dc 100644 --- a/crates/goose/src/permission/permission_inspector.rs +++ b/crates/goose/src/permission/permission_inspector.rs @@ -8,50 +8,27 @@ use anyhow::Result; use async_trait::async_trait; use std::collections::HashSet; use std::sync::Arc; -use tokio::sync::Mutex; /// Permission Inspector that handles tool permission checking pub struct PermissionInspector { - mode: Arc>, readonly_tools: HashSet, regular_tools: HashSet, - pub permission_manager: Arc>, + pub permission_manager: Arc, } impl PermissionInspector { pub fn new( - mode: GooseMode, readonly_tools: HashSet, regular_tools: HashSet, + permission_manager: Arc, ) -> Self { Self { - mode: Arc::new(Mutex::new(mode)), - readonly_tools, - regular_tools, - permission_manager: Arc::new(Mutex::new(PermissionManager::default())), - } - } - - pub fn with_permission_manager( - mode: GooseMode, - readonly_tools: HashSet, - regular_tools: HashSet, - permission_manager: Arc>, - ) -> Self { - Self { - mode: Arc::new(Mutex::new(mode)), readonly_tools, regular_tools, permission_manager, } } - /// Update the mode of this permission inspector - pub async fn update_mode(&self, new_mode: GooseMode) { - let mut mode = self.mode.lock().await; - *mode = new_mode; - } - /// Process inspection results into permission decisions /// This method takes all inspection results and converts them into a PermissionCheckResult /// that can be used by the agent to determine which tools to approve, deny, or ask for approval @@ -130,16 +107,16 @@ impl ToolInspector for PermissionInspector { &self, tool_requests: &[ToolRequest], _messages: &[Message], + goose_mode: GooseMode, ) -> Result> { let mut results = Vec::new(); - let permission_manager = self.permission_manager.lock().await; - let mode = self.mode.lock().await; + let permission_manager = &self.permission_manager; for request in tool_requests { if let Ok(tool_call) = &request.tool_call { let tool_name = &tool_call.name; - let action = match *mode { + let action = match goose_mode { GooseMode::Chat => continue, GooseMode::Auto => InspectionAction::Allow, GooseMode::Approve | GooseMode::SmartApprove => { @@ -174,7 +151,7 @@ impl ToolInspector for PermissionInspector { let reason = match &action { InspectionAction::Allow => { - if *mode == GooseMode::Auto { + if goose_mode == GooseMode::Auto { "Auto mode - all tools approved".to_string() } else if self.readonly_tools.contains(tool_name.as_ref()) { "Tool marked as read-only".to_string() diff --git a/crates/goose/src/posthog.rs b/crates/goose/src/posthog.rs index 540a4868b662..8584173a4478 100644 --- a/crates/goose/src/posthog.rs +++ b/crates/goose/src/posthog.rs @@ -395,7 +395,8 @@ async fn send_session_event(installation: &InstallationData) -> Result<(), Strin .insert_prop("db_schema_version", CURRENT_SCHEMA_VERSION) .ok(); - if let Ok(insights) = SessionManager::get_insights().await { + let session_manager = SessionManager::instance(); + if let Ok(insights) = session_manager.get_insights().await { event .insert_prop("total_sessions", insights.total_sessions) .ok(); diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 2a21a9484da2..eb31f1559776 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -135,10 +135,14 @@ pub struct Scheduler { jobs: Arc>, storage_path: PathBuf, running_tasks: Arc>, + session_manager: Arc, } impl Scheduler { - pub async fn new(storage_path: PathBuf) -> Result, SchedulerError> { + pub async fn new( + storage_path: PathBuf, + session_manager: Arc, + ) -> Result, SchedulerError> { let internal_scheduler = TokioJobScheduler::new() .await .map_err(|e| SchedulerError::SchedulerInternalError(e.to_string()))?; @@ -151,6 +155,7 @@ impl Scheduler { jobs, storage_path, running_tasks, + session_manager, }); arc_self.load_jobs_from_storage().await; @@ -498,7 +503,9 @@ impl Scheduler { sched_id: &str, limit: usize, ) -> Result, SchedulerError> { - let all_sessions = SessionManager::list_sessions() + let all_sessions = self + .session_manager + .list_sessions() .await .map_err(|e| SchedulerError::StorageError(io::Error::other(e)))?; @@ -740,12 +747,14 @@ async fn execute_job( } } - let session = SessionManager::create_session( - std::env::current_dir()?, - format!("Scheduled job: {}", job.id), - SessionType::Scheduled, - ) - .await?; + let session = agent + .session_manager() + .create_session( + std::env::current_dir()?, + format!("Scheduled job: {}", job.id), + SessionType::Scheduled, + ) + .await?; agent.update_provider(agent_provider, &session.id).await?; @@ -812,7 +821,9 @@ async fn execute_job( } } - SessionManager::update_session(&session.id) + agent + .session_manager() + .update(&session.id) .schedule_id(Some(job.id.clone())) .recipe(Some(recipe)) .apply() @@ -928,7 +939,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let storage_path = temp_dir.path().join("schedules.json"); let recipe_path = create_test_recipe(temp_dir.path(), "scheduled_job"); - let scheduler = Scheduler::new(storage_path).await.unwrap(); + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let scheduler = Scheduler::new(storage_path, session_manager).await.unwrap(); let job = ScheduledJob { id: "scheduled_job".to_string(), @@ -953,7 +965,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let storage_path = temp_dir.path().join("schedules.json"); let recipe_path = create_test_recipe(temp_dir.path(), "paused_job"); - let scheduler = Scheduler::new(storage_path).await.unwrap(); + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let scheduler = Scheduler::new(storage_path, session_manager).await.unwrap(); let job = ScheduledJob { id: "paused_job".to_string(), diff --git a/crates/goose/src/scheduler_trait.rs b/crates/goose/src/scheduler_trait.rs index 8122cab7f28f..41f11ee19967 100644 --- a/crates/goose/src/scheduler_trait.rs +++ b/crates/goose/src/scheduler_trait.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use std::path::PathBuf; +use std::sync::{Arc, LazyLock}; use crate::scheduler::{ScheduledJob, SchedulerError}; use crate::session::Session; @@ -39,3 +40,104 @@ pub trait SchedulerTrait: Send + Sync { sched_id: &str, ) -> Result)>, SchedulerError>; } + +const UNAVAILABLE_MESSAGE: &str = "Scheduler not available. This tool only works in server mode."; + +pub struct UnavailableScheduler; + +#[async_trait] +impl SchedulerTrait for UnavailableScheduler { + async fn add_scheduled_job( + &self, + _job: ScheduledJob, + _copy_recipe: bool, + ) -> Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn schedule_recipe( + &self, + _recipe_path: PathBuf, + _cron_schedule: Option, + ) -> anyhow::Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn list_scheduled_jobs(&self) -> Vec { + vec![] + } + + async fn remove_scheduled_job( + &self, + _id: &str, + _remove_recipe: bool, + ) -> Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn pause_schedule(&self, _id: &str) -> Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn unpause_schedule(&self, _id: &str) -> Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn run_now(&self, _id: &str) -> Result { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn sessions( + &self, + _sched_id: &str, + _limit: usize, + ) -> Result, SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn update_schedule( + &self, + _sched_id: &str, + _new_cron: String, + ) -> Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn kill_running_job(&self, _sched_id: &str) -> Result<(), SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } + + async fn get_running_job_info( + &self, + _sched_id: &str, + ) -> Result)>, SchedulerError> { + Err(SchedulerError::SchedulerInternalError( + UNAVAILABLE_MESSAGE.into(), + )) + } +} + +static UNAVAILABLE_SCHEDULER: LazyLock> = + LazyLock::new(|| Arc::new(UnavailableScheduler)); + +pub fn unavailable_scheduler() -> Arc { + Arc::clone(&UNAVAILABLE_SCHEDULER) +} diff --git a/crates/goose/src/security/security_inspector.rs b/crates/goose/src/security/security_inspector.rs index 3fb601d0d0fb..5e7ef6962309 100644 --- a/crates/goose/src/security/security_inspector.rs +++ b/crates/goose/src/security/security_inspector.rs @@ -1,6 +1,7 @@ use anyhow::Result; use async_trait::async_trait; +use crate::config::GooseMode; use crate::conversation::message::{Message, ToolRequest}; use crate::security::{SecurityManager, SecurityResult}; use crate::tool_inspection::{InspectionAction, InspectionResult, ToolInspector}; @@ -64,6 +65,7 @@ impl ToolInspector for SecurityInspector { &self, tool_requests: &[ToolRequest], messages: &[Message], + _goose_mode: GooseMode, ) -> Result> { let security_results = self .security_manager @@ -117,7 +119,10 @@ mod tests { tool_meta: None, }]; - let results = inspector.inspect(&tool_requests, &[]).await.unwrap(); + let results = inspector + .inspect(&tool_requests, &[], GooseMode::Approve) + .await + .unwrap(); // Results depend on whether security is enabled in config if inspector.is_enabled() { diff --git a/crates/goose/src/session/diagnostics.rs b/crates/goose/src/session/diagnostics.rs index 4c311eb2860e..59e02d42e9b2 100644 --- a/crates/goose/src/session/diagnostics.rs +++ b/crates/goose/src/session/diagnostics.rs @@ -1,13 +1,16 @@ use crate::config::paths::Paths; use crate::providers::utils::LOGS_TO_KEEP; use crate::session::SessionManager; -use std::fs::{self}; +use std::fs; use std::io::Cursor; use std::io::Write; use zip::write::FileOptions; use zip::ZipWriter; -pub async fn generate_diagnostics(session_id: &str) -> anyhow::Result> { +pub async fn generate_diagnostics( + session_manager: &SessionManager, + session_id: &str, +) -> anyhow::Result> { let logs_dir = Paths::in_state_dir("logs"); let config_dir = Paths::config_dir(); let config_path = config_dir.join("config.yaml"); @@ -45,7 +48,7 @@ pub async fn generate_diagnostics(session_id: &str) -> anyhow::Result> { zip.write_all(&fs::read(&path)?)?; } - let session_data = SessionManager::export_session(session_id).await?; + let session_data = session_manager.export_session(session_id).await?; zip.start_file("session.json", options)?; zip.write_all(session_data.as_bytes())?; diff --git a/crates/goose/src/session/mod.rs b/crates/goose/src/session/mod.rs index b8b7c8d5a28c..e67527904893 100644 --- a/crates/goose/src/session/mod.rs +++ b/crates/goose/src/session/mod.rs @@ -6,4 +6,6 @@ pub mod session_manager; pub use diagnostics::generate_diagnostics; pub use extension_data::{EnabledExtensionsState, ExtensionData, ExtensionState, TodoState}; -pub use session_manager::{Session, SessionInsights, SessionManager, SessionType}; +pub use session_manager::{ + Session, SessionInsights, SessionManager, SessionType, SessionUpdateBuilder, +}; diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index ecfbe7692ff2..77e4e5ca1035 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -9,13 +9,12 @@ use anyhow::Result; use chrono::{DateTime, Utc}; use rmcp::model::Role; use serde::{Deserialize, Serialize}; -use sqlx::sqlite::SqliteConnectOptions; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use sqlx::{Pool, Sqlite}; use std::collections::HashMap; use std::fs; use std::path::{Path, PathBuf}; -use std::sync::Arc; -use tokio::sync::OnceCell; +use std::sync::{Arc, LazyLock}; use tracing::{info, warn}; use utoipa::ToSchema; @@ -61,7 +60,8 @@ impl std::str::FromStr for SessionType { } } -static SESSION_STORAGE: OnceCell> = OnceCell::const_new(); +static SESSION_STORAGE: LazyLock> = + LazyLock::new(|| Arc::new(SessionStorage::new(Paths::data_dir()))); #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct Session { @@ -92,7 +92,8 @@ pub struct Session { pub model_config: Option, } -pub struct SessionUpdateBuilder { +pub struct SessionUpdateBuilder<'a> { + session_manager: &'a SessionManager, session_id: String, name: Option, user_set_name: Option, @@ -119,9 +120,10 @@ pub struct SessionInsights { pub total_tokens: i64, } -impl SessionUpdateBuilder { - fn new(session_id: String) -> Self { +impl<'a> SessionUpdateBuilder<'a> { + fn new(session_manager: &'a SessionManager, session_id: String) -> Self { Self { + session_manager, session_id, name: None, user_set_name: None, @@ -142,6 +144,10 @@ impl SessionUpdateBuilder { } } + pub async fn apply(self) -> Result<()> { + self.session_manager.apply_update_inner(self).await + } + pub fn user_provided_name(mut self, name: impl Into) -> Self { let name = name.into().trim().to_string(); if !name.is_empty() { @@ -232,99 +238,96 @@ impl SessionUpdateBuilder { self.model_config = Some(Some(model_config)); self } - - pub async fn apply(self) -> Result<()> { - SessionManager::apply_update(self).await - } } -pub struct SessionManager; +pub struct SessionManager { + storage: Arc, +} impl SessionManager { - pub async fn instance() -> Result> { - SESSION_STORAGE - .get_or_try_init(|| async { SessionStorage::new().await.map(Arc::new) }) - .await - .map(Arc::clone) + pub fn new(data_dir: PathBuf) -> Self { + Self { + storage: Arc::new(SessionStorage::new(data_dir)), + } + } + + pub fn instance() -> Self { + Self { + storage: Arc::clone(&SESSION_STORAGE), + } + } + + pub fn storage(&self) -> &Arc { + &self.storage } pub async fn create_session( + &self, working_dir: PathBuf, name: String, session_type: SessionType, ) -> Result { - Self::instance() - .await? + self.storage .create_session(working_dir, name, session_type) .await } - pub async fn get_session(id: &str, include_messages: bool) -> Result { - Self::instance() - .await? - .get_session(id, include_messages) - .await + pub async fn get_session(&self, id: &str, include_messages: bool) -> Result { + self.storage.get_session(id, include_messages).await } - pub fn update_session(id: &str) -> SessionUpdateBuilder { - SessionUpdateBuilder::new(id.to_string()) + pub fn update(&self, id: &str) -> SessionUpdateBuilder<'_> { + SessionUpdateBuilder::new(self, id.to_string()) } - async fn apply_update(builder: SessionUpdateBuilder) -> Result<()> { - Self::instance().await?.apply_update(builder).await + async fn apply_update_inner(&self, builder: SessionUpdateBuilder<'_>) -> Result<()> { + self.storage.apply_update(builder).await } - pub async fn add_message(id: &str, message: &Message) -> Result<()> { - Self::instance().await?.add_message(id, message).await + pub async fn add_message(&self, id: &str, message: &Message) -> Result<()> { + self.storage.add_message(id, message).await } - pub async fn replace_conversation(id: &str, conversation: &Conversation) -> Result<()> { - Self::instance() - .await? - .replace_conversation(id, conversation) - .await + pub async fn replace_conversation(&self, id: &str, conversation: &Conversation) -> Result<()> { + self.storage.replace_conversation(id, conversation).await } - pub async fn list_sessions() -> Result> { - Self::instance().await?.list_sessions().await + pub async fn list_sessions(&self) -> Result> { + self.storage.list_sessions().await } - pub async fn list_sessions_by_types(types: &[SessionType]) -> Result> { - Self::instance().await?.list_sessions_by_types(types).await + pub async fn list_sessions_by_types(&self, types: &[SessionType]) -> Result> { + self.storage.list_sessions_by_types(types).await } - pub async fn delete_session(id: &str) -> Result<()> { - Self::instance().await?.delete_session(id).await + pub async fn delete_session(&self, id: &str) -> Result<()> { + self.storage.delete_session(id).await } - pub async fn get_insights() -> Result { - Self::instance().await?.get_insights().await + pub async fn get_insights(&self) -> Result { + self.storage.get_insights().await } - pub async fn export_session(id: &str) -> Result { - Self::instance().await?.export_session(id).await + pub async fn export_session(&self, id: &str) -> Result { + self.storage.export_session(id).await } - pub async fn import_session(json: &str) -> Result { - Self::instance().await?.import_session(json).await + pub async fn import_session(&self, json: &str) -> Result { + self.storage.import_session(self, json).await } - pub async fn copy_session(session_id: &str, new_name: String) -> Result { - Self::instance() - .await? - .copy_session(session_id, new_name) - .await + pub async fn copy_session(&self, session_id: &str, new_name: String) -> Result { + self.storage.copy_session(self, session_id, new_name).await } - pub async fn truncate_conversation(session_id: &str, timestamp: i64) -> Result<()> { - Self::instance() - .await? + pub async fn truncate_conversation(&self, session_id: &str, timestamp: i64) -> Result<()> { + self.storage .truncate_conversation(session_id, timestamp) .await } - pub async fn maybe_update_name(id: &str, provider: Arc) -> Result<()> { - let session = Self::get_session(id, true).await?; + pub async fn maybe_update_name(&self, id: &str, provider: Arc) -> Result<()> { + let session = self.get_session(id, true).await?; if session.user_set_name { return Ok(()); @@ -342,24 +345,21 @@ impl SessionManager { if user_message_count <= MSG_COUNT_FOR_SESSION_NAME_GENERATION { let name = provider.generate_session_name(&conversation).await?; - Self::update_session(id) - .system_generated_name(name) - .apply() - .await + self.update(id).system_generated_name(name).apply().await } else { Ok(()) } } pub async fn search_chat_history( + &self, query: &str, limit: Option, after_date: Option>, before_date: Option>, exclude_session_id: Option, ) -> Result { - Self::instance() - .await? + self.storage .search_chat_history(query, limit, after_date, before_date, exclude_session_id) .await } @@ -367,16 +367,8 @@ impl SessionManager { pub struct SessionStorage { pool: Pool, -} - -pub fn ensure_session_dir() -> Result { - let session_dir = Paths::data_dir().join(SESSIONS_FOLDER); - - if !session_dir.exists() { - fs::create_dir_all(&session_dir)?; - } - - Ok(session_dir) + initialized: tokio::sync::OnceCell<()>, + session_dir: PathBuf, } fn role_to_string(role: &Role) -> &'static str { @@ -479,52 +471,61 @@ impl sqlx::FromRow<'_, sqlx::sqlite::SqliteRow> for Session { } impl SessionStorage { - async fn new() -> Result { - let session_dir = ensure_session_dir()?; - let db_path = session_dir.join(DB_NAME); - - let storage = if db_path.exists() { - Self::open(&db_path).await? - } else { - let storage = Self::create(&db_path).await?; - - if let Err(e) = storage.import_legacy(&session_dir).await { - warn!("Failed to import some legacy sessions: {}", e); - } - - storage - }; - - Ok(storage) - } + fn create_pool(path: &Path) -> Pool { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).unwrap(); + } - async fn get_pool(db_path: &Path, create_if_missing: bool) -> Result> { let options = SqliteConnectOptions::new() - .filename(db_path) - .create_if_missing(create_if_missing) + .filename(path) + .create_if_missing(true) .busy_timeout(std::time::Duration::from_secs(5)) .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); - sqlx::SqlitePool::connect_with(options).await.map_err(|e| { - anyhow::anyhow!( - "Failed to open SQLite database at '{}': {}", - db_path.display(), - e - ) - }) + SqlitePoolOptions::new().connect_lazy_with(options) } - async fn open(db_path: &Path) -> Result { - let pool = Self::get_pool(db_path, false).await?; + pub fn new(data_dir: PathBuf) -> Self { + let session_dir = data_dir.join(SESSIONS_FOLDER); + let db_path = session_dir.join(DB_NAME); + Self { + pool: Self::create_pool(&db_path), + initialized: tokio::sync::OnceCell::new(), + session_dir, + } + } - let storage = Self { pool }; - storage.run_migrations().await?; - Ok(storage) + async fn pool(&self) -> Result<&Pool> { + self.initialized + .get_or_try_init(|| async { + let schema_exists = sqlx::query_scalar::<_, bool>( + r#"SELECT EXISTS (SELECT name FROM sqlite_master WHERE type='table' AND name='schema_version')"#, + ) + .fetch_one(&self.pool) + .await + .unwrap_or(false); + + if schema_exists { + Self::run_migrations(&self.pool).await?; + } else { + Self::create_schema(&self.pool).await?; + if let Err(e) = Self::import_legacy(&self.pool, &self.session_dir).await { + warn!("Failed to import some legacy sessions: {}", e); + } + } + Ok::<(), anyhow::Error>(()) + }) + .await?; + Ok(&self.pool) } - async fn create(db_path: &Path) -> Result { - let pool = Self::get_pool(db_path, true).await?; + pub async fn create(session_dir: &Path) -> Result { + let storage = Self::new(session_dir.to_path_buf()); + Self::create_schema(&storage.pool).await?; + Ok(storage) + } + async fn create_schema(pool: &Pool) -> Result<()> { sqlx::query( r#" CREATE TABLE schema_version ( @@ -533,12 +534,12 @@ impl SessionStorage { ) "#, ) - .execute(&pool) + .execute(pool) .await?; sqlx::query("INSERT INTO schema_version (version) VALUES (?)") .bind(CURRENT_SCHEMA_VERSION) - .execute(&pool) + .execute(pool) .await?; sqlx::query( @@ -567,7 +568,7 @@ impl SessionStorage { ) "#, ) - .execute(&pool) + .execute(pool) .await?; sqlx::query( @@ -584,26 +585,26 @@ impl SessionStorage { ) "#, ) - .execute(&pool) + .execute(pool) .await?; sqlx::query("CREATE INDEX idx_messages_session ON messages(session_id)") - .execute(&pool) + .execute(pool) .await?; sqlx::query("CREATE INDEX idx_messages_timestamp ON messages(timestamp)") - .execute(&pool) + .execute(pool) .await?; sqlx::query("CREATE INDEX idx_sessions_updated ON sessions(updated_at DESC)") - .execute(&pool) + .execute(pool) .await?; sqlx::query("CREATE INDEX idx_sessions_type ON sessions(session_type)") - .execute(&pool) + .execute(pool) .await?; - Ok(Self { pool }) + Ok(()) } - async fn import_legacy(&self, session_dir: &PathBuf) -> Result<()> { + async fn import_legacy(pool: &Pool, session_dir: &PathBuf) -> Result<()> { use crate::session::legacy; let sessions = match legacy::list_sessions(session_dir) { @@ -623,7 +624,7 @@ impl SessionStorage { for (session_name, session_path) in sessions { match legacy::load_session(&session_name, &session_path) { - Ok(session) => match self.import_legacy_session(&session).await { + Ok(session) => match Self::import_legacy_session(pool, &session).await { Ok(_) => { imported_count += 1; info!(" ✓ Imported: {}", session_name); @@ -647,8 +648,8 @@ impl SessionStorage { Ok(()) } - async fn import_legacy_session(&self, session: &Session) -> Result<()> { - let mut tx = self.pool.begin().await?; + async fn import_legacy_session(pool: &Pool, session: &Session) -> Result<()> { + let mut tx = pool.begin().await?; let recipe_json = match &session.recipe { Some(recipe) => Some(serde_json::to_string(recipe)?), @@ -676,38 +677,38 @@ impl SessionStorage { ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) - .bind(&session.id) - .bind(&session.name) - .bind(session.user_set_name) - .bind(session.session_type.to_string()) - .bind(session.working_dir.to_string_lossy().as_ref()) - .bind(session.created_at) - .bind(session.updated_at) - .bind(serde_json::to_string(&session.extension_data)?) - .bind(session.total_tokens) - .bind(session.input_tokens) - .bind(session.output_tokens) - .bind(session.accumulated_total_tokens) - .bind(session.accumulated_input_tokens) - .bind(session.accumulated_output_tokens) - .bind(&session.schedule_id) - .bind(recipe_json) - .bind(user_recipe_values_json) - .bind(&session.provider_name) - .bind(model_config_json) - .execute(&mut *tx) - .await?; + .bind(&session.id) + .bind(&session.name) + .bind(session.user_set_name) + .bind(session.session_type.to_string()) + .bind(session.working_dir.to_string_lossy().as_ref()) + .bind(session.created_at) + .bind(session.updated_at) + .bind(serde_json::to_string(&session.extension_data)?) + .bind(session.total_tokens) + .bind(session.input_tokens) + .bind(session.output_tokens) + .bind(session.accumulated_total_tokens) + .bind(session.accumulated_input_tokens) + .bind(session.accumulated_output_tokens) + .bind(&session.schedule_id) + .bind(recipe_json) + .bind(user_recipe_values_json) + .bind(&session.provider_name) + .bind(model_config_json) + .execute(&mut *tx) + .await?; tx.commit().await?; if let Some(conversation) = &session.conversation { - self.replace_conversation(&session.id, conversation).await?; + Self::replace_conversation_inner(pool, &session.id, conversation).await?; } Ok(()) } - async fn run_migrations(&self) -> Result<()> { - let current_version = self.get_schema_version().await?; + async fn run_migrations(pool: &Pool) -> Result<()> { + let current_version = Self::get_schema_version(pool).await?; if current_version < CURRENT_SCHEMA_VERSION { info!( @@ -717,8 +718,8 @@ impl SessionStorage { for version in (current_version + 1)..=CURRENT_SCHEMA_VERSION { info!(" Applying migration v{}...", version); - self.apply_migration(version).await?; - self.update_schema_version(version).await?; + Self::apply_migration(pool, version).await?; + Self::update_schema_version(pool, version).await?; info!(" ✓ Migration v{} complete", version); } @@ -728,7 +729,7 @@ impl SessionStorage { Ok(()) } - async fn get_schema_version(&self) -> Result { + async fn get_schema_version(pool: &Pool) -> Result { let table_exists = sqlx::query_scalar::<_, bool>( r#" SELECT EXISTS ( @@ -737,7 +738,7 @@ impl SessionStorage { ) "#, ) - .fetch_one(&self.pool) + .fetch_one(pool) .await?; if !table_exists { @@ -745,21 +746,21 @@ impl SessionStorage { } let version = sqlx::query_scalar::<_, i32>("SELECT MAX(version) FROM schema_version") - .fetch_one(&self.pool) + .fetch_one(pool) .await?; Ok(version) } - async fn update_schema_version(&self, version: i32) -> Result<()> { + async fn update_schema_version(pool: &Pool, version: i32) -> Result<()> { sqlx::query("INSERT INTO schema_version (version) VALUES (?)") .bind(version) - .execute(&self.pool) + .execute(pool) .await?; Ok(()) } - async fn apply_migration(&self, version: i32) -> Result<()> { + async fn apply_migration(pool: &Pool, version: i32) -> Result<()> { match version { 1 => { sqlx::query( @@ -770,7 +771,7 @@ impl SessionStorage { ) "#, ) - .execute(&self.pool) + .execute(pool) .await?; } 2 => { @@ -779,7 +780,7 @@ impl SessionStorage { ALTER TABLE sessions ADD COLUMN user_recipe_values_json TEXT "#, ) - .execute(&self.pool) + .execute(pool) .await?; } 3 => { @@ -788,7 +789,7 @@ impl SessionStorage { ALTER TABLE messages ADD COLUMN metadata_json TEXT "#, ) - .execute(&self.pool) + .execute(pool) .await?; } 4 => { @@ -797,7 +798,7 @@ impl SessionStorage { ALTER TABLE sessions ADD COLUMN name TEXT DEFAULT '' "#, ) - .execute(&self.pool) + .execute(pool) .await?; sqlx::query( @@ -805,7 +806,7 @@ impl SessionStorage { ALTER TABLE sessions ADD COLUMN user_set_name BOOLEAN DEFAULT FALSE "#, ) - .execute(&self.pool) + .execute(pool) .await?; } 5 => { @@ -814,11 +815,11 @@ impl SessionStorage { ALTER TABLE sessions ADD COLUMN session_type TEXT NOT NULL DEFAULT 'user' "#, ) - .execute(&self.pool) + .execute(pool) .await?; sqlx::query("CREATE INDEX idx_sessions_type ON sessions(session_type)") - .execute(&self.pool) + .execute(pool) .await?; } 6 => { @@ -827,7 +828,7 @@ impl SessionStorage { ALTER TABLE sessions ADD COLUMN provider_name TEXT "#, ) - .execute(&self.pool) + .execute(pool) .await?; sqlx::query( @@ -835,7 +836,7 @@ impl SessionStorage { ALTER TABLE sessions ADD COLUMN model_config_json TEXT "#, ) - .execute(&self.pool) + .execute(pool) .await?; } _ => { @@ -852,7 +853,8 @@ impl SessionStorage { name: String, session_type: SessionType, ) -> Result { - let mut tx = self.pool.begin().await?; + let pool = self.pool().await?; + let mut tx = pool.begin().await?; let today = chrono::Utc::now().format("%Y%m%d").to_string(); let session = sqlx::query_as( @@ -887,6 +889,7 @@ impl SessionStorage { } async fn get_session(&self, id: &str, include_messages: bool) -> Result { + let pool = self.pool().await?; let mut session = sqlx::query_as::<_, Session>( r#" SELECT id, working_dir, name, description, user_set_name, session_type, created_at, updated_at, extension_data, @@ -899,7 +902,7 @@ impl SessionStorage { "#, ) .bind(id) - .fetch_optional(&self.pool) + .fetch_optional(pool) .await? .ok_or_else(|| anyhow::anyhow!("Session not found"))?; @@ -911,7 +914,7 @@ impl SessionStorage { let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM messages WHERE session_id = ?") .bind(&session.id) - .fetch_one(&self.pool) + .fetch_one(pool) .await? as usize; session.message_count = count; } @@ -920,7 +923,7 @@ impl SessionStorage { } #[allow(clippy::too_many_lines)] - async fn apply_update(&self, builder: SessionUpdateBuilder) -> Result<()> { + async fn apply_update(&self, builder: SessionUpdateBuilder<'_>) -> Result<()> { let mut updates = Vec::new(); let mut query = String::from("UPDATE sessions SET "); @@ -1022,7 +1025,8 @@ impl SessionStorage { q = q.bind(model_config_json); } - let mut tx = self.pool.begin().await?; + let pool = self.pool().await?; + let mut tx = pool.begin().await?; q = q.bind(&builder.session_id); q.execute(&mut *tx).await?; @@ -1031,11 +1035,12 @@ impl SessionStorage { } async fn get_conversation(&self, session_id: &str) -> Result { + let pool = self.pool().await?; let rows = sqlx::query_as::<_, (String, String, i64, Option)>( "SELECT role, content_json, created_timestamp, metadata_json FROM messages WHERE session_id = ? ORDER BY timestamp", ) .bind(session_id) - .fetch_all(&self.pool) + .fetch_all(pool) .await?; let mut messages = Vec::new(); @@ -1063,7 +1068,8 @@ impl SessionStorage { } async fn add_message(&self, session_id: &str, message: &Message) -> Result<()> { - let mut tx = self.pool.begin().await?; + let pool = self.pool().await?; + let mut tx = pool.begin().await?; let metadata_json = serde_json::to_string(&message.metadata)?; @@ -1090,12 +1096,12 @@ impl SessionStorage { Ok(()) } - async fn replace_conversation( - &self, + async fn replace_conversation_inner( + pool: &Pool, session_id: &str, conversation: &Conversation, ) -> Result<()> { - let mut tx = self.pool.begin().await?; + let mut tx = pool.begin().await?; sqlx::query("DELETE FROM messages WHERE session_id = ?") .bind(session_id) @@ -1124,6 +1130,15 @@ impl SessionStorage { Ok(()) } + pub async fn replace_conversation( + &self, + session_id: &str, + conversation: &Conversation, + ) -> Result<()> { + let pool = self.pool().await?; + Self::replace_conversation_inner(pool, session_id, conversation).await + } + async fn list_sessions_by_types(&self, types: &[SessionType]) -> Result> { if types.is_empty() { return Ok(Vec::new()); @@ -1152,7 +1167,8 @@ impl SessionStorage { q = q.bind(t.to_string()); } - q.fetch_all(&self.pool).await.map_err(Into::into) + let pool = self.pool().await?; + q.fetch_all(pool).await.map_err(Into::into) } async fn list_sessions(&self) -> Result> { @@ -1161,7 +1177,8 @@ impl SessionStorage { } async fn delete_session(&self, session_id: &str) -> Result<()> { - let mut tx = self.pool.begin().await?; + let pool = self.pool().await?; + let mut tx = pool.begin().await?; let exists = sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?)") @@ -1188,6 +1205,7 @@ impl SessionStorage { } async fn get_insights(&self) -> Result { + let pool = self.pool().await?; let row = sqlx::query_as::<_, (i64, Option)>( r#" SELECT COUNT(*) as total_sessions, @@ -1195,7 +1213,7 @@ impl SessionStorage { FROM sessions "#, ) - .fetch_one(&self.pool) + .fetch_one(pool) .await?; Ok(SessionInsights { @@ -1209,7 +1227,11 @@ impl SessionStorage { serde_json::to_string_pretty(&session).map_err(Into::into) } - async fn import_session(&self, json: &str) -> Result { + async fn import_session( + &self, + session_manager: &SessionManager, + json: &str, + ) -> Result { let import: Session = serde_json::from_str(json)?; let session = self @@ -1220,7 +1242,8 @@ impl SessionStorage { ) .await?; - let mut builder = SessionUpdateBuilder::new(session.id.clone()) + let mut builder = session_manager + .update(&session.id) .extension_data(import.extension_data) .total_tokens(import.total_tokens) .input_tokens(import.input_tokens) @@ -1236,7 +1259,7 @@ impl SessionStorage { builder = builder.user_provided_name(import.name.clone()); } - self.apply_update(builder).await?; + builder.apply().await?; if let Some(conversation) = import.conversation { self.replace_conversation(&session.id, &conversation) @@ -1246,7 +1269,12 @@ impl SessionStorage { self.get_session(&session.id, true).await } - async fn copy_session(&self, session_id: &str, new_name: String) -> Result { + async fn copy_session( + &self, + session_manager: &SessionManager, + session_id: &str, + new_name: String, + ) -> Result { let original_session = self.get_session(session_id, true).await?; let new_session = self @@ -1257,13 +1285,14 @@ impl SessionStorage { ) .await?; - let builder = SessionUpdateBuilder::new(new_session.id.clone()) + session_manager + .update(&new_session.id) .extension_data(original_session.extension_data) .schedule_id(original_session.schedule_id) .recipe(original_session.recipe) - .user_recipe_values(original_session.user_recipe_values); - - self.apply_update(builder).await?; + .user_recipe_values(original_session.user_recipe_values) + .apply() + .await?; if let Some(conversation) = original_session.conversation { self.replace_conversation(&new_session.id, &conversation) @@ -1274,10 +1303,11 @@ impl SessionStorage { } async fn truncate_conversation(&self, session_id: &str, timestamp: i64) -> Result<()> { + let pool = self.pool().await?; sqlx::query("DELETE FROM messages WHERE session_id = ? AND created_timestamp >= ?") .bind(session_id) .bind(timestamp) - .execute(&self.pool) + .execute(pool) .await?; Ok(()) @@ -1293,8 +1323,9 @@ impl SessionStorage { ) -> Result { use crate::session::chat_history_search::ChatHistorySearch; + let pool = self.pool().await?; ChatHistorySearch::new( - &self.pool, + pool, query, limit, after_date, @@ -1317,64 +1348,55 @@ mod tests { #[tokio::test] async fn test_concurrent_session_creation() { let temp_dir = TempDir::new().unwrap(); - let db_path = temp_dir.path().join("test_sessions.db"); - - let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap()); + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); let mut handles = vec![]; for i in 0..NUM_CONCURRENT_SESSIONS { - let session_storage = Arc::clone(&storage); + let sm = Arc::clone(&session_manager); let handle = tokio::spawn(async move { let working_dir = PathBuf::from(format!("/tmp/test_{}", i)); let description = format!("Test session {}", i); - let session = session_storage + let session = sm .create_session(working_dir.clone(), description, SessionType::User) .await .unwrap(); - session_storage - .add_message( - &session.id, - &Message { - id: None, - role: Role::User, - created: chrono::Utc::now().timestamp_millis(), - content: vec![MessageContent::text("hello world")], - metadata: Default::default(), - }, - ) - .await - .unwrap(); - - session_storage - .add_message( - &session.id, - &Message { - id: None, - role: Role::Assistant, - created: chrono::Utc::now().timestamp_millis(), - content: vec![MessageContent::text("sup world?")], - metadata: Default::default(), - }, - ) - .await - .unwrap(); + sm.add_message( + &session.id, + &Message { + id: None, + role: Role::User, + created: chrono::Utc::now().timestamp_millis(), + content: vec![MessageContent::text("hello world")], + metadata: Default::default(), + }, + ) + .await + .unwrap(); + + sm.add_message( + &session.id, + &Message { + id: None, + role: Role::Assistant, + created: chrono::Utc::now().timestamp_millis(), + content: vec![MessageContent::text("sup world?")], + metadata: Default::default(), + }, + ) + .await + .unwrap(); - session_storage - .apply_update( - SessionUpdateBuilder::new(session.id.clone()) - .user_provided_name(format!("Updated session {}", i)) - .total_tokens(Some(100 * i)), - ) + sm.update(&session.id) + .user_provided_name(format!("Updated session {}", i)) + .total_tokens(Some(100 * i)) + .apply() .await .unwrap(); - let updated = session_storage - .get_session(&session.id, true) - .await - .unwrap(); + let updated = sm.get_session(&session.id, true).await.unwrap(); assert_eq!(updated.message_count, 2); assert_eq!(updated.total_tokens, Some(100 * i)); @@ -1393,7 +1415,7 @@ mod tests { let unique_ids: std::collections::HashSet<_> = results.iter().collect(); assert_eq!(unique_ids.len(), NUM_CONCURRENT_SESSIONS as usize); - let sessions = storage.list_sessions().await.unwrap(); + let sessions = session_manager.list_sessions().await.unwrap(); assert_eq!(sessions.len(), NUM_CONCURRENT_SESSIONS as usize); for session in &sessions { @@ -1401,7 +1423,7 @@ mod tests { assert!(session.name.starts_with("Updated session")); } - let insights = storage.get_insights().await.unwrap(); + let insights = session_manager.get_insights().await.unwrap(); assert_eq!(insights.total_sessions, NUM_CONCURRENT_SESSIONS as usize); let expected_tokens = 100 * NUM_CONCURRENT_SESSIONS * (NUM_CONCURRENT_SESSIONS - 1) / 2; assert_eq!(insights.total_tokens, expected_tokens as i64); @@ -1418,10 +1440,9 @@ mod tests { const ASSISTANT_MESSAGE: &str = "test response"; let temp_dir = TempDir::new().unwrap(); - let db_path = temp_dir.path().join("test_export.db"); - let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap()); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); - let original = storage + let original = sm .create_session( PathBuf::from("/tmp/test"), DESCRIPTION.to_string(), @@ -1430,47 +1451,43 @@ mod tests { .await .unwrap(); - storage - .apply_update( - SessionUpdateBuilder::new(original.id.clone()) - .total_tokens(Some(TOTAL_TOKENS)) - .input_tokens(Some(INPUT_TOKENS)) - .output_tokens(Some(OUTPUT_TOKENS)) - .accumulated_total_tokens(Some(ACCUMULATED_TOKENS)), - ) - .await - .unwrap(); - - storage - .add_message( - &original.id, - &Message { - id: None, - role: Role::User, - created: chrono::Utc::now().timestamp_millis(), - content: vec![MessageContent::text(USER_MESSAGE)], - metadata: Default::default(), - }, - ) + sm.update(&original.id) + .total_tokens(Some(TOTAL_TOKENS)) + .input_tokens(Some(INPUT_TOKENS)) + .output_tokens(Some(OUTPUT_TOKENS)) + .accumulated_total_tokens(Some(ACCUMULATED_TOKENS)) + .apply() .await .unwrap(); - storage - .add_message( - &original.id, - &Message { - id: None, - role: Role::Assistant, - created: chrono::Utc::now().timestamp_millis(), - content: vec![MessageContent::text(ASSISTANT_MESSAGE)], - metadata: Default::default(), - }, - ) - .await - .unwrap(); + sm.add_message( + &original.id, + &Message { + id: None, + role: Role::User, + created: chrono::Utc::now().timestamp_millis(), + content: vec![MessageContent::text(USER_MESSAGE)], + metadata: Default::default(), + }, + ) + .await + .unwrap(); + + sm.add_message( + &original.id, + &Message { + id: None, + role: Role::Assistant, + created: chrono::Utc::now().timestamp_millis(), + content: vec![MessageContent::text(ASSISTANT_MESSAGE)], + metadata: Default::default(), + }, + ) + .await + .unwrap(); - let exported = storage.export_session(&original.id).await.unwrap(); - let imported = storage.import_session(&exported).await.unwrap(); + let exported = sm.export_session(&original.id).await.unwrap(); + let imported = sm.import_session(&exported).await.unwrap(); assert_ne!(imported.id, original.id); assert_eq!(imported.name, DESCRIPTION); @@ -1501,10 +1518,9 @@ mod tests { }"#; let temp_dir = TempDir::new().unwrap(); - let db_path = temp_dir.path().join("test_import.db"); - let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap()); + let sm = SessionManager::new(temp_dir.path().to_path_buf()); - let imported = storage.import_session(OLD_FORMAT_JSON).await.unwrap(); + let imported = sm.import_session(OLD_FORMAT_JSON).await.unwrap(); assert_eq!(imported.name, "Old format session"); assert!(imported.user_set_name); diff --git a/crates/goose/src/tool_inspection.rs b/crates/goose/src/tool_inspection.rs index bea409600f57..26264ee81e1f 100644 --- a/crates/goose/src/tool_inspection.rs +++ b/crates/goose/src/tool_inspection.rs @@ -40,6 +40,7 @@ pub trait ToolInspector: Send + Sync { &self, tool_requests: &[ToolRequest], messages: &[Message], + goose_mode: GooseMode, ) -> Result>; /// Whether this inspector is enabled @@ -74,6 +75,7 @@ impl ToolInspectionManager { &self, tool_requests: &[ToolRequest], messages: &[Message], + goose_mode: GooseMode, ) -> Result> { let mut all_results = Vec::new(); @@ -88,7 +90,7 @@ impl ToolInspectionManager { "Running tool inspector" ); - match inspector.inspect(tool_requests, messages).await { + match inspector.inspect(tool_requests, messages, goose_mode).await { Ok(results) => { tracing::debug!( inspector_name = inspector.name(), @@ -116,22 +118,6 @@ impl ToolInspectionManager { self.inspectors.iter().map(|i| i.name()).collect() } - /// Update the permission inspector's mode - pub async fn update_permission_inspector_mode(&self, mode: GooseMode) { - for inspector in &self.inspectors { - if inspector.name() == "permission" { - // Downcast to PermissionInspector to access update_mode method - if let Some(permission_inspector) = - inspector.as_any().downcast_ref::() - { - permission_inspector.update_mode(mode).await; - return; - } - } - } - tracing::warn!("Permission inspector not found for mode update"); - } - /// Update the permission manager for a specific tool pub async fn update_permission_manager( &self, @@ -144,9 +130,9 @@ impl ToolInspectionManager { if let Some(permission_inspector) = inspector.as_any().downcast_ref::() { - let mut permission_manager = - permission_inspector.permission_manager.lock().await; - permission_manager.update_user_permission(tool_name, permission_level); + permission_inspector + .permission_manager + .update_user_permission(tool_name, permission_level); return; } } diff --git a/crates/goose/src/tool_monitor.rs b/crates/goose/src/tool_monitor.rs index 6ba6ec4f4054..a6dc3df937b1 100644 --- a/crates/goose/src/tool_monitor.rs +++ b/crates/goose/src/tool_monitor.rs @@ -1,3 +1,4 @@ +use crate::config::GooseMode; use crate::conversation::message::{Message, ToolRequest}; use crate::tool_inspection::{InspectionAction, InspectionResult, ToolInspector}; use anyhow::Result; @@ -99,6 +100,7 @@ impl ToolInspector for RepetitionInspector { &self, tool_requests: &[ToolRequest], _messages: &[Message], + _goose_mode: GooseMode, ) -> Result> { let mut results = Vec::new(); diff --git a/crates/goose/tests/acp_integration_test.rs b/crates/goose/tests/acp_integration_test.rs deleted file mode 100644 index ed8e3cb9bef8..000000000000 --- a/crates/goose/tests/acp_integration_test.rs +++ /dev/null @@ -1,381 +0,0 @@ -mod common; - -use rmcp::transport::streamable_http_server::{ - session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService, -}; -use rmcp::{ - handler::server::router::tool::ToolRouter, model::*, tool, tool_handler, tool_router, - ErrorData as McpError, ServerHandler, -}; -use sacp::schema::{ - ContentBlock, ContentChunk, InitializeRequest, McpServer, McpServerHttp, NewSessionRequest, - PromptRequest, ProtocolVersion, RequestPermissionOutcome, RequestPermissionRequest, - RequestPermissionResponse, SelectedPermissionOutcome, SessionNotification, SessionUpdate, - StopReason, TextContent, -}; -use sacp::{ClientToAgent, JrConnectionCx}; -use std::collections::VecDeque; -use std::path::Path; -use std::process::Stdio; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use tokio::process::{Child, Command}; -use tokio::task::JoinHandle; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use wiremock::matchers::{method, path}; -use wiremock::{Mock, MockServer, ResponseTemplate}; - -/// Fake code returned by the MCP server - an LLM couldn't know this from memory -const FAKE_CODE: &str = "test-uuid-12345-67890"; - -#[tokio::test] -async fn test_acp_basic_completion() { - let prompt = "what is 1+1"; - let mock_server = setup_mock_openai(vec![( - format!(r#"\n{prompt}","role":"user""#), - include_str!("./test_data/openai_basic_response.txt"), - )]) - .await; - - run_acp_session( - &mock_server, - vec![], - &[], - tempfile::tempdir().unwrap().path(), - |cx, session_id, updates| async move { - let response = cx - .send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - - assert_eq!(response.stop_reason, StopReason::EndTurn); - wait_for_text(&updates, "2", Duration::from_secs(5)).await; - }, - ) - .await; -} - -#[tokio::test] -async fn test_acp_with_mcp_http_server() { - let prompt = "Use the get_code tool and output only its result."; - let (mcp_url, _handle) = spawn_mcp_http_server().await; - - let mock_server = setup_mock_openai(vec![ - ( - format!(r#"\n{prompt}","role":"user""#), - include_str!("./test_data/openai_tool_call_response.txt"), - ), - ( - format!(r#""content":"{FAKE_CODE}","role":"tool""#), - include_str!("./test_data/openai_tool_result_response.txt"), - ), - ]) - .await; - - run_acp_session( - &mock_server, - vec![McpServer::Http(McpServerHttp::new("lookup", &mcp_url))], - &[], - tempfile::tempdir().unwrap().path(), - |cx, session_id, updates| async move { - let response = cx - .send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - - assert_eq!(response.stop_reason, StopReason::EndTurn); - wait_for_text(&updates, FAKE_CODE, Duration::from_secs(5)).await; - }, - ) - .await; -} - -#[tokio::test] -async fn test_acp_with_builtin_and_mcp() { - let prompt = - "Search for get_code and text_editor tools. Use them to save the code to /tmp/result.txt."; - let (mcp_url, _handle) = spawn_mcp_http_server().await; - - let mock_server = setup_mock_openai(vec![ - ( - format!(r#"\n{prompt}","role":"user""#), - include_str!("./test_data/openai_builtin_search.txt"), - ), - ( - r#"lookup/get_code: Get the code"#.into(), - include_str!("./test_data/openai_builtin_read_modules.txt"), - ), - ( - r#"lookup[\"get_code\"]({}): string - Get the code"#.into(), - include_str!("./test_data/openai_builtin_execute.txt"), - ), - ( - r#"Successfully wrote to /tmp/result.txt"#.into(), - include_str!("./test_data/openai_builtin_final.txt"), - ), - ]) - .await; - - run_acp_session( - &mock_server, - vec![McpServer::Http(McpServerHttp::new("lookup", &mcp_url))], - &["code_execution", "developer"], - tempfile::tempdir().unwrap().path(), - |cx, session_id, updates| async move { - let response = cx - .send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - - assert_eq!(response.stop_reason, StopReason::EndTurn); - wait_for_text(&updates, FAKE_CODE, Duration::from_secs(10)).await; - }, - ) - .await; -} - -async fn wait_for_text( - updates: &Arc>>, - expected: &str, - timeout: Duration, -) { - let deadline = tokio::time::Instant::now() + timeout; - loop { - let actual = extract_text(&updates.lock().unwrap()); - if actual.contains(expected) { - return; - } - if tokio::time::Instant::now() > deadline { - assert_eq!(actual, expected); - return; - } - tokio::task::yield_now().await; - } -} - -/// Each entry is (expected_body_substring, response_body). -/// Session description requests are handled automatically. -async fn setup_mock_openai(exchanges: Vec<(String, &'static str)>) -> MockServer { - let mock_server = MockServer::start().await; - let queue: VecDeque<(String, &'static str)> = exchanges.into_iter().collect(); - let queue = Arc::new(Mutex::new(queue)); - - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with({ - let queue = queue.clone(); - move |req: &wiremock::Request| { - let body = String::from_utf8_lossy(&req.body); - - if body.contains("Reply with only a description in four words or less") { - return ResponseTemplate::new(200) - .insert_header("content-type", "application/json") - .set_body_string(include_str!( - "./test_data/openai_session_description.json" - )); - } - - let (expected, response) = { - let mut q = queue.lock().unwrap(); - match q.pop_front() { - Some(item) => item, - None => { - return ResponseTemplate::new(500) - .set_body_string(format!("unexpected request: {body}")); - } - } - }; - - if !body.contains(&expected) { - return ResponseTemplate::new(500).set_body_string(format!( - "expected body to contain: {expected}\nactual: {body}" - )); - } - - ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_string(response) - } - }) - .mount(&mock_server) - .await; - - mock_server -} - -fn extract_text(updates: &[SessionNotification]) -> String { - updates - .iter() - .filter_map(|n| match &n.update { - SessionUpdate::AgentMessageChunk(ContentChunk { - content: ContentBlock::Text(t), - .. - }) => Some(t.text.clone()), - _ => None, - }) - .collect() -} - -async fn spawn_goose_acp(mock_server: &MockServer, builtins: &[&str], data_root: &Path) -> Child { - let mut cmd = Command::new(&*common::GOOSE_BINARY); - cmd.args(["acp"]); - if !builtins.is_empty() { - cmd.arg("--with-builtin").arg(builtins.join(",")); - } - cmd.stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .env("GOOSE_PROVIDER", "openai") - .env("GOOSE_MODEL", "gpt-5-nano") - .env("GOOSE_MODE", "approve") - .env("OPENAI_HOST", mock_server.uri()) - .env("OPENAI_API_KEY", "test-key") - .env("GOOSE_PATH_ROOT", data_root) - .env( - "RUST_LOG", - std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()), - ) - .kill_on_drop(true) - .spawn() - .unwrap() -} - -async fn run_acp_session( - mock_server: &MockServer, - mcp_servers: Vec, - builtins: &[&str], - data_root: &Path, - test_fn: F, -) where - F: FnOnce( - JrConnectionCx, - sacp::schema::SessionId, - Arc>>, - ) -> Fut, - Fut: std::future::Future, -{ - let mut child = spawn_goose_acp(mock_server, builtins, data_root).await; - let work_dir = tempfile::tempdir().unwrap(); - let updates = Arc::new(Mutex::new(Vec::new())); - let outgoing = child.stdin.take().unwrap().compat_write(); - let incoming = child.stdout.take().unwrap().compat(); - - let transport = sacp::ByteStreams::new(outgoing, incoming); - - ClientToAgent::builder() - .on_receive_notification( - { - let updates = updates.clone(); - async move |notification: SessionNotification, _cx| { - updates.lock().unwrap().push(notification); - Ok(()) - } - }, - sacp::on_receive_notification!(), - ) - .on_receive_request( - async move |request: RequestPermissionRequest, request_cx, _connection_cx| { - let option_id = request.options.first().map(|opt| opt.option_id.clone()); - match option_id { - Some(id) => request_cx.respond(RequestPermissionResponse::new( - RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(id)), - )), - None => request_cx.respond(RequestPermissionResponse::new( - RequestPermissionOutcome::Cancelled, - )), - } - }, - sacp::on_receive_request!(), - ) - .connect_to(transport) - .unwrap() - .run_until({ - let updates = updates.clone(); - move |cx: JrConnectionCx| async move { - cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) - .block_task() - .await - .unwrap(); - - let session = cx - .send_request( - NewSessionRequest::new(work_dir.path().to_path_buf()) - .mcp_servers(mcp_servers), - ) - .block_task() - .await - .unwrap(); - - test_fn(cx.clone(), session.session_id, updates).await; - Ok(()) - } - }) - .await - .unwrap(); -} - -#[derive(Clone)] -struct Lookup { - tool_router: ToolRouter, -} - -#[tool_router] -impl Lookup { - fn new() -> Self { - Self { - tool_router: Self::tool_router(), - } - } - - /// Returns a fake code that an LLM couldn't know from memory - #[tool(description = "Get the code")] - fn get_code(&self) -> Result { - Ok(CallToolResult::success(vec![Content::text(FAKE_CODE)])) - } -} - -#[tool_handler] -impl ServerHandler for Lookup { - fn get_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: rmcp::model::ProtocolVersion::V_2025_03_26, - capabilities: ServerCapabilities::builder().enable_tools().build(), - server_info: Implementation { - name: "lookup".into(), - version: "1.0.0".into(), - ..Default::default() - }, - instructions: Some("Lookup server with get_code tool.".into()), - } - } -} - -async fn spawn_mcp_http_server() -> (String, JoinHandle<()>) { - let service = StreamableHttpService::new( - || Ok(Lookup::new()), - LocalSessionManager::default().into(), - StreamableHttpServerConfig::default(), - ); - let router = axum::Router::new().nest_service("/mcp", service); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let url = format!("http://{addr}/mcp"); - - let handle = tokio::spawn(async move { - axum::serve(listener, router).await.unwrap(); - }); - - (url, handle) -} diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 967f56d51e02..dc08e84603b6 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -15,11 +15,15 @@ mod tests { use async_trait::async_trait; use chrono::{DateTime, Utc}; use goose::agents::platform_tools::PLATFORM_MANAGE_SCHEDULE_TOOL_NAME; + use goose::agents::AgentConfig; + use goose::config::permission::PermissionManager; + use goose::config::GooseMode; use goose::scheduler::{ScheduledJob, SchedulerError}; use goose::scheduler_trait::SchedulerTrait; - use goose::session::Session; + use goose::session::{Session, SessionManager}; use std::path::PathBuf; use std::sync::Arc; + use tempfile::TempDir; struct MockScheduler { jobs: tokio::sync::Mutex>, @@ -114,12 +118,20 @@ mod tests { #[tokio::test] async fn test_schedule_management_tool_list() { - let agent = Agent::new(); + let temp_dir = TempDir::new().unwrap(); + let data_dir = temp_dir.path().to_path_buf(); + let session_manager = Arc::new(SessionManager::new(data_dir.clone())); + let permission_manager = Arc::new(PermissionManager::new(data_dir)); let mock_scheduler = Arc::new(MockScheduler::new()); - agent.set_scheduler(mock_scheduler.clone()).await; + let config = AgentConfig::new( + session_manager, + permission_manager, + GooseMode::Auto, + mock_scheduler, + ); + let agent = Agent::with_config(config); - // Test that the schedule management tool is available in the tools list - let tools = agent.list_tools(None).await; + let tools = agent.list_tools("test-session-id", None).await; let schedule_tool = tools .iter() .find(|tool| tool.name == PLATFORM_MANAGE_SCHEDULE_TOOL_NAME); @@ -139,7 +151,7 @@ mod tests { // Don't set scheduler - test that the tool still appears in the list // but would fail if actually called (which we can't test directly through public API) - let tools = agent.list_tools(None).await; + let tools = agent.list_tools("test-session-id", None).await; let schedule_tool = tools .iter() .find(|tool| tool.name == PLATFORM_MANAGE_SCHEDULE_TOOL_NAME); @@ -149,7 +161,9 @@ mod tests { #[tokio::test] async fn test_schedule_management_tool_in_platform_tools() { let agent = Agent::new(); - let tools = agent.list_tools(Some("platform".to_string())).await; + let tools = agent + .list_tools("test-session-id", Some("platform".to_string())) + .await; // Check that the schedule management tool is included in platform tools let schedule_tool = tools @@ -188,7 +202,7 @@ mod tests { #[tokio::test] async fn test_schedule_management_tool_schema_validation() { let agent = Agent::new(); - let tools = agent.list_tools(None).await; + let tools = agent.list_tools("test-session-id", None).await; let schedule_tool = tools .iter() .find(|tool| tool.name == PLATFORM_MANAGE_SCHEDULE_TOOL_NAME); @@ -303,7 +317,6 @@ mod tests { use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; - use goose::session::SessionManager; use rmcp::model::{CallToolRequestParam, Tool}; use rmcp::object; use std::path::PathBuf; @@ -375,12 +388,14 @@ mod tests { let provider = Arc::new(MockToolProvider::new()); let user_message = Message::user().with_text("Hello"); - let session = SessionManager::create_session( - PathBuf::default(), - "max-turn-test".to_string(), - SessionType::Hidden, - ) - .await?; + let session = agent + .session_manager() + .create_session( + PathBuf::default(), + "max-turn-test".to_string(), + SessionType::Hidden, + ) + .await?; agent.update_provider(provider, &session.id).await?; @@ -447,10 +462,15 @@ mod tests { #[cfg(test)] mod extension_manager_tests { use super::*; - use goose::agents::extension::{ExtensionConfig, PlatformExtensionContext}; + use goose::agents::extension::ExtensionConfig; use goose::agents::extension_manager_extension::{ MANAGE_EXTENSIONS_TOOL_NAME, SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, }; + use goose::agents::AgentConfig; + use goose::config::permission::PermissionManager; + use goose::config::GooseMode; + use goose::scheduler_trait::unavailable_scheduler; + use goose::session::SessionManager; async fn setup_agent_with_extension_manager() -> Agent { // Add the TODO extension to the config so it can be discovered by search_available_extensions @@ -468,15 +488,17 @@ mod tests { }; set_extension(todo_extension_entry); - let agent = Agent::new(); + // Create agent with session_id from the start + let temp_dir = tempfile::tempdir().unwrap(); + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let config = AgentConfig::new( + session_manager, + PermissionManager::instance(), + GooseMode::Auto, + unavailable_scheduler(), + ); - agent - .extension_manager - .set_context(PlatformExtensionContext { - session_id: Some("test_session".to_string()), - extension_manager: Some(Arc::downgrade(&agent.extension_manager)), - }) - .await; + let agent = Agent::with_config(config); // Now add the extension manager platform extension let ext_config = ExtensionConfig::Platform { @@ -496,7 +518,7 @@ mod tests { #[tokio::test] async fn test_extension_manager_tools_available() { let agent = setup_agent_with_extension_manager().await; - let tools = agent.list_tools(None).await; + let tools = agent.list_tools("test-session-id", None).await; // Note: Tool names are prefixed with the normalized extension name "extensionmanager" // not the display name "Extension Manager" diff --git a/crates/goose/tests/common.rs b/crates/goose/tests/common.rs deleted file mode 100644 index 137179319dae..000000000000 --- a/crates/goose/tests/common.rs +++ /dev/null @@ -1,43 +0,0 @@ -use std::path::PathBuf; -use std::process::Command; -use std::sync::LazyLock; - -/// Build a binary from a package and return its path. -pub fn build_binary(package: &str, bin_name: &str) -> PathBuf { - let output = Command::new("cargo") - .args([ - "build", - "-p", - package, - "--bin", - bin_name, - "--message-format=json", - ]) - .output() - .expect("failed to build binary"); - - if !output.status.success() { - panic!("build failed: {}", String::from_utf8_lossy(&output.stderr)); - } - - String::from_utf8_lossy(&output.stdout) - .lines() - .filter_map(|line| serde_json::from_str::(line).ok()) - .filter(|msg| msg["reason"] == "compiler-artifact") - .filter(|msg| msg["target"]["name"] == bin_name) - .filter(|msg| { - msg["target"]["kind"] - .as_array() - .map(|k| k.iter().any(|v| v == "bin")) - .unwrap_or(false) - }) - .filter_map(|msg| msg["executable"].as_str().map(PathBuf::from)) - .next() - .expect("failed to find binary path in cargo output") -} - -#[allow(dead_code)] -pub static GOOSE_BINARY: LazyLock = LazyLock::new(|| build_binary("goose-cli", "goose")); -#[allow(dead_code)] -pub static CAPTURE_BINARY: LazyLock = - LazyLock::new(|| build_binary("goose-test", "capture")); diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 08849135815f..01bfbf9a5540 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -1,4 +1,4 @@ -mod common; +use serde::Deserialize; use std::collections::HashMap; use std::fs::File; @@ -20,6 +20,21 @@ use async_trait::async_trait; use goose::conversation::message::Message; use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; +use once_cell::sync::Lazy; +use std::process::Command; + +#[derive(Deserialize)] +struct CargoBuildMessage { + reason: String, + target: Target, + executable: String, +} + +#[derive(Deserialize)] +struct Target { + name: String, + kind: Vec, +} #[derive(Clone)] pub struct MockProvider { @@ -60,6 +75,44 @@ impl Provider for MockProvider { } } +fn build_and_get_binary_path() -> PathBuf { + let output = Command::new("cargo") + .args([ + "build", + "--frozen", + "-p", + "goose-test", + "--bin", + "capture", + "--message-format=json", + ]) + .output() + .expect("failed to build binary"); + + if !output.status.success() { + panic!("build failed: {}", String::from_utf8_lossy(&output.stderr)); + } + + String::from_utf8_lossy(&output.stdout) + .lines() + .map(serde_json::from_str::) + .filter_map(Result::ok) + .filter(|message| message.reason == "compiler-artifact") + .filter_map(|message| { + if message.target.name == "capture" + && message.target.kind.contains(&String::from("bin")) + { + Some(PathBuf::from(message.executable)) + } else { + None + } + }) + .next() + .expect("failed to parse binary path") +} + +static REPLAY_BINARY_PATH: Lazy = Lazy::new(build_and_get_binary_path); + enum TestMode { Record, Playback, @@ -161,7 +214,7 @@ async fn test_replayed_session( TestMode::Record => "record", TestMode::Playback => "playback", }; - let cmd = common::CAPTURE_BINARY.to_string_lossy().to_string(); + let cmd = REPLAY_BINARY_PATH.to_string_lossy().to_string(); let mut args = vec!["stdio", mode_arg] .into_iter() .map(str::to_string) @@ -203,7 +256,11 @@ async fn test_replayed_session( let provider = Arc::new(tokio::sync::Mutex::new(Some(Arc::new(MockProvider { model_config: ModelConfig::new("test-model").unwrap(), }) as Arc))); - let extension_manager = ExtensionManager::new(provider); + let temp_dir = tempfile::tempdir().unwrap(); + let session_manager = Arc::new(goose::session::SessionManager::new( + temp_dir.path().to_path_buf(), + )); + let extension_manager = Arc::new(ExtensionManager::new(provider, session_manager)); #[allow(clippy::redundant_closure_call)] let result = (async || -> Result<(), Box> { @@ -215,7 +272,7 @@ async fn test_replayed_session( arguments: tool_call.arguments, }; let result = extension_manager - .dispatch_tool_call(tool_call, CancellationToken::default()) + .dispatch_tool_call("test-session-id", tool_call, CancellationToken::default()) .await; let tool_result = result?; diff --git a/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper b/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper index 39b1f2561c32..e723208464cc 100644 --- a/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper +++ b/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper @@ -10,17 +10,17 @@ STDERR: at crates/goose-mcp/src/developer/analyze/cache.rs:26 STDERR: STDOUT: {"jsonrpc":"2.0","id":0,"result":{"protocolVersion":"2025-03-26","capabilities":{"prompts":{},"tools":{}},"serverInfo":{"name":"goose-developer","version":"1.16.0"},"instructions":" The developer extension gives you the capabilities to edit code files and run shell commands,\n and can be used to solve a wide range of problems.\n\nYou can use the shell tool to run any command that would work on the relevant operating system.\nUse the shell tool as needed to locate files or interact with the project.\n\nLeverage `analyze` through `return_last_only=true` subagents for deep codebase understanding with lean context\n- delegate analysis, retain summaries\n\nYour windows/screen tools can be used for visual debugging. You should not use these tools unless\nprompted to, but you can mention they are available if they are relevant.\n\nAlways prefer ripgrep (rg -C 3) to grep.\n\noperating system: macos\ncurrent directory: /Users/douwe/proj/goose/crates/goose\nshell: /bin/zsh\n\n \nAdditional Text Editor Tool Instructions:\n\nPerform text editing operations on files.\n\nThe `command` parameter specifies the operation to perform. Allowed options are:\n- `view`: View the content of a file.\n- `write`: Create or overwrite a file with the given content\n- `str_replace`: Replace text in one or more files.\n- `insert`: Insert text at a specific line location in the file.\n- `undo_edit`: Undo the last edit made to a file.\n\nTo use the write command, you must specify `file_text` which will become the new content of the file. Be careful with\nexisting files! This is a full overwrite, so you must include everything - not just sections you are modifying.\n\nTo use the str_replace command to edit multiple files, use the `diff` parameter with a unified diff.\nTo use the str_replace command to edit one file, you must specify both `old_str` and `new_str` - the `old_str` needs to exactly match one\nunique section of the original file, including any whitespace. Make sure to include enough context that the match is not\nambiguous. The entire original string will be replaced with `new_str`\n\nWhen possible, batch file edits together by using a multi-file unified `diff` within a single str_replace tool call.\n\nTo use the insert command, you must specify both `insert_line` (the line number after which to insert, 0 for beginning, -1 for end)\nand `new_str` (the text to insert).\n\n\n\nAdditional Shell Tool Instructions:\nExecute a command in the shell.\n\nThis will return the output and error concatenated into a single string, as\nyou would see from running on the command line. There will also be an indication\nof if the command succeeded or failed.\n\nAvoid commands that produce a large amount of output, and consider piping those outputs to files.\n\n**Important**: Each shell command runs in its own process. Things like directory changes or\nsourcing files do not persist between tool calls. So you may need to repeat them each time by\nstringing together commands.\n\nIf fetching web content, consider adding Accept: text/markdown header\nIf you need to run a long lived command, background it - e.g. `uvicorn main:app &` so that\nthis tool does not run indefinitely.\n\n**Important**: Use ripgrep - `rg` - exclusively when you need to locate a file or a code reference,\nother solutions may produce too large output because of hidden files! For example *do not* use `find` or `ls -r`\n - List files by name: `rg --files | rg `\n - List files that contain a regex: `rg '' -l`\n\n - Multiple commands: Use && to chain commands, avoid newlines\n - Example: `cd example && ls` or `source env/bin/activate && pip install numpy`\n"}} STDIN: {"jsonrpc":"2.0","method":"notifications/initialized"} -STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"progressToken":0},"name":"text_editor","arguments":{"command":"view","path":"/tmp/goose_test/goose.txt"}}} +STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":0},"name":"text_editor","arguments":{"command":"view","path":"/tmp/goose_test/goose.txt"}}} STDOUT: {"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"resource","resource":{"uri":"file:///tmp/goose_test/goose.txt","mimeType":"text","text":"# goose\n"},"annotations":{"audience":["assistant"]}},{"type":"text","text":"### /tmp/goose_test/goose.txt\n```\n1: # goose\n```\n","annotations":{"audience":["user"],"priority":0.0}}],"isError":false}} -STDIN: {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"_meta":{"progressToken":1},"name":"text_editor","arguments":{"command":"str_replace","new_str":"# goose (modified by test)","old_str":"# goose","path":"/tmp/goose_test/goose.txt"}}} +STDIN: {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":1},"name":"text_editor","arguments":{"command":"str_replace","new_str":"# goose (modified by test)","old_str":"# goose","path":"/tmp/goose_test/goose.txt"}}} STDOUT: {"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"The file /tmp/goose_test/goose.txt has been edited, and the section now reads:\n```\n# goose (modified by test)\n```\n\nReview the changes above for errors. Undo and edit the file again if necessary!\n","annotations":{"audience":["assistant"]}},{"type":"text","text":"```\n# goose (modified by test)\n```\n","annotations":{"audience":["user"],"priority":0.2}}],"isError":false}} -STDIN: {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"_meta":{"progressToken":2},"name":"shell","arguments":{"command":"cat /tmp/goose_test/goose.txt"}}} +STDIN: {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":2},"name":"shell","arguments":{"command":"cat /tmp/goose_test/goose.txt"}}} STDERR: 2025-12-11T19:43:39.019022Z DEBUG goose_mcp::developer::rmcp_developer: Shell process spawned with PID: 78321 STDERR: at crates/goose-mcp/src/developer/rmcp_developer.rs:997 STDERR: STDOUT: {"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info","logger":"shell_tool","data":{"type":"shell_output","stream":"stdout","output":"# goose (modified by test)"}}} STDOUT: {"jsonrpc":"2.0","id":3,"result":{"content":[{"type":"text","text":"# goose (modified by test)\n","annotations":{"audience":["assistant"]}},{"type":"text","text":"# goose (modified by test)\n","annotations":{"audience":["user"],"priority":0.0}}],"isError":false}} -STDIN: {"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"_meta":{"progressToken":3},"name":"text_editor","arguments":{"command":"str_replace","new_str":"# goose","old_str":"# goose (modified by test)","path":"/tmp/goose_test/goose.txt"}}} +STDIN: {"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":3},"name":"text_editor","arguments":{"command":"str_replace","new_str":"# goose","old_str":"# goose (modified by test)","path":"/tmp/goose_test/goose.txt"}}} STDOUT: {"jsonrpc":"2.0","id":4,"result":{"content":[{"type":"text","text":"The file /tmp/goose_test/goose.txt has been edited, and the section now reads:\n```\n# goose\n```\n\nReview the changes above for errors. Undo and edit the file again if necessary!\n","annotations":{"audience":["assistant"]}},{"type":"text","text":"```\n# goose\n```\n","annotations":{"audience":["user"],"priority":0.2}}],"isError":false}} -STDIN: {"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"_meta":{"progressToken":4},"name":"list_windows","arguments":{}}} +STDIN: {"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":4},"name":"list_windows","arguments":{}}} STDOUT: {"jsonrpc":"2.0","id":5,"result":{"content":[{"type":"text","text":"Available windows:\nMenubar","annotations":{"audience":["assistant"]}},{"type":"text","text":"Available windows:\nMenubar","annotations":{"audience":["user"],"priority":0.0}}],"isError":false}} diff --git a/crates/goose/tests/mcp_replays/github-mcp-serverstdio b/crates/goose/tests/mcp_replays/github-mcp-serverstdio index 0acc0f27379a..99c92d895a5d 100644 --- a/crates/goose/tests/mcp_replays/github-mcp-serverstdio +++ b/crates/goose/tests/mcp_replays/github-mcp-serverstdio @@ -7,6 +7,6 @@ STDERR: time=2025-12-11T17:58:47.640-05:00 level=INFO msg="server session connec STDOUT: {"jsonrpc":"2.0","id":0,"result":{"capabilities":{"completions":{},"logging":{},"prompts":{"listChanged":true},"resources":{"listChanged":true},"tools":{"listChanged":true}},"instructions":"The GitHub MCP Server provides tools to interact with GitHub platform.\n\nTool selection guidance:\n\t1. Use 'list_*' tools for broad, simple retrieval and pagination of all items of a type (e.g., all issues, all PRs, all branches) with basic filtering.\n\t2. Use 'search_*' tools for targeted queries with specific criteria, keywords, or complex filters (e.g., issues with certain text, PRs by author, code containing functions).\n\nContext management:\n\t1. Use pagination whenever possible with batches of 5-10 items.\n\t2. Use minimal_output parameter set to true if the full information is not needed to accomplish a task.\n\nTool usage guidance:\n\t1. For 'search_*' tools: Use separate 'sort' and 'order' parameters if available for sorting results - do not include 'sort:' syntax in query strings. Query strings should contain only search criteria (e.g., 'org:google language:python'), not sorting instructions. Always call 'get_me' first to understand current user permissions and context. ## Issues\n\nCheck 'list_issue_types' first for organizations to use proper issue types. Use 'search_issues' before creating new issues to avoid duplicates. Always set 'state_reason' when closing issues. ## Pull Requests\n\nPR review workflow: Always use 'pull_request_review_write' with method 'create' to create a pending review, then 'add_comment_to_pending_review' to add comments, and finally 'pull_request_review_write' with method 'submit_pending' to submit the review for complex reviews with line-specific comments.\n\nBefore creating a pull request, search for pull request templates in the repository. Template files are called pull_request_template.md or they're located in '.github/PULL_REQUEST_TEMPLATE' directory. Use the template content to structure the PR description and then call create_pull_request tool.","protocolVersion":"2025-03-26","serverInfo":{"name":"github-mcp-server","title":"GitHub MCP Server","version":"0.24.1"}}} STDIN: {"jsonrpc":"2.0","method":"notifications/initialized"} STDERR: time=2025-12-11T17:58:47.642-05:00 level=INFO msg="session initialized" -STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"progressToken":0},"name":"get_file_contents","arguments":{"owner":"block","path":"README.md","repo":"goose","sha":"ab62b863c1666232a67048b6c4e10007a2a5b83c"}}} +STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":0},"name":"get_file_contents","arguments":{"owner":"block","path":"README.md","repo":"goose","sha":"ab62b863c1666232a67048b6c4e10007a2a5b83c"}}} STDOUT: {"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"text","text":"successfully downloaded text file (SHA: de9bdde7f260549bf3a083651842f30ab29cf4e9)"},{"type":"resource","resource":{"uri":"repo://block/goose/sha/ab62b863c1666232a67048b6c4e10007a2a5b83c/contents/README.md","mimeType":"text/plain; charset=utf-8","text":"\u003cdiv align=\"center\"\u003e\n\n# goose\n\n_a local, extensible, open source AI agent that automates engineering tasks_\n\n\u003cp align=\"center\"\u003e\n \u003ca href=\"https://opensource.org/licenses/Apache-2.0\"\u003e\n \u003cimg src=\"https://img.shields.io/badge/License-Apache_2.0-blue.svg\"\u003e\n \u003c/a\u003e\n \u003ca href=\"https://discord.gg/7GaTvbDwga\"\u003e\n \u003cimg src=\"https://img.shields.io/discord/1287729918100246654?logo=discord\u0026logoColor=white\u0026label=Join+Us\u0026color=blueviolet\" alt=\"Discord\"\u003e\n \u003c/a\u003e\n \u003ca href=\"https://github.com/block/goose/actions/workflows/ci.yml\"\u003e\n \u003cimg src=\"https://img.shields.io/github/actions/workflow/status/block/goose/ci.yml?branch=main\" alt=\"CI\"\u003e\n \u003c/a\u003e\n\u003c/p\u003e\n\u003c/div\u003e\n\ngoose is your on-machine AI agent, capable of automating complex development tasks from start to finish. More than just code suggestions, goose can build entire projects from scratch, write and execute code, debug failures, orchestrate workflows, and interact with external APIs - _autonomously_.\n\nWhether you're prototyping an idea, refining existing code, or managing intricate engineering pipelines, goose adapts to your workflow and executes tasks with precision.\n\nDesigned for maximum flexibility, goose works with any LLM and supports multi-model configuration to optimize performance and cost, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.\n\n[![Watch the video](https://github.com/user-attachments/assets/ddc71240-3928-41b5-8210-626dfb28af7a)](https://youtu.be/D-DpDunrbpo)\n\n# Quick Links\n- [Quickstart](https://block.github.io/goose/docs/quickstart)\n- [Installation](https://block.github.io/goose/docs/getting-started/installation)\n- [Tutorials](https://block.github.io/goose/docs/category/tutorials)\n- [Documentation](https://block.github.io/goose/docs/category/getting-started)\n\n\n# a little goose humor ðŸĶĒ\n\n\u003e Why did the developer choose goose as their AI agent?\n\u003e \n\u003e Because it always helps them \"migrate\" their code to production! 🚀\n\n# goose around with us\n- [Discord](https://discord.gg/block-opensource)\n- [YouTube](https://www.youtube.com/@goose-oss)\n- [LinkedIn](https://www.linkedin.com/company/goose-oss)\n- [Twitter/X](https://x.com/goose_oss)\n- [Bluesky](https://bsky.app/profile/opensource.block.xyz)\n- [Nostr](https://njump.me/opensource@block.xyz)\n"}}]}} STDERR: time=2025-12-11T17:58:48.133-05:00 level=INFO msg="server session disconnected" session_id="" diff --git a/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything b/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything index 811a47d3a862..eedc241097a7 100644 --- a/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything +++ b/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything @@ -3,21 +3,21 @@ STDERR: Starting default (STDIO) server... STDERR: Starting logs update interval STDOUT: {"result":{"protocolVersion":"2025-03-26","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{},"completions":{}},"serverInfo":{"name":"example-servers/everything","title":"Everything Example Server","version":"1.0.0"},"instructions":"Testing and demonstration server for MCP protocol features.\n\n## Resources\n\nResources 1-100 follow pattern: even IDs contain text, odd IDs contain binary data. Resources are paginated at 10 items per page with cursor-based navigation.\n\n## Key dependencies\n\nProgress notifications require `_meta.progressToken` in tool calls. Resource subscriptions generate updates every 10 seconds.\n\n## Performance characteristics\n\nServer generates automatic log messages every 20 seconds (filtered by current log level), stderr notifications every 30 seconds, and resource update notifications every 10 seconds for active subscriptions. Log level changes affect message filtering in real-time.\n\n## Multi-modal testing\n\n`complex_prompt` includes both text arguments and image content for testing client multi-modal handling. `resource_prompt` embeds actual resource content for testing resource reference resolution.\n\nArgument completion is available for prompt parameters and resource IDs. Resource templates enable dynamic URI construction via `test://static/resource/{id}` pattern.\n\n## Easter egg\n\nIf asked about server instructions, respond with \"🎉 Server instructions are working! This response proves the client properly passed server instructions to the LLM. This demonstrates MCP's instructions feature in action.\"\n"},"jsonrpc":"2.0","id":0} STDIN: {"jsonrpc":"2.0","method":"notifications/initialized"} -STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"progressToken":0},"name":"echo","arguments":{"message":"Hello, world!"}}} +STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":0},"name":"echo","arguments":{"message":"Hello, world!"}}} STDOUT: {"method":"notifications/message","params":{"level":"info","logger":"everything-server","data":"Client does not support MCP roots protocol"},"jsonrpc":"2.0"} STDOUT: {"result":{"content":[{"type":"text","text":"Echo: Hello, world!"}]},"jsonrpc":"2.0","id":1} -STDIN: {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"_meta":{"progressToken":1},"name":"add","arguments":{"a":1,"b":2}}} +STDIN: {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":1},"name":"add","arguments":{"a":1,"b":2}}} STDOUT: {"result":{"content":[{"type":"text","text":"The sum of 1 and 2 is 3."}]},"jsonrpc":"2.0","id":2} -STDIN: {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"_meta":{"progressToken":2},"name":"longRunningOperation","arguments":{"duration":1,"steps":5}}} +STDIN: {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":2},"name":"longRunningOperation","arguments":{"duration":1,"steps":5}}} STDOUT: {"method":"notifications/progress","params":{"progress":1,"total":5,"progressToken":2},"jsonrpc":"2.0"} STDOUT: {"method":"notifications/progress","params":{"progress":2,"total":5,"progressToken":2},"jsonrpc":"2.0"} STDOUT: {"method":"notifications/progress","params":{"progress":3,"total":5,"progressToken":2},"jsonrpc":"2.0"} STDOUT: {"method":"notifications/progress","params":{"progress":4,"total":5,"progressToken":2},"jsonrpc":"2.0"} STDOUT: {"method":"notifications/progress","params":{"progress":5,"total":5,"progressToken":2},"jsonrpc":"2.0"} STDOUT: {"result":{"content":[{"type":"text","text":"Long running operation completed. Duration: 1 seconds, Steps: 5."}]},"jsonrpc":"2.0","id":3} -STDIN: {"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"_meta":{"progressToken":3},"name":"structuredContent","arguments":{"location":"11238"}}} +STDIN: {"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":3},"name":"structuredContent","arguments":{"location":"11238"}}} STDOUT: {"result":{"content":[{"type":"text","text":"{\"temperature\":22.5,\"conditions\":\"Partly cloudy\",\"humidity\":65}"}],"structuredContent":{"temperature":22.5,"conditions":"Partly cloudy","humidity":65}},"jsonrpc":"2.0","id":4} -STDIN: {"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"_meta":{"progressToken":4},"name":"sampleLLM","arguments":{"maxTokens":100,"prompt":"Please provide a quote from The Great Gatsby"}}} +STDIN: {"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":4},"name":"sampleLLM","arguments":{"maxTokens":100,"prompt":"Please provide a quote from The Great Gatsby"}}} STDOUT: {"method":"sampling/createMessage","params":{"messages":[{"role":"user","content":{"type":"text","text":"Resource sampleLLM context: Please provide a quote from The Great Gatsby"}}],"systemPrompt":"You are a helpful test server.","maxTokens":100,"temperature":0.7,"includeContext":"thisServer"},"jsonrpc":"2.0","id":0} STDIN: {"jsonrpc":"2.0","id":0,"result":{"model":"mock","stopReason":"endTurn","role":"assistant","content":{"type":"text","text":"\"So we beat on, boats against the current, borne back ceaselessly into the past.\" — F. Scott Fitzgerald, The Great Gatsby (1925)"}}} STDOUT: {"result":{"content":[{"type":"text","text":"LLM sampling result: \"So we beat on, boats against the current, borne back ceaselessly into the past.\" — F. Scott Fitzgerald, The Great Gatsby (1925)"}]},"jsonrpc":"2.0","id":5} diff --git a/crates/goose/tests/mcp_replays/uvxmcp-server-fetch b/crates/goose/tests/mcp_replays/uvxmcp-server-fetch index 913f5603f807..7f8f86427766 100644 --- a/crates/goose/tests/mcp_replays/uvxmcp-server-fetch +++ b/crates/goose/tests/mcp_replays/uvxmcp-server-fetch @@ -1,5 +1,5 @@ STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{},"elicitation":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}} STDOUT: {"jsonrpc":"2.0","id":0,"result":{"protocolVersion":"2025-03-26","capabilities":{"experimental":{},"prompts":{"listChanged":false},"tools":{"listChanged":false}},"serverInfo":{"name":"mcp-fetch","version":"1.23.3"}}} STDIN: {"jsonrpc":"2.0","method":"notifications/initialized"} -STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"progressToken":0},"name":"fetch","arguments":{"url":"https://example.com"}}} +STDIN: {"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"_meta":{"goose-session-id":"test-session-id","progressToken":0},"name":"fetch","arguments":{"url":"https://example.com"}}} STDOUT: {"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"text","text":"Contents of https://example.com/:\nThis domain is for use in documentation examples without needing permission. Avoid use in operations.\n\n[Learn more](https://iana.org/domains/example)"}],"isError":false}} diff --git a/crates/goose/tests/tool_inspection_manager_tests.rs b/crates/goose/tests/tool_inspection_manager_tests.rs index 6701b1c3684f..af832d7afb8b 100644 --- a/crates/goose/tests/tool_inspection_manager_tests.rs +++ b/crates/goose/tests/tool_inspection_manager_tests.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; +use goose::config::GooseMode; use goose::conversation::message::{Message, ToolRequest}; use goose::tool_inspection::{ InspectionAction, InspectionResult, ToolInspectionManager, ToolInspector, @@ -26,6 +27,7 @@ impl ToolInspector for MockInspectorOk { &self, _tool_requests: &[ToolRequest], _messages: &[Message], + _goose_mode: GooseMode, ) -> Result> { Ok(self.results.clone()) } @@ -43,6 +45,7 @@ impl ToolInspector for MockInspectorErr { &self, _tool_requests: &[ToolRequest], _messages: &[Message], + _goose_mode: GooseMode, ) -> Result> { Err(anyhow!("simulated failure")) } @@ -83,7 +86,7 @@ async fn test_inspect_tools_aggregates_and_handles_errors() { // Act let results = manager - .inspect_tools(&tool_requests, &messages) + .inspect_tools(&tool_requests, &messages, GooseMode::Approve) .await .expect("inspect_tools should not fail when one inspector errors");