diff --git a/Cargo.lock b/Cargo.lock index 39ed1a471481..61380241397b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2986,6 +2986,7 @@ version = "1.21.0" dependencies = [ "anyhow", "assert-json-diff", + "async-trait", "axum 0.8.8", "fs-err", "futures", diff --git a/clippy-baselines/too_many_lines.txt b/clippy-baselines/too_many_lines.txt index ae7ca874e58c..e69d0ef3a1ce 100644 --- a/clippy-baselines/too_many_lines.txt +++ b/clippy-baselines/too_many_lines.txt @@ -23,3 +23,4 @@ crates/goose/src/providers/formats/google.rs::format_messages crates/goose/src/providers/formats/openai.rs::format_messages crates/goose/src/providers/formats/openai.rs::response_to_streaming_message crates/goose/src/providers/snowflake.rs::post +crates/goose/src/security/mod.rs::analyze_tool_requests diff --git a/crates/goose-acp/.gooseignore b/crates/goose-acp/.gooseignore new file mode 100644 index 000000000000..550684df254e --- /dev/null +++ b/crates/goose-acp/.gooseignore @@ -0,0 +1,18 @@ +# This file is created automatically if no .gooseignore exists. +# Customize or uncomment the patterns below instead of deleting the file. +# Removing it will simply cause goose to recreate it on the next start. +# +# Suggested patterns you can uncomment if desired: +# **/.ssh/** # block SSH keys and configs +# **/*.key # block loose private keys +# **/*.pem # block certificates/private keys +# **/.git/** # block git metadata entirely +# **/target/** # block Rust build artifacts +# **/node_modules/** # block JS/TS dependencies +# **/*.db # block local database files +# **/*.sqlite # block SQLite databases +# + +**/.env +**/.env.* +**/secrets.* diff --git a/crates/goose-acp/Cargo.toml b/crates/goose-acp/Cargo.toml index bce7afd9451c..ab8f3ce64cc0 100644 --- a/crates/goose-acp/Cargo.toml +++ b/crates/goose-acp/Cargo.toml @@ -26,6 +26,7 @@ url = { workspace = true } [dev-dependencies] assert-json-diff = "2.0.2" +async-trait = "0.1.89" wiremock = { workspace = true } tempfile = "3" test-case = { workspace = true } diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 5552b09f03ca..d3a184423f52 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -2,6 +2,8 @@ use anyhow::Result; use fs_err as fs; use goose::agents::extension::{Envs, PLATFORM_EXTENSIONS}; use goose::agents::{Agent, AgentConfig, ExtensionConfig, SessionConfig}; +use goose::config::base::CONFIG_YAML_NAME; +use goose::config::extensions::get_enabled_extensions_with_config; use goose::config::paths::Paths; use goose::config::permission::PermissionManager; use goose::config::Config; @@ -46,7 +48,7 @@ pub struct GooseAcpAgent { provider: Arc, } -pub struct GooseAcpConfig { +pub struct AcpServerConfig { pub provider: Arc, pub builtins: Vec, pub data_dir: std::path::PathBuf, @@ -276,8 +278,21 @@ async fn add_builtins(agent: &Agent, builtins: Vec) { } } } +async fn add_extensions(agent: &Agent, extensions: Vec) { + for extension in extensions { + let name = extension.name().to_string(); + match agent.add_extension(extension).await { + Ok(_) => info!(extension = %name, "extension loaded"), + Err(e) => warn!(extension = %name, error = %e, "extension load failed"), + } + } +} impl GooseAcpAgent { + pub fn permission_manager(&self) -> Arc { + Arc::clone(&self.agent.config.permission_manager) + } + pub async fn new(builtins: Vec) -> Result { let config = Config::global(); @@ -304,7 +319,7 @@ impl GooseAcpAgent { .get_goose_mode() .unwrap_or(goose::config::GooseMode::Auto); - Self::with_config(GooseAcpConfig { + Self::with_config(AcpServerConfig { provider, builtins, data_dir: Paths::data_dir(), @@ -314,8 +329,9 @@ impl GooseAcpAgent { .await } - pub async fn with_config(config: GooseAcpConfig) -> Result { + pub async fn with_config(config: AcpServerConfig) -> Result { let session_manager = Arc::new(SessionManager::new(config.data_dir)); + let config_dir = config.config_dir.clone(); let permission_manager = Arc::new(PermissionManager::new(config.config_dir)); let agent = Agent::with_config(AgentConfig::new( @@ -327,7 +343,12 @@ impl GooseAcpAgent { let agent_ptr = Arc::new(agent); + let config_path = config_dir.join(CONFIG_YAML_NAME); + let config_file = Config::new(&config_path, "goose")?; + let extensions = get_enabled_extensions_with_config(&config_file); + add_builtins(&agent_ptr, config.builtins).await; + add_extensions(&agent_ptr, extensions).await; Ok(Self { provider: config.provider.clone(), diff --git a/crates/goose-acp/tests/common_tests/mod.rs b/crates/goose-acp/tests/common_tests/mod.rs new file mode 100644 index 000000000000..2917e22d1bc4 --- /dev/null +++ b/crates/goose-acp/tests/common_tests/mod.rs @@ -0,0 +1,224 @@ +// Required when compiled as standalone test "common"; harmless warning when included as module. +#![recursion_limit = "256"] +#![allow(unused_attributes)] + +#[path = "../fixtures/mod.rs"] +pub mod fixtures; +use fixtures::{ + ExpectedSessionId, McpFixture, OpenAiFixture, PermissionDecision, Session, TestSessionConfig, + FAKE_CODE, +}; +use fs_err as fs; +use goose::config::base::CONFIG_YAML_NAME; +use goose::config::GooseMode; +use sacp::schema::{McpServer, McpServerHttp, ToolCallStatus}; + +pub async fn run_basic_completion() { + let expected_session_id = ExpectedSessionId::default(); + let openai = OpenAiFixture::new( + vec![( + r#"\nwhat is 1+1""#.into(), + include_str!("../test_data/openai_basic_response.txt"), + )], + expected_session_id.clone(), + ) + .await; + + let mut session = S::new(TestSessionConfig::default(), openai).await; + expected_session_id.set(session.id()); + + let output = session + .prompt("what is 1+1", PermissionDecision::Cancel) + .await; + assert!(output.text.contains("2")); + expected_session_id.assert_matches(&session.id().0); +} + +pub async fn run_mcp_http_server() { + let expected_session_id = ExpectedSessionId::default(); + let mcp = McpFixture::new(expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + r#"\nUse the get_code tool and output only its result.""#.into(), + include_str!("../test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("../test_data/openai_tool_result_response.txt"), + ), + ], + expected_session_id.clone(), + ) + .await; + + let config = TestSessionConfig { + mcp_servers: vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], + ..Default::default() + }; + let mut session = S::new(config, openai).await; + expected_session_id.set(session.id()); + + let output = session + .prompt( + "Use the get_code tool and output only its result.", + PermissionDecision::Cancel, + ) + .await; + assert!(output.text.contains(FAKE_CODE)); + expected_session_id.assert_matches(&session.id().0); +} + +pub async fn run_builtin_and_mcp() { + let expected_session_id = ExpectedSessionId::default(); + let prompt = + "Search for get_code and text_editor tools. Use them to save the code to /tmp/result.txt."; + let mcp = McpFixture::new(expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("../test_data/openai_builtin_search.txt"), + ), + ( + r#"lookup/get_code: Get the code"#.into(), + include_str!("../test_data/openai_builtin_read_modules.txt"), + ), + ( + r#"lookup[\"get_code\"]({}): string - Get the code"#.into(), + include_str!("../test_data/openai_builtin_execute.txt"), + ), + ( + r#"Successfully wrote to /tmp/result.txt"#.into(), + include_str!("../test_data/openai_builtin_final.txt"), + ), + ], + expected_session_id.clone(), + ) + .await; + + let config = TestSessionConfig { + builtins: vec!["code_execution".to_string(), "developer".to_string()], + mcp_servers: vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], + ..Default::default() + }; + + let _ = fs::remove_file("/tmp/result.txt"); + + let mut session = S::new(config, openai).await; + expected_session_id.set(session.id()); + + let _ = session.prompt(prompt, PermissionDecision::Cancel).await; + + let result = fs::read_to_string("/tmp/result.txt").unwrap_or_default(); + assert!(result.contains(FAKE_CODE)); + expected_session_id.assert_matches(&session.id().0); +} + +pub async fn run_permission_persistence() { + let cases = vec![ + ( + PermissionDecision::AllowAlways, + ToolCallStatus::Completed, + "user:\n always_allow:\n - lookup__get_code\n ask_before: []\n never_allow: []\n", + ), + (PermissionDecision::AllowOnce, ToolCallStatus::Completed, ""), + ( + PermissionDecision::RejectAlways, + ToolCallStatus::Failed, + "user:\n always_allow: []\n ask_before: []\n never_allow:\n - lookup__get_code\n", + ), + (PermissionDecision::RejectOnce, ToolCallStatus::Failed, ""), + (PermissionDecision::Cancel, ToolCallStatus::Failed, ""), + ]; + + let temp_dir = tempfile::tempdir().unwrap(); + let prompt = "Use the get_code tool and output only its result."; + let expected_session_id = ExpectedSessionId::default(); + let mcp = McpFixture::new(expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + prompt.to_string(), + include_str!("../test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("../test_data/openai_tool_result_response.txt"), + ), + ], + expected_session_id.clone(), + ) + .await; + + let config = TestSessionConfig { + mcp_servers: vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], + goose_mode: GooseMode::Approve, + data_root: temp_dir.path().to_path_buf(), + ..Default::default() + }; + + let mut session = S::new(config, openai).await; + expected_session_id.set(session.id()); + + for (decision, expected_status, expected_yaml) in cases { + session.reset_openai(); + session.reset_permissions(); + let _ = fs::remove_file(temp_dir.path().join("permission.yaml")); + let output = session.prompt(prompt, decision).await; + + assert_eq!( + output.tool_status.unwrap(), + expected_status, + "permission decision {:?}", + decision + ); + assert_eq!( + fs::read_to_string(temp_dir.path().join("permission.yaml")).unwrap_or_default(), + expected_yaml, + "permission decision {:?}", + decision + ); + } + expected_session_id.assert_matches(&session.id().0); +} + +pub async fn run_configured_extension() { + let temp_dir = tempfile::tempdir().unwrap(); + let expected_session_id = ExpectedSessionId::default(); + let prompt = "Use the get_code tool and output only its result."; + let mcp = McpFixture::new(expected_session_id.clone()).await; + + let config_yaml = format!( + "extensions:\n lookup:\n enabled: true\n type: streamable_http\n name: lookup\n description: Lookup server\n uri: \"{}\"\n", + mcp.url + ); + fs::write(temp_dir.path().join(CONFIG_YAML_NAME), config_yaml).unwrap(); + + let openai = OpenAiFixture::new( + vec![ + ( + prompt.to_string(), + include_str!("../test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("../test_data/openai_tool_result_response.txt"), + ), + ], + expected_session_id.clone(), + ) + .await; + + let config = TestSessionConfig { + data_root: temp_dir.path().to_path_buf(), + ..Default::default() + }; + + let mut session = S::new(config, openai).await; + expected_session_id.set(session.id()); + + let output = session.prompt(prompt, PermissionDecision::Cancel).await; + assert!(output.text.contains(FAKE_CODE)); + expected_session_id.assert_matches(&session.id().0); +} diff --git a/crates/goose-acp/tests/common.rs b/crates/goose-acp/tests/fixtures/mod.rs similarity index 61% rename from crates/goose-acp/tests/common.rs rename to crates/goose-acp/tests/fixtures/mod.rs index 620d52199547..f79cbf8d16fa 100644 --- a/crates/goose-acp/tests/common.rs +++ b/crates/goose-acp/tests/fixtures/mod.rs @@ -1,5 +1,12 @@ use assert_json_diff::{assert_json_matches_no_panic, CompareMode, Config}; +use async_trait::async_trait; +use fs_err as fs; +use goose::config::{GooseMode, PermissionManager}; +use goose::model::ModelConfig; +use goose::providers::api_client::{ApiClient, AuthMethod}; +use goose::providers::openai::OpenAiProvider; use goose::session_context::SESSION_ID_HEADER; +use goose_acp::server::{serve, AcpServerConfig, GooseAcpAgent}; use rmcp::model::{ClientNotification, ClientRequest, Meta, ServerResult}; use rmcp::service::{NotificationContext, RequestContext, ServiceRole}; use rmcp::transport::streamable_http_server::{ @@ -9,9 +16,17 @@ use rmcp::{ handler::server::router::tool::ToolRouter, model::*, tool, tool_handler, tool_router, ErrorData as McpError, RoleServer, ServerHandler, Service, }; +use sacp::schema::{ + McpServer, PermissionOptionKind, RequestPermissionOutcome, RequestPermissionRequest, + RequestPermissionResponse, SelectedPermissionOutcome, ToolCallStatus, +}; use std::collections::VecDeque; +use std::future::Future; +use std::path::Path; +use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tokio::task::JoinHandle; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -19,6 +34,49 @@ pub const FAKE_CODE: &str = "test-uuid-12345-67890"; const NOT_YET_SET: &str = "session-id-not-yet-set"; +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum PermissionDecision { + AllowAlways, + AllowOnce, + RejectOnce, + RejectAlways, + Cancel, +} + +#[derive(Default)] +pub struct PermissionMapping; + +pub fn map_permission_response( + _mapping: &PermissionMapping, + req: &RequestPermissionRequest, + decision: PermissionDecision, +) -> RequestPermissionResponse { + let outcome = match decision { + PermissionDecision::Cancel => RequestPermissionOutcome::Cancelled, + PermissionDecision::AllowAlways => select_option(req, PermissionOptionKind::AllowAlways), + PermissionDecision::AllowOnce => select_option(req, PermissionOptionKind::AllowOnce), + PermissionDecision::RejectOnce => select_option(req, PermissionOptionKind::RejectOnce), + PermissionDecision::RejectAlways => select_option(req, PermissionOptionKind::RejectAlways), + }; + + RequestPermissionResponse::new(outcome) +} + +fn select_option( + req: &RequestPermissionRequest, + kind: PermissionOptionKind, +) -> RequestPermissionOutcome { + req.options + .iter() + .find(|opt| opt.kind == kind) + .map(|opt| { + RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new( + opt.option_id.clone(), + )) + }) + .unwrap_or(RequestPermissionOutcome::Cancelled) +} + #[derive(Clone)] pub struct ExpectedSessionId { value: Arc>, @@ -60,14 +118,19 @@ impl ExpectedSessionId { /// Calling this ensures incidental requests that might error asynchronously, such as /// session rename have coherent session IDs. - pub fn assert_no_errors(&self) { + pub fn assert_matches(&self, actual: &str) { + let result = self.validate(Some(actual)); + assert!(result.is_ok(), "{}", result.unwrap_err()); let e = self.errors.lock().unwrap(); assert!(e.is_empty(), "Session ID validation errors: {:?}", *e); } } pub struct OpenAiFixture { - pub server: MockServer, + _server: MockServer, + base_url: String, + exchanges: Vec<(String, &'static str)>, + queue: Arc>>, } impl OpenAiFixture { @@ -78,8 +141,7 @@ impl OpenAiFixture { expected_session_id: ExpectedSessionId, ) -> Self { let mock_server = MockServer::start().await; - let queue: VecDeque<(String, &'static str)> = exchanges.into_iter().collect(); - let queue = Arc::new(Mutex::new(queue)); + let queue = Arc::new(Mutex::new(VecDeque::from(exchanges.clone()))); Mock::given(method("POST")) .and(path("/v1/chat/completions")) @@ -104,7 +166,7 @@ impl OpenAiFixture { return ResponseTemplate::new(200) .insert_header("content-type", "application/json") .set_body_string(include_str!( - "./test_data/openai_session_description.json" + "../test_data/openai_session_description.json" )); } @@ -135,14 +197,27 @@ impl OpenAiFixture { .mount(&mock_server) .await; + let base_url = mock_server.uri(); Self { - server: mock_server, + _server: mock_server, + base_url, + exchanges, + queue, } } + + pub fn uri(&self) -> &str { + &self.base_url + } + + pub fn reset(&self) { + let mut queue = self.queue.lock().unwrap(); + *queue = VecDeque::from(self.exchanges.clone()); + } } #[derive(Clone)] -pub struct Lookup { +struct Lookup { tool_router: ToolRouter, } @@ -198,13 +273,13 @@ impl HasMeta for NotificationContext { } } -pub struct ValidatingService { +struct ValidatingService { inner: S, expected_session_id: ExpectedSessionId, } impl ValidatingService { - pub fn new(inner: S, expected_session_id: ExpectedSessionId) -> Self { + fn new(inner: S, expected_session_id: ExpectedSessionId) -> Self { Self { inner, expected_session_id, @@ -287,3 +362,103 @@ impl McpFixture { } } } + +#[allow(dead_code)] +pub async fn spawn_acp_server_in_process( + openai_base_url: &str, + builtins: &[String], + data_root: &Path, + goose_mode: GooseMode, +) -> ( + tokio::io::DuplexStream, + tokio::io::DuplexStream, + JoinHandle<()>, + Arc, +) { + fs::create_dir_all(data_root).unwrap(); + let api_client = ApiClient::new( + openai_base_url.to_string(), + AuthMethod::BearerToken("test-key".to_string()), + ) + .unwrap(); + let model_config = ModelConfig::new("gpt-5-nano").unwrap(); + let provider = OpenAiProvider::new(api_client, model_config); + + let config = AcpServerConfig { + provider: Arc::new(provider), + builtins: builtins.to_vec(), + data_dir: data_root.to_path_buf(), + config_dir: data_root.to_path_buf(), + goose_mode, + }; + + let (client_read, server_write) = tokio::io::duplex(64 * 1024); + let (server_read, client_write) = tokio::io::duplex(64 * 1024); + + let agent = Arc::new(GooseAcpAgent::with_config(config).await.unwrap()); + let permission_manager = agent.permission_manager(); + let handle = tokio::spawn(async move { + if let Err(e) = serve(agent, server_read.compat(), server_write.compat_write()).await { + tracing::error!("ACP server error: {e}"); + } + }); + + (client_read, client_write, handle, permission_manager) +} + +pub struct TestOutput { + pub text: String, + pub tool_status: Option, +} + +pub struct TestSessionConfig { + pub mcp_servers: Vec, + pub builtins: Vec, + pub goose_mode: GooseMode, + pub data_root: PathBuf, +} + +impl Default for TestSessionConfig { + fn default() -> Self { + Self { + mcp_servers: Vec::new(), + builtins: Vec::new(), + goose_mode: GooseMode::Auto, + data_root: PathBuf::new(), + } + } +} + +#[async_trait] +pub trait Session { + async fn new(config: TestSessionConfig, openai: OpenAiFixture) -> Self + where + Self: Sized; + fn id(&self) -> &sacp::schema::SessionId; + fn reset_openai(&self); + fn reset_permissions(&self); + async fn prompt(&mut self, text: &str, decision: PermissionDecision) -> TestOutput; +} + +#[allow(dead_code)] +pub fn run_test(fut: F) +where + F: Future + Send + 'static, +{ + let handle = std::thread::Builder::new() + .name("acp-test".to_string()) + .stack_size(8 * 1024 * 1024) + .spawn(move || { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .thread_stack_size(8 * 1024 * 1024) + .enable_all() + .build() + .unwrap(); + runtime.block_on(fut); + }) + .unwrap(); + handle.join().unwrap(); +} + +pub mod server; diff --git a/crates/goose-acp/tests/fixtures/server.rs b/crates/goose-acp/tests/fixtures/server.rs new file mode 100644 index 000000000000..bc81f2b2662c --- /dev/null +++ b/crates/goose-acp/tests/fixtures/server.rs @@ -0,0 +1,226 @@ +use super::{ + map_permission_response, spawn_acp_server_in_process, PermissionDecision, PermissionMapping, + Session, TestOutput, TestSessionConfig, +}; +use async_trait::async_trait; +use goose::config::PermissionManager; +use sacp::schema::{ + ContentBlock, InitializeRequest, NewSessionRequest, PromptRequest, ProtocolVersion, + RequestPermissionRequest, SessionNotification, SessionUpdate, StopReason, TextContent, + ToolCallStatus, +}; +use sacp::{ClientToAgent, JrConnectionCx}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::sync::Notify; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +pub struct ClientToAgentSession { + cx: JrConnectionCx, + session_id: sacp::schema::SessionId, + updates: Arc>>, + permission: Arc>, + notify: Arc, + permission_manager: Arc, + // Keep the OpenAI mock server alive for the lifetime of the session. + _openai: super::OpenAiFixture, + // Keep the temp dir alive so test data/permissions persist during the session. + _temp_dir: Option, +} + +#[async_trait] +impl Session for ClientToAgentSession { + async fn new(config: TestSessionConfig, openai: super::OpenAiFixture) -> Self { + let (data_root, temp_dir) = match config.data_root.as_os_str().is_empty() { + true => { + let temp_dir = tempfile::tempdir().unwrap(); + (temp_dir.path().to_path_buf(), Some(temp_dir)) + } + false => (config.data_root.clone(), None), + }; + + let (client_read, client_write, _handle, permission_manager) = spawn_acp_server_in_process( + openai.uri(), + &config.builtins, + data_root.as_path(), + config.goose_mode, + ) + .await; + + let updates = Arc::new(Mutex::new(Vec::new())); + let notify = Arc::new(Notify::new()); + let permission = Arc::new(Mutex::new(PermissionDecision::Cancel)); + + let transport = sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); + + let (cx, session_id) = { + let updates_clone = updates.clone(); + let notify_clone = notify.clone(); + let permission_clone = permission.clone(); + let mcp_servers_clone = config.mcp_servers.clone(); + + let cx_holder: Arc>>> = + Arc::new(Mutex::new(None)); + let session_id_holder: Arc>> = + Arc::new(Mutex::new(None)); + + let cx_holder_clone = cx_holder.clone(); + let session_id_holder_clone = session_id_holder.clone(); + + let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + let permission_mapping = PermissionMapping; + + let result = ClientToAgent::builder() + .on_receive_notification( + { + let updates = updates_clone.clone(); + let notify = notify_clone.clone(); + async move |notification: SessionNotification, _cx| { + updates.lock().unwrap().push(notification); + notify.notify_waiters(); + Ok(()) + } + }, + sacp::on_receive_notification!(), + ) + .on_receive_request( + { + let permission = permission_clone.clone(); + async move |req: RequestPermissionRequest, + request_cx, + _connection_cx| { + let decision = *permission.lock().unwrap(); + let response = + map_permission_response(&permission_mapping, &req, decision); + request_cx.respond(response) + } + }, + sacp::on_receive_request!(), + ) + .connect_to(transport) + .unwrap() + .run_until({ + let mcp_servers = mcp_servers_clone; + let cx_holder = cx_holder_clone; + let session_id_holder = session_id_holder_clone; + move |cx: JrConnectionCx| async move { + cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) + .block_task() + .await + .unwrap(); + + let work_dir = tempfile::tempdir().unwrap(); + let session = cx + .send_request( + NewSessionRequest::new(work_dir.path()) + .mcp_servers(mcp_servers), + ) + .block_task() + .await + .unwrap(); + + *cx_holder.lock().unwrap() = Some(cx.clone()); + *session_id_holder.lock().unwrap() = Some(session.session_id); + let _ = ready_tx.send(()); + + std::future::pending::>().await + } + }) + .await; + + if let Err(e) = result { + tracing::error!("SACP client error: {e}"); + } + }); + + ready_rx.await.unwrap(); + + let cx = cx_holder.lock().unwrap().take().unwrap(); + let session_id = session_id_holder.lock().unwrap().take().unwrap(); + (cx, session_id) + }; + + Self { + cx, + session_id, + updates, + permission, + notify, + permission_manager, + _openai: openai, + _temp_dir: temp_dir, + } + } + + fn id(&self) -> &sacp::schema::SessionId { + &self.session_id + } + + fn reset_openai(&self) { + self._openai.reset(); + } + + fn reset_permissions(&self) { + self.permission_manager.remove_extension(""); + } + + async fn prompt(&mut self, text: &str, decision: PermissionDecision) -> TestOutput { + *self.permission.lock().unwrap() = decision; + self.updates.lock().unwrap().clear(); + + let response = self + .cx + .send_request(PromptRequest::new( + self.id().clone(), + vec![ContentBlock::Text(TextContent::new(text))], + )) + .block_task() + .await + .unwrap(); + + assert_eq!(response.stop_reason, StopReason::EndTurn); + + let mut updates_len = self.updates.lock().unwrap().len(); + while updates_len == 0 { + self.notify.notified().await; + updates_len = self.updates.lock().unwrap().len(); + } + + let text = collect_agent_text(&self.updates); + let deadline = tokio::time::Instant::now() + Duration::from_millis(500); + let mut tool_status = extract_tool_status(&self.updates); + while tool_status.is_none() && tokio::time::Instant::now() < deadline { + tokio::task::yield_now().await; + tool_status = extract_tool_status(&self.updates); + } + + TestOutput { text, tool_status } + } +} + +fn collect_agent_text(updates: &Arc>>) -> String { + let guard = updates.lock().unwrap(); + let mut text = String::new(); + + for notification in guard.iter() { + if let SessionUpdate::AgentMessageChunk(chunk) = ¬ification.update { + if let ContentBlock::Text(t) = &chunk.content { + text.push_str(&t.text); + } + } + } + + text +} + +fn extract_tool_status(updates: &Arc>>) -> Option { + let guard = updates.lock().unwrap(); + guard.iter().find_map(|notification| { + if let SessionUpdate::ToolCallUpdate(update) = ¬ification.update { + return update.fields.status; + } + None + }) +} diff --git a/crates/goose-acp/tests/server_test.rs b/crates/goose-acp/tests/server_test.rs index e04e01f7726d..c9241b2a90eb 100644 --- a/crates/goose-acp/tests/server_test.rs +++ b/crates/goose-acp/tests/server_test.rs @@ -1,429 +1,32 @@ -mod common; - -use common::{ExpectedSessionId, McpFixture, OpenAiFixture, FAKE_CODE}; -use fs_err as fs; -use goose::config::GooseMode; -use goose::model::ModelConfig; -use goose::providers::api_client::{ApiClient, AuthMethod}; -use goose::providers::openai::OpenAiProvider; -use goose_acp::server::{serve, GooseAcpAgent, GooseAcpConfig}; -use sacp::schema::{ - ContentBlock, ContentChunk, InitializeRequest, McpServer, McpServerHttp, NewSessionRequest, - PermissionOptionKind, PromptRequest, ProtocolVersion, RequestPermissionOutcome, - RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, - SessionNotification, SessionUpdate, StopReason, TextContent, ToolCallId, ToolCallStatus, - ToolCallUpdate, ToolCallUpdateFields, +mod common_tests; +use common_tests::fixtures::run_test; +use common_tests::fixtures::server::ClientToAgentSession; +use common_tests::{ + run_basic_completion, run_builtin_and_mcp, run_configured_extension, run_mcp_http_server, + run_permission_persistence, }; -use sacp::{ClientToAgent, JrConnectionCx}; -use std::path::Path; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use test_case::test_case; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use wiremock::MockServer; - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_acp_basic_completion() { - let temp_dir = tempfile::tempdir().unwrap(); - let prompt = "what is 1+1"; - let expected_session_id = ExpectedSessionId::default(); - let openai = OpenAiFixture::new( - vec![( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_basic_response.txt"), - )], - expected_session_id.clone(), - ) - .await; - - run_acp_session( - &openai.server, - vec![], - &[], - temp_dir.path(), - GooseMode::Auto, - None, - expected_session_id.clone(), - |cx, session_id, updates| async move { - let response = cx - .send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - - assert_eq!(response.stop_reason, StopReason::EndTurn); - wait_for( - &updates, - &SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new("2"), - ))), - ) - .await; - }, - ) - .await; - expected_session_id.assert_no_errors(); +#[test] +fn test_acp_basic_completion() { + run_test(async { run_basic_completion::().await }); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_acp_with_mcp_http_server() { - let temp_dir = tempfile::tempdir().unwrap(); - let prompt = "Use the get_code tool and output only its result."; - let expected_session_id = ExpectedSessionId::default(); - let mcp = McpFixture::new(expected_session_id.clone()).await; - let openai = OpenAiFixture::new( - vec![ - ( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_tool_call_response.txt"), - ), - ( - format!(r#""content":"{FAKE_CODE}""#), - include_str!("./test_data/openai_tool_result_response.txt"), - ), - ], - expected_session_id.clone(), - ) - .await; - - run_acp_session( - &openai.server, - vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], - &[], - temp_dir.path(), - GooseMode::Auto, - None, - expected_session_id.clone(), - |cx, session_id, updates| async move { - let response = cx - .send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - - assert_eq!(response.stop_reason, StopReason::EndTurn); - wait_for( - &updates, - &SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new(FAKE_CODE), - ))), - ) - .await; - }, - ) - .await; - - expected_session_id.assert_no_errors(); +#[test] +fn test_acp_with_mcp_http_server() { + run_test(async { run_mcp_http_server::().await }); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_acp_with_builtin_and_mcp() { - let temp_dir = tempfile::tempdir().unwrap(); - let prompt = - "Search for get_code and text_editor tools. Use them to save the code to /tmp/result.txt."; - let expected_session_id = ExpectedSessionId::default(); - let mcp = McpFixture::new(expected_session_id.clone()).await; - let openai = OpenAiFixture::new( - vec![ - ( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_builtin_search.txt"), - ), - ( - r#"lookup/get_code: Get the code"#.into(), - include_str!("./test_data/openai_builtin_read_modules.txt"), - ), - ( - r#"lookup[\"get_code\"]({}): string - Get the code"#.into(), - include_str!("./test_data/openai_builtin_execute.txt"), - ), - ( - r#"Successfully wrote to /tmp/result.txt"#.into(), - include_str!("./test_data/openai_builtin_final.txt"), - ), - ], - expected_session_id.clone(), - ) - .await; - - run_acp_session( - &openai.server, - vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], - &["code_execution", "developer"], - temp_dir.path(), - GooseMode::Auto, - None, - expected_session_id.clone(), - |cx, session_id, updates| async move { - let response = cx - .send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - - assert_eq!(response.stop_reason, StopReason::EndTurn); - wait_for( - &updates, - &SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text( - TextContent::new(FAKE_CODE), - ))), - ) - .await; - }, - ) - .await; - - expected_session_id.assert_no_errors(); +#[test] +fn test_acp_with_builtin_and_mcp() { + run_test(async { run_builtin_and_mcp::().await }); } -async fn wait_for(updates: &Arc>>, expected: &SessionUpdate) { - let deadline = tokio::time::Instant::now() + Duration::from_millis(500); - let mut context = String::new(); - - loop { - let matched = { - let guard = updates.lock().unwrap(); - context.clear(); - - match expected { - SessionUpdate::AgentMessageChunk(chunk) => { - let expected_text = match &chunk.content { - ContentBlock::Text(t) => &t.text, - other => panic!("wait_for: unhandled content {:?}", other), - }; - for n in guard.iter() { - if let SessionUpdate::AgentMessageChunk(c) = &n.update { - if let ContentBlock::Text(t) = &c.content { - if t.text.is_empty() { - context.clear(); - } else { - context.push_str(&t.text); - } - } - } - } - context.contains(expected_text) - } - SessionUpdate::ToolCallUpdate(expected_update) => { - for n in guard.iter() { - if let SessionUpdate::ToolCallUpdate(u) = &n.update { - context.push_str(&format!("{:?}\n", u)); - if u.fields.status == expected_update.fields.status { - return; - } - } - } - false - } - other => panic!("wait_for: unhandled update {:?}", other), - } - }; - - if matched { - return; - } - if tokio::time::Instant::now() > deadline { - panic!("Timeout waiting for {:?}\n\n{}", expected, context); - } - tokio::task::yield_now().await; - } +#[test] +fn test_permission_persistence() { + run_test(async { run_permission_persistence::().await }); } -async fn spawn_server_in_process( - mock_server: &MockServer, - builtins: &[&str], - data_root: &Path, - goose_mode: GooseMode, -) -> ( - tokio::io::DuplexStream, - tokio::io::DuplexStream, - tokio::task::JoinHandle<()>, -) { - let api_client = ApiClient::new( - mock_server.uri(), - AuthMethod::BearerToken("test-key".to_string()), - ) - .unwrap(); - let model_config = ModelConfig::new("gpt-5-nano").unwrap(); - let provider = OpenAiProvider::new(api_client, model_config); - - let config = GooseAcpConfig { - provider: Arc::new(provider), - builtins: builtins.iter().map(|s| s.to_string()).collect(), - data_dir: data_root.to_path_buf(), - config_dir: data_root.to_path_buf(), - goose_mode, - }; - - let (client_read, server_write) = tokio::io::duplex(64 * 1024); - let (server_read, client_write) = tokio::io::duplex(64 * 1024); - - let agent = Arc::new(GooseAcpAgent::with_config(config).await.unwrap()); - let handle = tokio::spawn(async move { - if let Err(e) = serve(agent, server_read.compat(), server_write.compat_write()).await { - tracing::error!("ACP server error: {e}"); - } - }); - - (client_read, client_write, handle) -} - -#[allow(clippy::too_many_arguments)] -async fn run_acp_session( - mock_server: &MockServer, - mcp_servers: Vec, - builtins: &[&str], - data_root: &Path, - mode: GooseMode, - select: Option, - expected_session_id: ExpectedSessionId, - test_fn: F, -) where - F: FnOnce( - JrConnectionCx, - sacp::schema::SessionId, - Arc>>, - ) -> Fut, - Fut: std::future::Future, -{ - let (client_read, client_write, _handle) = - spawn_server_in_process(mock_server, builtins, data_root, mode).await; - let work_dir = tempfile::tempdir().unwrap(); - let updates = Arc::new(Mutex::new(Vec::new())); - - let transport = sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); - - ClientToAgent::builder() - .on_receive_notification( - { - let updates = updates.clone(); - async move |notification: SessionNotification, _cx| { - updates.lock().unwrap().push(notification); - Ok(()) - } - }, - sacp::on_receive_notification!(), - ) - .on_receive_request( - async move |req: RequestPermissionRequest, request_cx, _connection_cx| { - let response = match select { - Some(kind) => { - let id = req - .options - .iter() - .find(|o| o.kind == kind) - .unwrap() - .option_id - .clone(); - RequestPermissionResponse::new(RequestPermissionOutcome::Selected( - SelectedPermissionOutcome::new(id), - )) - } - None => RequestPermissionResponse::new(RequestPermissionOutcome::Cancelled), - }; - request_cx.respond(response) - }, - sacp::on_receive_request!(), - ) - .connect_to(transport) - .unwrap() - .run_until({ - let updates = updates.clone(); - let expected_session_id = expected_session_id.clone(); - move |cx: JrConnectionCx| async move { - cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) - .block_task() - .await - .unwrap(); - - let session = cx - .send_request(NewSessionRequest::new(work_dir.path()).mcp_servers(mcp_servers)) - .block_task() - .await - .unwrap(); - - expected_session_id.set(&session.session_id); - - test_fn(cx.clone(), session.session_id, updates).await; - Ok(()) - } - }) - .await - .unwrap(); -} - -#[test_case(Some(PermissionOptionKind::AllowAlways), ToolCallStatus::Completed, "user:\n always_allow:\n - lookup__get_code\n ask_before: []\n never_allow: []\n"; "allow_always")] -#[test_case(Some(PermissionOptionKind::AllowOnce), ToolCallStatus::Completed, ""; "allow_once")] -#[test_case(Some(PermissionOptionKind::RejectAlways), ToolCallStatus::Failed, "user:\n always_allow: []\n ask_before: []\n never_allow:\n - lookup__get_code\n"; "reject_always")] -#[test_case(Some(PermissionOptionKind::RejectOnce), ToolCallStatus::Failed, ""; "reject_once")] -#[test_case(None, ToolCallStatus::Failed, ""; "cancelled")] -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_permission_persistence( - kind: Option, - expected_status: ToolCallStatus, - expected_yaml: &str, -) { - let temp_dir = tempfile::tempdir().unwrap(); - let prompt = "Use the get_code tool and output only its result."; - let expected_session_id = ExpectedSessionId::default(); - let mcp = McpFixture::new(expected_session_id.clone()).await; - let openai = OpenAiFixture::new( - vec![ - ( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_tool_call_response.txt"), - ), - ( - format!(r#""content":"{FAKE_CODE}""#), - include_str!("./test_data/openai_tool_result_response.txt"), - ), - ], - expected_session_id.clone(), - ) - .await; - - run_acp_session( - &openai.server, - vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], - &[], - temp_dir.path(), - GooseMode::Approve, - kind, - expected_session_id.clone(), - |cx, session_id, updates| async move { - cx.send_request(PromptRequest::new( - session_id, - vec![ContentBlock::Text(TextContent::new(prompt))], - )) - .block_task() - .await - .unwrap(); - wait_for( - &updates, - &SessionUpdate::ToolCallUpdate(ToolCallUpdate::new( - ToolCallId::new(""), - ToolCallUpdateFields::new().status(Some(expected_status)), - )), - ) - .await; - }, - ) - .await; - - expected_session_id.assert_no_errors(); - - assert_eq!( - fs::read_to_string(temp_dir.path().join("permission.yaml")).unwrap_or_default(), - expected_yaml - ); +#[test] +fn test_configured_extension() { + run_test(async { run_configured_extension::().await }); } diff --git a/crates/goose-mcp/.gooseignore b/crates/goose-mcp/.gooseignore new file mode 100644 index 000000000000..550684df254e --- /dev/null +++ b/crates/goose-mcp/.gooseignore @@ -0,0 +1,18 @@ +# This file is created automatically if no .gooseignore exists. +# Customize or uncomment the patterns below instead of deleting the file. +# Removing it will simply cause goose to recreate it on the next start. +# +# Suggested patterns you can uncomment if desired: +# **/.ssh/** # block SSH keys and configs +# **/*.key # block loose private keys +# **/*.pem # block certificates/private keys +# **/.git/** # block git metadata entirely +# **/target/** # block Rust build artifacts +# **/node_modules/** # block JS/TS dependencies +# **/*.db # block local database files +# **/*.sqlite # block SQLite databases +# + +**/.env +**/.env.* +**/secrets.* diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index c3f4b6cd3e3d..c71a7e70874c 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -27,8 +27,8 @@ pub fn name_to_key(name: &str) -> String { .to_lowercase() } -fn get_extensions_map() -> IndexMap { - let raw: Mapping = Config::global() +fn get_extensions_map_with_config(config: &Config) -> IndexMap { + let raw: Mapping = config .get_param(EXTENSIONS_CONFIG_KEY) .unwrap_or_else(|err| { warn!( @@ -75,6 +75,10 @@ fn get_extensions_map() -> IndexMap { extensions_map } +fn get_extensions_map() -> IndexMap { + get_extensions_map_with_config(Config::global()) +} + fn save_extensions_map(extensions: IndexMap) { let config = Config::global(); if let Err(e) = config.set_param(EXTENSIONS_CONFIG_KEY, &extensions) { @@ -135,6 +139,14 @@ pub fn get_enabled_extensions() -> Vec { .collect() } +pub fn get_enabled_extensions_with_config(config: &Config) -> Vec { + get_extensions_map_with_config(config) + .into_values() + .filter(|ext| ext.enabled) + .map(|ext| ext.config) + .collect() +} + pub fn get_warnings() -> Vec { let raw: Mapping = Config::global() .get_param(EXTENSIONS_CONFIG_KEY)