diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 59bdbc23bf7c..43715089f832 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -39,7 +39,10 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; use url::Url; +// Agent binds provider, extensions, and permission channels to a single session. +// ACP has no session/close, so sessions accumulate until transport closes. struct GooseAcpSession { + agent: Arc, messages: Conversation, tool_requests: HashMap, cancel_token: Option, @@ -47,10 +50,13 @@ struct GooseAcpSession { pub struct GooseAcpAgent { sessions: Arc>>, - agent: Arc, provider_factory: ProviderConstructor, config_dir: std::path::PathBuf, - provider_initialized: tokio::sync::OnceCell>, + session_manager: Arc, + permission_manager: Arc, + goose_mode: goose::config::GooseMode, + disable_session_naming: bool, + builtins: Vec, } fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result { @@ -286,7 +292,7 @@ async fn build_model_state( impl GooseAcpAgent { pub fn permission_manager(&self) -> Arc { - Arc::clone(&self.agent.config.permission_manager) + Arc::clone(&self.permission_manager) } pub async fn new( @@ -300,60 +306,36 @@ impl GooseAcpAgent { let session_manager = Arc::new(SessionManager::new(data_dir)); let permission_manager = Arc::new(PermissionManager::new(config_dir.clone())); - let agent = Agent::with_config(AgentConfig::new( - Arc::clone(&session_manager), - permission_manager, - None, - goose_mode, - disable_session_naming, - )); - - let agent_ptr = Arc::new(agent); - - let config_path = config_dir.join(CONFIG_YAML_NAME); - let config_file = Config::new(&config_path, "goose")?; - let extensions = get_enabled_extensions_with_config(&config_file); - - add_builtins(&agent_ptr, builtins).await; - add_extensions(&agent_ptr, extensions).await; - Ok(Self { sessions: Arc::new(Mutex::new(HashMap::new())), - agent: agent_ptr, provider_factory, config_dir, - provider_initialized: tokio::sync::OnceCell::new(), + session_manager, + permission_manager, + goose_mode, + disable_session_naming, + builtins, }) } - pub async fn create_session(&self) -> Result { - let manager = self.agent.config.session_manager.clone(); - let goose_session = manager - .create_session( - std::env::current_dir().unwrap_or_default(), - "ACP Session".to_string(), - SessionType::User, - ) - .await?; - - self.ensure_provider(&goose_session).await?; - - let session = GooseAcpSession { - messages: Conversation::new_unvalidated(Vec::new()), - tool_requests: HashMap::new(), - cancel_token: None, - }; - - let mut sessions = self.sessions.lock().await; - sessions.insert(goose_session.id.clone(), session); + async fn create_agent_for_session(&self) -> Arc { + let agent = Agent::with_config(AgentConfig::new( + Arc::clone(&self.session_manager), + Arc::clone(&self.permission_manager), + None, + self.goose_mode, + self.disable_session_naming, + )); + let agent = Arc::new(agent); - info!( - session_id = %goose_session.id, - session_type = "acp", - "Session created" - ); + let config_path = self.config_dir.join(CONFIG_YAML_NAME); + if let Ok(config_file) = Config::new(&config_path, "goose") { + let extensions = get_enabled_extensions_with_config(&config_file); + add_extensions(&agent, extensions).await; + } + add_builtins(&agent, self.builtins.clone()).await; - Ok(goose_session.id) + agent } pub async fn has_session(&self, session_id: &str) -> bool { @@ -433,12 +415,13 @@ impl GooseAcpAgent { } = &action_required.data { self.handle_tool_permission_request( + cx, + &session.agent, + session_id, id.clone(), tool_name.clone(), arguments.clone(), prompt.clone(), - session_id, - cx, )?; } } @@ -513,17 +496,19 @@ impl GooseAcpAgent { Ok(()) } + #[allow(clippy::too_many_arguments)] fn handle_tool_permission_request( &self, + cx: &JrConnectionCx, + agent: &Arc, + session_id: &SessionId, request_id: String, tool_name: String, arguments: serde_json::Map, prompt: Option, - session_id: &SessionId, - cx: &JrConnectionCx, ) -> Result<(), sacp::Error> { let cx = cx.clone(); - let agent = self.agent.clone(); + let agent = agent.clone(); let session_id = session_id.clone(); let formatted_name = format_tool_name(&tool_name); @@ -689,8 +674,8 @@ impl GooseAcpAgent { ) -> Result { debug!(?args, "new session request"); - let manager = self.agent.config.session_manager.clone(); - let goose_session = manager + let goose_session = self + .session_manager .create_session( args.cwd.clone(), "ACP Session".to_string(), @@ -700,9 +685,14 @@ impl GooseAcpAgent { .map_err(|e| { sacp::Error::internal_error().data(format!("Failed to create session: {}", e)) })?; - let provider = self.ensure_provider(&goose_session).await.map_err(|e| { - sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) - })?; + + let agent = self.create_agent_for_session().await; + let provider = self + .init_provider(&agent, &goose_session) + .await + .map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) + })?; for mcp_server in args.mcp_servers { let config = match mcp_server_to_extension_config(mcp_server) { @@ -712,13 +702,14 @@ impl GooseAcpAgent { } }; let name = config.name().to_string(); - if let Err(e) = self.agent.add_extension(config, &goose_session.id).await { + if let Err(e) = agent.add_extension(config, &goose_session.id).await { return Err(sacp::Error::internal_error() .data(format!("Failed to add MCP server '{}': {}", name, e))); } } let session = GooseAcpSession { + agent, messages: Conversation::new_unvalidated(Vec::new()), tool_requests: HashMap::new(), cancel_token: None, @@ -734,29 +725,26 @@ impl GooseAcpAgent { ); let model_state = - build_model_state(&**provider, &provider.get_model_config().model_name).await?; + build_model_state(&*provider, &provider.get_model_config().model_name).await?; Ok(NewSessionResponse::new(SessionId::new(goose_session.id)).models(model_state)) } - async fn create_provider(&self, session: &Session) -> Result> { - let config_path = self.config_dir.join(CONFIG_YAML_NAME); - let config = Config::new(&config_path, "goose")?; - let model_id = config.get_goose_model()?; - let model_config = goose::model::ModelConfig::new(&model_id)?; + async fn init_provider(&self, agent: &Agent, session: &Session) -> Result> { + let model_config = match &session.model_config { + Some(config) => config.clone(), + None => { + let config_path = self.config_dir.join(CONFIG_YAML_NAME); + let config = Config::new(&config_path, "goose")?; + let model_id = config.get_goose_model()?; + goose::model::ModelConfig::new(&model_id)? + } + }; let provider = (self.provider_factory)(model_config).await?; - self.agent - .update_provider(provider.clone(), &session.id) - .await?; + agent.update_provider(provider.clone(), &session.id).await?; Ok(provider) } - async fn ensure_provider(&self, session: &Session) -> Result<&Arc> { - self.provider_initialized - .get_or_try_init(|| self.create_provider(session)) - .await - } - async fn on_load_session( &self, args: LoadSessionRequest, @@ -766,21 +754,29 @@ impl GooseAcpAgent { let session_id = args.session_id.0.to_string(); - let manager = self.agent.config.session_manager.clone(); - let goose_session = manager.get_session(&session_id, true).await.map_err(|e| { - sacp::Error::invalid_params() - .data(format!("Failed to load session {}: {}", session_id, e)) - })?; - let provider = self.ensure_provider(&goose_session).await.map_err(|e| { - sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) - })?; + let goose_session = self + .session_manager + .get_session(&session_id, true) + .await + .map_err(|e| { + sacp::Error::invalid_params() + .data(format!("Failed to load session {}: {}", session_id, e)) + })?; + + let agent = self.create_agent_for_session().await; + let provider = self + .init_provider(&agent, &goose_session) + .await + .map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to set provider: {}", e)) + })?; let conversation = goose_session.conversation.ok_or_else(|| { sacp::Error::internal_error() .data(format!("Session {} has no conversation data", session_id)) })?; - manager + self.session_manager .update(&session_id) .working_dir(args.cwd.clone()) .apply() @@ -791,6 +787,7 @@ impl GooseAcpAgent { })?; let mut session = GooseAcpSession { + agent, messages: conversation.clone(), tool_requests: HashMap::new(), cancel_token: None, @@ -852,7 +849,7 @@ impl GooseAcpAgent { ); let model_state = - build_model_state(&**provider, &provider.get_model_config().model_name).await?; + build_model_state(&*provider, &provider.get_model_config().model_name).await?; Ok(LoadSessionResponse::new().models(model_state)) } @@ -865,13 +862,14 @@ impl GooseAcpAgent { let session_id = args.session_id.0.to_string(); let cancel_token = CancellationToken::new(); - { + let agent = { let mut sessions = self.sessions.lock().await; let session = sessions.get_mut(&session_id).ok_or_else(|| { sacp::Error::invalid_params().data(format!("Session not found: {}", session_id)) })?; session.cancel_token = Some(cancel_token.clone()); - } + session.agent.clone() + }; let user_message = self.convert_acp_prompt_to_message(args.prompt); @@ -882,8 +880,7 @@ impl GooseAcpAgent { retry_config: None, }; - let mut stream = self - .agent + let mut stream = agent .reply(user_message, session_config, Some(cancel_token.clone())) .await .map_err(|e| { @@ -959,12 +956,20 @@ impl GooseAcpAgent { model_id: &str, ) -> Result { let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| { - sacp::Error::internal_error().data(format!("Invalid model config: {}", e)) + sacp::Error::invalid_params().data(format!("Invalid model config: {}", e)) })?; let provider = (self.provider_factory)(model_config).await.map_err(|e| { sacp::Error::internal_error().data(format!("Failed to create provider: {}", e)) })?; - self.agent + + let agent = { + let sessions = self.sessions.lock().await; + let session = sessions.get(session_id).ok_or_else(|| { + sacp::Error::invalid_params().data(format!("Session not found: {}", session_id)) + })?; + session.agent.clone() + }; + agent .update_provider(provider, session_id) .await .map_err(|e| { diff --git a/crates/goose-acp/src/transport/http.rs b/crates/goose-acp/src/transport/http.rs index 07555a3fdbc1..0c1e7f28cca0 100644 --- a/crates/goose-acp/src/transport/http.rs +++ b/crates/goose-acp/src/transport/http.rs @@ -18,6 +18,7 @@ use crate::server_factory::AcpServer; pub(crate) struct HttpState { server: Arc, + // Keyed by acp_session_id: a connection-scoped UUID serving many Goose sessions. sessions: RwLock>, } @@ -38,10 +39,7 @@ impl HttpState { StatusCode::INTERNAL_SERVER_ERROR })?; - let session_id = agent.create_session().await.map_err(|e| { - error!("Failed to create ACP session: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; + let acp_session_id = uuid::Uuid::new_v4().to_string(); let handle = tokio::spawn(async move { let read_stream = ReceiverToAsyncRead::new(to_agent_rx); @@ -55,7 +53,7 @@ impl HttpState { }); self.sessions.write().await.insert( - session_id.clone(), + acp_session_id.clone(), TransportSession { to_agent_tx, from_agent_rx: Arc::new(Mutex::new(from_agent_rx)), @@ -63,24 +61,24 @@ impl HttpState { }, ); - info!(session_id = %session_id, "Session created"); - Ok(session_id) + info!(acp_session_id = %acp_session_id, "Session created"); + Ok(acp_session_id) } - async fn has_session(&self, session_id: &str) -> bool { - self.sessions.read().await.contains_key(session_id) + async fn has_session(&self, acp_session_id: &str) -> bool { + self.sessions.read().await.contains_key(acp_session_id) } - async fn remove_session(&self, session_id: &str) { - if let Some(session) = self.sessions.write().await.remove(session_id) { + async fn remove_session(&self, acp_session_id: &str) { + if let Some(session) = self.sessions.write().await.remove(acp_session_id) { session.handle.abort(); - info!(session_id = %session_id, "Session removed"); + info!(acp_session_id = %acp_session_id, "Session removed"); } } - async fn send_message(&self, session_id: &str, message: String) -> Result<(), StatusCode> { + async fn send_message(&self, acp_session_id: &str, message: String) -> Result<(), StatusCode> { let sessions = self.sessions.read().await; - let session = sessions.get(session_id).ok_or(StatusCode::NOT_FOUND)?; + let session = sessions.get(acp_session_id).ok_or(StatusCode::NOT_FOUND)?; session .to_agent_tx .send(message) @@ -90,10 +88,10 @@ impl HttpState { async fn get_receiver( &self, - session_id: &str, + acp_session_id: &str, ) -> Result>>, StatusCode> { let sessions = self.sessions.read().await; - let session = sessions.get(session_id).ok_or(StatusCode::NOT_FOUND)?; + let session = sessions.get(acp_session_id).ok_or(StatusCode::NOT_FOUND)?; Ok(session.from_agent_rx.clone()) } } @@ -107,8 +105,8 @@ fn create_sse_stream( while let Some(msg) = rx.recv().await { yield Ok::<_, Infallible>(axum::response::sse::Event::default().data(msg)); } - if let Some((state, session_id)) = cleanup { - state.remove_session(&session_id).await; + if let Some((state, acp_session_id)) = cleanup { + state.remove_session(&acp_session_id).await; } }; @@ -120,48 +118,48 @@ fn create_sse_stream( } async fn handle_initialize(state: Arc, json_message: &Value) -> Response { - let new_session_id = match state.create_session().await { + let acp_session_id = match state.create_session().await { Ok(id) => id, Err(status) => return status.into_response(), }; let message_str = serde_json::to_string(json_message).unwrap(); - if let Err(status) = state.send_message(&new_session_id, message_str).await { - state.remove_session(&new_session_id).await; + if let Err(status) = state.send_message(&acp_session_id, message_str).await { + state.remove_session(&acp_session_id).await; return status.into_response(); } - let receiver = match state.get_receiver(&new_session_id).await { + let receiver = match state.get_receiver(&acp_session_id).await { Ok(r) => r, Err(status) => { - state.remove_session(&new_session_id).await; + state.remove_session(&acp_session_id).await; return status.into_response(); } }; - let sse = create_sse_stream(receiver, Some((state.clone(), new_session_id.clone()))); + let sse = create_sse_stream(receiver, Some((state.clone(), acp_session_id.clone()))); let mut response = sse.into_response(); response .headers_mut() - .insert(HEADER_SESSION_ID, new_session_id.parse().unwrap()); + .insert(HEADER_SESSION_ID, acp_session_id.parse().unwrap()); response } async fn handle_request( state: Arc, - session_id: String, + acp_session_id: String, json_message: &Value, ) -> Response { - if !state.has_session(&session_id).await { + if !state.has_session(&acp_session_id).await { return (StatusCode::NOT_FOUND, "Session not found").into_response(); } let message_str = serde_json::to_string(json_message).unwrap(); - if let Err(status) = state.send_message(&session_id, message_str).await { + if let Err(status) = state.send_message(&acp_session_id, message_str).await { return status.into_response(); } - let receiver = match state.get_receiver(&session_id).await { + let receiver = match state.get_receiver(&acp_session_id).await { Ok(r) => r, Err(status) => return status.into_response(), }; @@ -171,15 +169,15 @@ async fn handle_request( async fn handle_notification_or_response( state: Arc, - session_id: String, + acp_session_id: String, json_message: &Value, ) -> Response { - if !state.has_session(&session_id).await { + if !state.has_session(&acp_session_id).await { return (StatusCode::NOT_FOUND, "Session not found").into_response(); } let message_str = serde_json::to_string(json_message).unwrap(); - if let Err(status) = state.send_message(&session_id, message_str).await { + if let Err(status) = state.send_message(&acp_session_id, message_str).await { return status.into_response(); } @@ -206,7 +204,7 @@ pub(crate) async fn handle_post( .into_response(); } - let session_id = get_session_id(&request); + let acp_session_id = get_session_id(&request); let body_bytes = match request.into_body().collect().await { Ok(collected) => collected.to_bytes(), @@ -235,7 +233,7 @@ pub(crate) async fn handle_post( if is_initialize_request(&json_message) { handle_initialize(state.clone(), &json_message).await } else if is_jsonrpc_request(&json_message) { - let Some(id) = session_id else { + let Some(id) = acp_session_id else { return ( StatusCode::BAD_REQUEST, "Bad Request: Acp-Session-Id header required", @@ -244,7 +242,7 @@ pub(crate) async fn handle_post( }; handle_request(state.clone(), id, &json_message).await } else if is_jsonrpc_notification(&json_message) || is_jsonrpc_response(&json_message) { - let Some(id) = session_id else { + let Some(id) = acp_session_id else { return ( StatusCode::BAD_REQUEST, "Bad Request: Acp-Session-Id header required", @@ -266,7 +264,7 @@ pub(crate) async fn handle_get(state: Arc, request: Request) -> .into_response(); } - let session_id = match get_session_id(&request) { + let acp_session_id = match get_session_id(&request) { Some(id) => id, None => { return ( @@ -277,11 +275,11 @@ pub(crate) async fn handle_get(state: Arc, request: Request) -> } }; - if !state.has_session(&session_id).await { + if !state.has_session(&acp_session_id).await { return (StatusCode::NOT_FOUND, "Session not found").into_response(); } - let receiver = match state.get_receiver(&session_id).await { + let receiver = match state.get_receiver(&acp_session_id).await { Ok(r) => r, Err(status) => return status.into_response(), }; @@ -306,7 +304,7 @@ pub(crate) async fn handle_delete( State(state): State>, request: Request, ) -> Response { - let session_id = match get_session_id(&request) { + let acp_session_id = match get_session_id(&request) { Some(id) => id, None => { return ( @@ -317,10 +315,10 @@ pub(crate) async fn handle_delete( } }; - if !state.has_session(&session_id).await { + if !state.has_session(&acp_session_id).await { return (StatusCode::NOT_FOUND, "Session not found").into_response(); } - state.remove_session(&session_id).await; + state.remove_session(&acp_session_id).await; StatusCode::ACCEPTED.into_response() } diff --git a/crates/goose-acp/src/transport/websocket.rs b/crates/goose-acp/src/transport/websocket.rs index a9c6a15edbcd..559375507e4a 100644 --- a/crates/goose-acp/src/transport/websocket.rs +++ b/crates/goose-acp/src/transport/websocket.rs @@ -16,6 +16,7 @@ use crate::server_factory::AcpServer; pub(crate) struct WsState { server: Arc, + // Keyed by acp_session_id: a connection-scoped UUID serving many Goose sessions. sessions: RwLock>, } @@ -33,8 +34,7 @@ impl WsState { let agent = self.server.create_agent().await?; - // Create a Goose ACP session (not just the transport connection) - let session_id = agent.create_session().await?; + let acp_session_id = uuid::Uuid::new_v4().to_string(); let handle = tokio::spawn(async move { let read_stream = ReceiverToAsyncRead::new(to_agent_rx); @@ -48,7 +48,7 @@ impl WsState { }); self.sessions.write().await.insert( - session_id.clone(), + acp_session_id.clone(), TransportSession { to_agent_tx, from_agent_rx: Arc::new(Mutex::new(from_agent_rx)), @@ -56,20 +56,20 @@ impl WsState { }, ); - info!(session_id = %session_id, "WebSocket connection created"); - Ok(session_id) + info!(acp_session_id = %acp_session_id, "WebSocket connection created"); + Ok(acp_session_id) } - async fn remove_connection(&self, session_id: &str) { - if let Some(session) = self.sessions.write().await.remove(session_id) { + async fn remove_connection(&self, acp_session_id: &str) { + if let Some(session) = self.sessions.write().await.remove(acp_session_id) { session.handle.abort(); - info!(session_id = %session_id, "WebSocket connection removed"); + info!(acp_session_id = %acp_session_id, "WebSocket connection removed"); } } } pub(crate) async fn handle_get(state: Arc, ws: WebSocketUpgrade) -> Response { - let session_id = match state.create_connection().await { + let acp_session_id = match state.create_connection().await { Ok(id) => id, Err(e) => { error!("Failed to create WebSocket connection: {}", e); @@ -82,30 +82,30 @@ pub(crate) async fn handle_get(state: Arc, ws: WebSocketUpgrade) -> Res }; let mut response = ws.on_upgrade({ - let session_id = session_id.clone(); - move |socket| handle_ws(socket, state, session_id) + let acp_session_id = acp_session_id.clone(); + move |socket| handle_ws(socket, state, acp_session_id) }); response .headers_mut() - .insert(HEADER_SESSION_ID, session_id.parse().unwrap()); + .insert(HEADER_SESSION_ID, acp_session_id.parse().unwrap()); response } -pub(crate) async fn handle_ws(socket: WebSocket, state: Arc, session_id: String) { +pub(crate) async fn handle_ws(socket: WebSocket, state: Arc, acp_session_id: String) { let (mut ws_tx, mut ws_rx) = socket.split(); let (to_agent, from_agent) = { let sessions = state.sessions.read().await; - match sessions.get(&session_id) { + match sessions.get(&acp_session_id) { Some(session) => (session.to_agent_tx.clone(), session.from_agent_rx.clone()), None => { - error!(session_id = %session_id, "Session not found after creation"); + error!(acp_session_id = %acp_session_id, "Session not found after creation"); return; } } }; - debug!(session_id = %session_id, "Starting bidirectional message loop"); + debug!(acp_session_id = %acp_session_id, "Starting bidirectional message loop"); let mut from_agent_rx = from_agent.lock().await; @@ -115,14 +115,14 @@ pub(crate) async fn handle_ws(socket: WebSocket, state: Arc, session_id match msg_result { Ok(Message::Text(text)) => { let text_str = text.to_string(); - debug!(session_id = %session_id, "Client → Agent: {} bytes", text_str.len()); + debug!(acp_session_id = %acp_session_id, "Client → Agent: {} bytes", text_str.len()); if let Err(e) = to_agent.send(text_str).await { - error!(session_id = %session_id, "Failed to send to agent: {}", e); + error!(acp_session_id = %acp_session_id, "Failed to send to agent: {}", e); break; } } Ok(Message::Close(frame)) => { - debug!(session_id = %session_id, "Client closed connection: {:?}", frame); + debug!(acp_session_id = %acp_session_id, "Client closed connection: {:?}", frame); break; } Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => { @@ -130,31 +130,31 @@ pub(crate) async fn handle_ws(socket: WebSocket, state: Arc, session_id continue; } Ok(Message::Binary(_)) => { - warn!(session_id = %session_id, "Ignoring binary message (ACP uses text)"); + warn!(acp_session_id = %acp_session_id, "Ignoring binary message (ACP uses text)"); continue; } Err(e) => { - error!(session_id = %session_id, "WebSocket error: {}", e); + error!(acp_session_id = %acp_session_id, "WebSocket error: {}", e); break; } } } Some(text) = from_agent_rx.recv() => { - debug!(session_id = %session_id, "Agent → Client: {} bytes", text.len()); + debug!(acp_session_id = %acp_session_id, "Agent → Client: {} bytes", text.len()); if let Err(e) = ws_tx.send(Message::Text(text.into())).await { - error!(session_id = %session_id, "Failed to send to client: {}", e); + error!(acp_session_id = %acp_session_id, "Failed to send to client: {}", e); break; } } else => { - debug!(session_id = %session_id, "Both channels closed"); + debug!(acp_session_id = %acp_session_id, "Both channels closed"); break; } } } - debug!(session_id = %session_id, "Cleaning up connection"); - state.remove_connection(&session_id).await; + debug!(acp_session_id = %acp_session_id, "Cleaning up connection"); + state.remove_connection(&acp_session_id).await; } diff --git a/crates/goose-acp/tests/common_tests/mod.rs b/crates/goose-acp/tests/common_tests/mod.rs index e2e11304218d..92ff4de11787 100644 --- a/crates/goose-acp/tests/common_tests/mod.rs +++ b/crates/goose-acp/tests/common_tests/mod.rs @@ -4,16 +4,21 @@ #[path = "../fixtures/mod.rs"] pub mod fixtures; -use fixtures::{OpenAiFixture, PermissionDecision, Session, TestSessionConfig}; +use fixtures::{ + initialize_agent, Connection, OpenAiFixture, PermissionDecision, Session, TestConnectionConfig, +}; use fs_err as fs; use goose::config::base::CONFIG_YAML_NAME; use goose::config::GooseMode; +use goose::providers::provider_registry::ProviderConstructor; +use goose_acp::server::GooseAcpAgent; use goose_test_support::{ExpectedSessionId, McpFixture, FAKE_CODE, TEST_MODEL}; use sacp::schema::{ McpServer, McpServerHttp, ModelId, ModelInfo, SessionModelState, ToolCallStatus, }; +use std::sync::Arc; -pub async fn run_config_mcp() { +pub async fn run_config_mcp() { let temp_dir = tempfile::tempdir().unwrap(); let expected_session_id = ExpectedSessionId::default(); let prompt = "Use the get_code tool and output only its result."; @@ -40,20 +45,195 @@ pub async fn run_config_mcp() { ) .await; - let config = TestSessionConfig { + let config = TestConnectionConfig { data_root: temp_dir.path().to_path_buf(), ..Default::default() }; - let mut session = S::new(config, openai).await; - expected_session_id.set(session.id().0.to_string()); + let mut conn = C::new(config, openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); let output = session.prompt(prompt, PermissionDecision::Cancel).await; assert_eq!(output.text, FAKE_CODE); - expected_session_id.assert_matches(&session.id().0); + expected_session_id.assert_matches(&session.session_id().0); +} + +pub async fn run_initialize_without_provider() { + let temp_dir = tempfile::tempdir().unwrap(); + + let provider_factory: ProviderConstructor = + Arc::new(|_| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) })); + + let agent = Arc::new( + GooseAcpAgent::new( + provider_factory, + vec![], + temp_dir.path().to_path_buf(), + temp_dir.path().to_path_buf(), + GooseMode::Auto, + false, + ) + .await + .unwrap(), + ); + + let resp = initialize_agent(agent).await; + assert!(!resp.auth_methods.is_empty()); + assert!(resp + .auth_methods + .iter() + .any(|m| &*m.id.0 == "goose-provider")); +} + +pub async fn run_load_model() { + let expected_session_id = ExpectedSessionId::default(); + let openai = OpenAiFixture::new( + vec![( + r#""model":"o4-mini""#.into(), + include_str!("../test_data/openai_basic.txt"), + )], + expected_session_id.clone(), + ) + .await; + + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); + + session.set_model("o4-mini").await; + + let output = session + .prompt("what is 1+1", PermissionDecision::Cancel) + .await; + assert_eq!(output.text, "2"); + + let session_id = session.session_id().0.to_string(); + let (_, models) = conn.load_session(&session_id).await; + assert_eq!(&*models.unwrap().current_model_id.0, "o4-mini"); +} + +pub async fn run_model_list() { + let expected_session_id = ExpectedSessionId::default(); + let openai = OpenAiFixture::new(vec![], expected_session_id.clone()).await; + + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let (session, models) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); + + let models = models.unwrap(); + let expected = SessionModelState::new( + ModelId::new(TEST_MODEL), + [ + "gpt-5.2", + "gpt-5.2-2025-12-11", + "gpt-5.2-chat-latest", + "gpt-5.2-codex", + "gpt-5.2-pro", + "gpt-5.2-pro-2025-12-11", + "gpt-5.1", + "gpt-5.1-2025-11-13", + "gpt-5.1-chat-latest", + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5-pro", + "gpt-5-pro-2025-10-06", + "gpt-5-codex", + "gpt-5", + "gpt-5-2025-08-07", + "gpt-5-chat-latest", + "gpt-5-mini", + "gpt-5-mini-2025-08-07", + TEST_MODEL, + "gpt-5-nano-2025-08-07", + "codex-mini-latest", + "o3", + "o3-2025-04-16", + "o4-mini", + "o4-mini-2025-04-16", + "gpt-4.1", + "gpt-4.1-2025-04-14", + "gpt-4.1-mini", + "gpt-4.1-mini-2025-04-14", + "gpt-4.1-nano", + "gpt-4.1-nano-2025-04-14", + "o1-pro", + "o1-pro-2025-03-19", + "o3-mini", + "o3-mini-2025-01-31", + "o1", + "o1-2024-12-17", + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-2024-11-20", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "o4-mini-deep-research", + "o4-mini-deep-research-2025-06-26", + "text-embedding-3-large", + "text-embedding-3-small", + "gpt-4", + "gpt-4-0613", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "text-embedding-ada-002", + ] + .iter() + .map(|id| ModelInfo::new(ModelId::new(*id), *id)) + .collect(), + ); + assert_eq!(models, expected); +} + +pub async fn run_model_set() { + let expected_session_id = ExpectedSessionId::default(); + let openai = OpenAiFixture::new( + vec![ + // Session B prompt with switched model + ( + r#""model":"o4-mini""#.into(), + include_str!("../test_data/openai_basic.txt"), + ), + // Session A prompt with default model + ( + format!(r#""model":"{TEST_MODEL}""#), + include_str!("../test_data/openai_basic.txt"), + ), + ], + expected_session_id.clone(), + ) + .await; + + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + + // Session A: default model + let (mut session_a, _) = conn.new_session().await; + + // Session B: switch to o4-mini + let (mut session_b, _) = conn.new_session().await; + session_b.set_model("o4-mini").await; + + // Prompt B — expects o4-mini + expected_session_id.set(session_b.session_id().0.to_string()); + let output = session_b + .prompt("what is 1+1", PermissionDecision::Cancel) + .await; + assert_eq!(output.text, "2"); + + // Prompt A — expects default TEST_MODEL (proves sessions are independent) + expected_session_id.set(session_a.session_id().0.to_string()); + let output = session_a + .prompt("what is 1+1", PermissionDecision::Cancel) + .await; + assert_eq!(output.text, "2"); } -pub async fn run_permission_persistence() { +pub async fn run_permission_persistence() { let cases = vec![ ( PermissionDecision::AllowAlways, @@ -89,39 +269,33 @@ pub async fn run_permission_persistence() { ) .await; - let config = TestSessionConfig { + let config = TestConnectionConfig { mcp_servers: vec![McpServer::Http(McpServerHttp::new("mcp-fixture", &mcp.url))], goose_mode: GooseMode::Approve, data_root: temp_dir.path().to_path_buf(), ..Default::default() }; - let mut session = S::new(config, openai).await; - expected_session_id.set(session.id().0.to_string()); + let mut conn = C::new(config, openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); for (decision, expected_status, expected_yaml) in cases { - session.reset_openai(); - session.reset_permissions(); + conn.reset_openai(); + conn.reset_permissions(); let _ = fs::remove_file(temp_dir.path().join("permission.yaml")); let output = session.prompt(prompt, decision).await; - assert_eq!( - output.tool_status.unwrap(), - expected_status, - "permission decision {:?}", - decision - ); + assert_eq!(output.tool_status.unwrap(), expected_status); assert_eq!( fs::read_to_string(temp_dir.path().join("permission.yaml")).unwrap_or_default(), expected_yaml, - "permission decision {:?}", - decision ); } - expected_session_id.assert_matches(&session.id().0); + expected_session_id.assert_matches(&session.session_id().0); } -pub async fn run_prompt_basic() { +pub async fn run_prompt_basic() { let expected_session_id = ExpectedSessionId::default(); let openai = OpenAiFixture::new( vec![( @@ -132,17 +306,18 @@ pub async fn run_prompt_basic() { ) .await; - let mut session = S::new(TestSessionConfig::default(), openai).await; - expected_session_id.set(session.id().0.to_string()); + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); let output = session .prompt("what is 1+1", PermissionDecision::Cancel) .await; assert_eq!(output.text, "2"); - expected_session_id.assert_matches(&session.id().0); + expected_session_id.assert_matches(&session.session_id().0); } -pub async fn run_prompt_codemode() { +pub async fn run_prompt_codemode() { let expected_session_id = ExpectedSessionId::default(); let prompt = "Search for getCode and textEditor tools. Use them to save the code to /tmp/result.txt."; @@ -166,7 +341,7 @@ pub async fn run_prompt_codemode() { ) .await; - let config = TestSessionConfig { + let config = TestConnectionConfig { builtins: vec!["code_execution".to_string(), "developer".to_string()], mcp_servers: vec![McpServer::Http(McpServerHttp::new("mcp-fixture", &mcp.url))], ..Default::default() @@ -174,8 +349,9 @@ pub async fn run_prompt_codemode() { let _ = fs::remove_file("/tmp/result.txt"); - let mut session = S::new(config, openai).await; - expected_session_id.set(session.id().0.to_string()); + let mut conn = C::new(config, openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); let output = session.prompt(prompt, PermissionDecision::Cancel).await; if matches!(output.tool_status, Some(ToolCallStatus::Failed)) || output.text.contains("error") { @@ -184,10 +360,10 @@ pub async fn run_prompt_codemode() { let result = fs::read_to_string("/tmp/result.txt").unwrap_or_default(); assert_eq!(result, format!("{FAKE_CODE}\n")); - expected_session_id.assert_matches(&session.id().0); + expected_session_id.assert_matches(&session.session_id().0); } -pub async fn run_prompt_image() { +pub async fn run_prompt_image() { let expected_session_id = ExpectedSessionId::default(); let mcp = McpFixture::new(Some(expected_session_id.clone())).await; let openai = OpenAiFixture::new( @@ -206,12 +382,13 @@ pub async fn run_prompt_image() { ) .await; - let config = TestSessionConfig { + let config = TestConnectionConfig { mcp_servers: vec![McpServer::Http(McpServerHttp::new("mcp-fixture", &mcp.url))], ..Default::default() }; - let mut session = S::new(config, openai).await; - expected_session_id.set(session.id().0.to_string()); + let mut conn = C::new(config, openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); let output = session .prompt( @@ -220,10 +397,10 @@ pub async fn run_prompt_image() { ) .await; assert_eq!(output.text, "Hello Goose!\nThis is a test image."); - expected_session_id.assert_matches(&session.id().0); + expected_session_id.assert_matches(&session.session_id().0); } -pub async fn run_prompt_mcp() { +pub async fn run_prompt_mcp() { let expected_session_id = ExpectedSessionId::default(); let mcp = McpFixture::new(Some(expected_session_id.clone())).await; let openai = OpenAiFixture::new( @@ -241,12 +418,13 @@ pub async fn run_prompt_mcp() { ) .await; - let config = TestSessionConfig { + let config = TestConnectionConfig { mcp_servers: vec![McpServer::Http(McpServerHttp::new("mcp-fixture", &mcp.url))], ..Default::default() }; - let mut session = S::new(config, openai).await; - expected_session_id.set(session.id().0.to_string()); + let mut conn = C::new(config, openai).await; + let (mut session, _) = conn.new_session().await; + expected_session_id.set(session.session_id().0.to_string()); let output = session .prompt( @@ -255,103 +433,5 @@ pub async fn run_prompt_mcp() { ) .await; assert_eq!(output.text, FAKE_CODE); - expected_session_id.assert_matches(&session.id().0); -} - -pub async fn run_model_list() { - let expected_session_id = ExpectedSessionId::default(); - let openai = OpenAiFixture::new(vec![], expected_session_id.clone()).await; - - let session = S::new(TestSessionConfig::default(), openai).await; - expected_session_id.set(session.id().0.to_string()); - - let models = session.models().unwrap(); - let expected = SessionModelState::new( - ModelId::new(TEST_MODEL), - [ - "gpt-5.2", - "gpt-5.2-2025-12-11", - "gpt-5.2-chat-latest", - "gpt-5.2-codex", - "gpt-5.2-pro", - "gpt-5.2-pro-2025-12-11", - "gpt-5.1", - "gpt-5.1-2025-11-13", - "gpt-5.1-chat-latest", - "gpt-5.1-codex", - "gpt-5.1-codex-max", - "gpt-5.1-codex-mini", - "gpt-5-pro", - "gpt-5-pro-2025-10-06", - "gpt-5-codex", - "gpt-5", - "gpt-5-2025-08-07", - "gpt-5-chat-latest", - "gpt-5-mini", - "gpt-5-mini-2025-08-07", - TEST_MODEL, - "gpt-5-nano-2025-08-07", - "codex-mini-latest", - "o3", - "o3-2025-04-16", - "o4-mini", - "o4-mini-2025-04-16", - "gpt-4.1", - "gpt-4.1-2025-04-14", - "gpt-4.1-mini", - "gpt-4.1-mini-2025-04-14", - "gpt-4.1-nano", - "gpt-4.1-nano-2025-04-14", - "o1-pro", - "o1-pro-2025-03-19", - "o3-mini", - "o3-mini-2025-01-31", - "o1", - "o1-2024-12-17", - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4o-2024-08-06", - "gpt-4o-2024-11-20", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "o4-mini-deep-research", - "o4-mini-deep-research-2025-06-26", - "text-embedding-3-large", - "text-embedding-3-small", - "gpt-4", - "gpt-4-0613", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-3.5-turbo", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - "text-embedding-ada-002", - ] - .iter() - .map(|id| ModelInfo::new(ModelId::new(*id), *id)) - .collect(), - ); - assert_eq!(*models, expected); -} - -pub async fn run_set_model() { - let expected_session_id = ExpectedSessionId::default(); - let openai = OpenAiFixture::new( - vec![( - r#""model":"o4-mini""#.into(), - include_str!("../test_data/openai_basic.txt"), - )], - expected_session_id.clone(), - ) - .await; - - let mut session = S::new(TestSessionConfig::default(), openai).await; - expected_session_id.set(session.id().0.to_string()); - - session.set_model("o4-mini").await; - - let output = session - .prompt("what is 1+1", PermissionDecision::Cancel) - .await; - assert_eq!(output.text, "2"); + expected_session_id.assert_matches(&session.session_id().0); } diff --git a/crates/goose-acp/tests/fixtures/mod.rs b/crates/goose-acp/tests/fixtures/mod.rs index c5bcdbee6b3f..2f35c42f050c 100644 --- a/crates/goose-acp/tests/fixtures/mod.rs +++ b/crates/goose-acp/tests/fixtures/mod.rs @@ -194,22 +194,25 @@ pub async fn spawn_acp_server_in_process( builtins: &[String], data_root: &Path, goose_mode: GooseMode, + provider_factory: Option, ) -> (DuplexTransport, JoinHandle<()>, Arc) { fs::create_dir_all(data_root).unwrap(); - // ensure_provider reads the model from config lazily, so tests need a config.yaml. let config_path = data_root.join(goose::config::base::CONFIG_YAML_NAME); if !config_path.exists() { fs::write(&config_path, format!("GOOSE_MODEL: {TEST_MODEL}\n")).unwrap(); } - let base_url = openai_base_url.to_string(); - let provider_factory: ProviderConstructor = Arc::new(move |model_config| { - let base_url = base_url.clone(); - Box::pin(async move { - let api_client = - ApiClient::new(base_url, AuthMethod::BearerToken("test-key".to_string())).unwrap(); - let provider: Arc = - Arc::new(OpenAiProvider::new(api_client, model_config)); - Ok(provider) + let provider_factory = provider_factory.unwrap_or_else(|| { + let base_url = openai_base_url.to_string(); + Arc::new(move |model_config| { + let base_url = base_url.clone(); + Box::pin(async move { + let api_client = + ApiClient::new(base_url, AuthMethod::BearerToken("test-key".to_string())) + .unwrap(); + let provider: Arc = + Arc::new(OpenAiProvider::new(api_client, model_config)); + Ok(provider) + }) }) }); @@ -236,33 +239,43 @@ pub struct TestOutput { pub tool_status: Option, } -pub struct TestSessionConfig { +pub struct TestConnectionConfig { pub mcp_servers: Vec, pub builtins: Vec, pub goose_mode: GooseMode, pub data_root: PathBuf, + pub provider_factory: Option, } -impl Default for TestSessionConfig { +impl Default for TestConnectionConfig { fn default() -> Self { Self { mcp_servers: Vec::new(), builtins: Vec::new(), goose_mode: GooseMode::Auto, data_root: PathBuf::new(), + provider_factory: None, } } } #[async_trait] -pub trait Session { - async fn new(config: TestSessionConfig, openai: OpenAiFixture) -> Self - where - Self: Sized; - fn id(&self) -> &sacp::schema::SessionId; - fn models(&self) -> Option<&SessionModelState>; +pub trait Connection: Sized { + type Session: Session; + + async fn new(config: TestConnectionConfig, openai: OpenAiFixture) -> Self; + async fn new_session(&mut self) -> (Self::Session, Option); + async fn load_session( + &mut self, + session_id: &str, + ) -> (Self::Session, Option); fn reset_openai(&self); fn reset_permissions(&self); +} + +#[async_trait] +pub trait Session { + fn session_id(&self) -> &sacp::schema::SessionId; async fn prompt(&mut self, text: &str, decision: PermissionDecision) -> TestOutput; async fn set_model(&self, model_id: &str); } diff --git a/crates/goose-acp/tests/fixtures/server.rs b/crates/goose-acp/tests/fixtures/server.rs index 4d6a7c841d65..1432e3431da5 100644 --- a/crates/goose-acp/tests/fixtures/server.rs +++ b/crates/goose-acp/tests/fixtures/server.rs @@ -1,36 +1,44 @@ use super::{ - map_permission_response, spawn_acp_server_in_process, PermissionDecision, PermissionMapping, - Session, TestOutput, TestSessionConfig, + map_permission_response, spawn_acp_server_in_process, Connection, PermissionDecision, + PermissionMapping, Session, TestConnectionConfig, TestOutput, }; use async_trait::async_trait; use goose::config::PermissionManager; use sacp::schema::{ - ContentBlock, InitializeRequest, NewSessionRequest, NewSessionResponse, PromptRequest, - ProtocolVersion, RequestPermissionRequest, SessionModelState, SessionNotification, - SessionUpdate, StopReason, TextContent, ToolCallStatus, + ContentBlock, InitializeRequest, LoadSessionRequest, McpServer, NewSessionRequest, + PromptRequest, ProtocolVersion, RequestPermissionRequest, SessionModelState, + SessionNotification, SessionUpdate, StopReason, TextContent, ToolCallStatus, }; use sacp::{ClientToAgent, JrConnectionCx}; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::sync::Notify; -pub struct ClientToAgentSession { +pub struct ClientToAgentConnection { cx: JrConnectionCx, - session_id: sacp::schema::SessionId, - new_session_response: NewSessionResponse, + // MCP servers from config, consumed by the first new_session call. + pending_mcp_servers: Vec, updates: Arc>>, permission: Arc>, notify: Arc, permission_manager: Arc, - // Keep the OpenAI mock server alive for the lifetime of the session. _openai: super::OpenAiFixture, - // Keep the temp dir alive so test data/permissions persist during the session. _temp_dir: Option, } +pub struct ClientToAgentSession { + cx: JrConnectionCx, + session_id: sacp::schema::SessionId, + updates: Arc>>, + permission: Arc>, + notify: Arc, +} + #[async_trait] -impl Session for ClientToAgentSession { - async fn new(config: TestSessionConfig, openai: super::OpenAiFixture) -> Self { +impl Connection for ClientToAgentConnection { + type Session = ClientToAgentSession; + + async fn new(config: TestConnectionConfig, openai: super::OpenAiFixture) -> Self { let (data_root, temp_dir) = match config.data_root.as_os_str().is_empty() { true => { let temp_dir = tempfile::tempdir().unwrap(); @@ -44,6 +52,7 @@ impl Session for ClientToAgentSession { &config.builtins, data_root.as_path(), config.goose_mode, + config.provider_factory, ) .await; @@ -51,22 +60,14 @@ impl Session for ClientToAgentSession { let notify = Arc::new(Notify::new()); let permission = Arc::new(Mutex::new(PermissionDecision::Cancel)); - let (cx, session_id, new_session_response) = { + let cx = { let updates_clone = updates.clone(); let notify_clone = notify.clone(); let permission_clone = permission.clone(); - let mcp_servers_clone = config.mcp_servers.clone(); let cx_holder: Arc>>> = Arc::new(Mutex::new(None)); - let session_id_holder: Arc>> = - Arc::new(Mutex::new(None)); - let response_holder: Arc>> = - Arc::new(Mutex::new(None)); - let cx_holder_clone = cx_holder.clone(); - let session_id_holder_clone = session_id_holder.clone(); - let response_holder_clone = response_holder.clone(); let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); @@ -103,28 +104,14 @@ impl Session for ClientToAgentSession { .connect_to(transport) .unwrap() .run_until({ - let mcp_servers = mcp_servers_clone; let cx_holder = cx_holder_clone; - let session_id_holder = session_id_holder_clone; move |cx: JrConnectionCx| async move { cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) .block_task() .await .unwrap(); - let work_dir = tempfile::tempdir().unwrap(); - let response = cx - .send_request( - NewSessionRequest::new(work_dir.path()) - .mcp_servers(mcp_servers), - ) - .block_task() - .await - .unwrap(); - *cx_holder.lock().unwrap() = Some(cx.clone()); - *session_id_holder.lock().unwrap() = Some(response.session_id.clone()); - *response_holder_clone.lock().unwrap() = Some(response); let _ = ready_tx.send(()); std::future::pending::>().await @@ -138,17 +125,13 @@ impl Session for ClientToAgentSession { }); ready_rx.await.unwrap(); - let cx = cx_holder.lock().unwrap().take().unwrap(); - let session_id = session_id_holder.lock().unwrap().take().unwrap(); - let new_session_response = response_holder.lock().unwrap().take().unwrap(); - (cx, session_id, new_session_response) + cx }; Self { cx, - session_id, - new_session_response, + pending_mcp_servers: config.mcp_servers, updates, permission, notify, @@ -158,12 +141,46 @@ impl Session for ClientToAgentSession { } } - fn id(&self) -> &sacp::schema::SessionId { - &self.session_id + async fn new_session(&mut self) -> (ClientToAgentSession, Option) { + let work_dir = tempfile::tempdir().unwrap(); + let mcp_servers = std::mem::take(&mut self.pending_mcp_servers); + let response = self + .cx + .send_request(NewSessionRequest::new(work_dir.path()).mcp_servers(mcp_servers)) + .block_task() + .await + .unwrap(); + let session = ClientToAgentSession { + cx: self.cx.clone(), + session_id: response.session_id.clone(), + updates: self.updates.clone(), + permission: self.permission.clone(), + notify: self.notify.clone(), + }; + (session, response.models) } - fn models(&self) -> Option<&SessionModelState> { - self.new_session_response.models.as_ref() + async fn load_session( + &mut self, + session_id: &str, + ) -> (ClientToAgentSession, Option) { + self.updates.lock().unwrap().clear(); + let work_dir = tempfile::tempdir().unwrap(); + let session_id = sacp::schema::SessionId::new(session_id.to_string()); + let response = self + .cx + .send_request(LoadSessionRequest::new(session_id.clone(), work_dir.path())) + .block_task() + .await + .unwrap(); + let session = ClientToAgentSession { + cx: self.cx.clone(), + session_id, + updates: self.updates.clone(), + permission: self.permission.clone(), + notify: self.notify.clone(), + }; + (session, response.models) } fn reset_openai(&self) { @@ -173,6 +190,13 @@ impl Session for ClientToAgentSession { fn reset_permissions(&self) { self.permission_manager.remove_extension(""); } +} + +#[async_trait] +impl Session for ClientToAgentSession { + fn session_id(&self) -> &sacp::schema::SessionId { + &self.session_id + } async fn prompt(&mut self, text: &str, decision: PermissionDecision) -> TestOutput { *self.permission.lock().unwrap() = decision; @@ -181,7 +205,7 @@ impl Session for ClientToAgentSession { let response = self .cx .send_request(PromptRequest::new( - self.id().clone(), + self.session_id.clone(), vec![ContentBlock::Text(TextContent::new(text))], )) .block_task() diff --git a/crates/goose-acp/tests/server_test.rs b/crates/goose-acp/tests/server_test.rs index c20a7074a163..a143d15356e3 100644 --- a/crates/goose-acp/tests/server_test.rs +++ b/crates/goose-acp/tests/server_test.rs @@ -1,83 +1,58 @@ mod common_tests; -use common_tests::fixtures::initialize_agent; use common_tests::fixtures::run_test; -use common_tests::fixtures::server::ClientToAgentSession; +use common_tests::fixtures::server::ClientToAgentConnection; use common_tests::{ - run_config_mcp, run_model_list, run_permission_persistence, run_prompt_basic, - run_prompt_codemode, run_prompt_image, run_prompt_mcp, run_set_model, + run_config_mcp, run_initialize_without_provider, run_load_model, run_model_list, run_model_set, + run_permission_persistence, run_prompt_basic, run_prompt_codemode, run_prompt_image, + run_prompt_mcp, }; -use goose::config::GooseMode; -use goose::providers::provider_registry::ProviderConstructor; -use goose_acp::server::GooseAcpAgent; -use std::sync::Arc; #[test] fn test_config_mcp() { - run_test(async { run_config_mcp::().await }); + run_test(async { run_config_mcp::().await }); +} + +#[test] +fn test_initialize_without_provider() { + run_test(async { run_initialize_without_provider().await }); +} + +#[test] +fn test_load_model() { + run_test(async { run_load_model::().await }); } #[test] fn test_model_list() { - run_test(async { run_model_list::().await }); + run_test(async { run_model_list::().await }); } #[test] -fn test_set_model() { - run_test(async { run_set_model::().await }); +fn test_model_set() { + run_test(async { run_model_set::().await }); } #[test] fn test_permission_persistence() { - run_test(async { run_permission_persistence::().await }); + run_test(async { run_permission_persistence::().await }); } #[test] fn test_prompt_basic() { - run_test(async { run_prompt_basic::().await }); + run_test(async { run_prompt_basic::().await }); } #[test] fn test_prompt_codemode() { - run_test(async { run_prompt_codemode::().await }); + run_test(async { run_prompt_codemode::().await }); } #[test] fn test_prompt_image() { - run_test(async { run_prompt_image::().await }); + run_test(async { run_prompt_image::().await }); } #[test] fn test_prompt_mcp() { - run_test(async { run_prompt_mcp::().await }); -} - -#[test] -fn test_initialize_without_provider() { - run_test(async { - let temp_dir = tempfile::tempdir().unwrap(); - - let provider_factory: ProviderConstructor = - Arc::new(|_| Box::pin(async { Err(anyhow::anyhow!("no provider configured")) })); - - let agent = Arc::new( - GooseAcpAgent::new( - provider_factory, - vec![], - temp_dir.path().to_path_buf(), - temp_dir.path().to_path_buf(), - GooseMode::Auto, - false, - ) - .await - .unwrap(), - ); - - // Initialization shouldn't fail even though we have a crashing provider factory. - let resp = initialize_agent(agent).await; - assert!(!resp.auth_methods.is_empty()); - assert!(resp - .auth_methods - .iter() - .any(|m| &*m.id.0 == "goose-provider")); - }); + run_test(async { run_prompt_mcp::().await }); }