diff --git a/Cargo.lock b/Cargo.lock index d845f889650b..252e22320c11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,16 +40,16 @@ dependencies = [ [[package]] name = "agent-client-protocol-schema" -version = "0.10.8" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bc1fef9c32f03bce2ab44af35b6f483bfd169bf55cc59beeb2e3b1a00ae4d1" +checksum = "96daddd0d00f2eab88f8099d38190881bf8d6c5e46b6fa21751f482775d0dba7" dependencies = [ "anyhow", "derive_more", "schemars 1.2.1", "serde", "serde_json", - "strum", + "strum 0.28.0", ] [[package]] @@ -225,9 +225,9 @@ dependencies = [ [[package]] name = "arc-swap" -version = "1.8.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" dependencies = [ "rustversion", ] @@ -443,9 +443,9 @@ dependencies = [ [[package]] name = "aws-lc-fips-sys" -version = "0.13.12" +version = "0.13.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ed8cd42adddefbdb8507fb7443fa9b666631078616b78f70ed22117b5c27d90" +checksum = "f8bce4948d2520386c6d92a6ea2d472300257702242e5a1d01d6add52bd2e7c1" dependencies = [ "bindgen", "cc", @@ -457,9 +457,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.16.1" +version = "1.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" dependencies = [ "aws-lc-fips-sys", "aws-lc-sys", @@ -469,9 +469,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.38.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a" dependencies = [ "bindgen", "cc", @@ -2983,8 +2983,8 @@ dependencies = [ "proc-macro2", "quote", "stringcase", - "strum", - "strum_macros", + "strum 0.27.2", + "strum_macros 0.27.2", "syn 2.0.117", "syn-match", "thiserror 2.0.18", @@ -4139,15 +4139,6 @@ dependencies = [ "slab", ] -[[package]] -name = "fxhash" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" -dependencies = [ - "byteorder", -] - [[package]] name = "gemm" version = "0.18.2" @@ -4533,6 +4524,7 @@ dependencies = [ "encoding_rs", "env-lock", "etcetera 0.11.0", + "fs-err", "fs2", "futures", "goose-test-support", @@ -4564,7 +4556,7 @@ dependencies = [ "rayon", "regex", "reqwest 0.13.2", - "rmcp 1.2.0", + "rmcp", "rubato", "sacp", "schemars 1.2.1", @@ -4577,7 +4569,7 @@ dependencies = [ "shell-words", "shellexpand", "sqlx", - "strum", + "strum 0.27.2", "symphonia", "sys-info", "tempfile", @@ -4634,12 +4626,13 @@ dependencies = [ "goose-test-support", "http-body-util", "regex", - "rmcp 1.2.0", + "rmcp", "sacp", "schemars 1.2.1", "serde", "serde_json", - "strum", + "sqlx", + "strum 0.27.2", "tempfile", "test-case", "tokio", @@ -4688,7 +4681,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest 0.13.2", - "rmcp 1.2.0", + "rmcp", "rustyline", "serde", "serde_json", @@ -4696,7 +4689,7 @@ dependencies = [ "sha2", "shlex", "sigstore-verification", - "strum", + "strum 0.27.2", "tar", "tempfile", "test-case", @@ -4709,7 +4702,7 @@ dependencies = [ "urlencoding", "webbrowser", "winapi", - "zip 8.2.0", + "zip 8.3.1", ] [[package]] @@ -4727,7 +4720,7 @@ dependencies = [ "lopdf", "once_cell", "reqwest 0.13.2", - "rmcp 1.2.0", + "rmcp", "schemars 1.2.1", "serde", "serde_json", @@ -4763,7 +4756,7 @@ dependencies = [ "rand 0.9.2", "rcgen", "reqwest 0.13.2", - "rmcp 1.2.0", + "rmcp", "rustls 0.23.37", "serde", "serde_json", @@ -4802,7 +4795,7 @@ dependencies = [ "axum 0.7.9", "env-lock", "opentelemetry", - "rmcp 1.2.0", + "rmcp", "serde_json", "tokio", ] @@ -5660,9 +5653,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "ixdtf" @@ -6125,9 +6118,9 @@ dependencies = [ [[package]] name = "linux-keyutils" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "761e49ec5fd8a5a463f9b84e877c373d888935b71c6be78f3767fe2ae6bed18e" +checksum = "83270a18e9f90d0707c41e9f35efada77b64c0e6f3f1810e71c8368a864d5590" dependencies = [ "bitflags 2.11.0", "libc", @@ -6882,8 +6875,8 @@ dependencies = [ "regex", "serde", "serde_json", - "strum", - "strum_macros", + "strum 0.27.2", + "strum_macros 0.27.2", "thiserror 2.0.18", ] @@ -7316,7 +7309,7 @@ dependencies = [ "deno_error", "pctx_config", "pctx_registry", - "rmcp 1.2.0", + "rmcp", "serde", "serde_json", "thiserror 2.0.18", @@ -7382,7 +7375,7 @@ dependencies = [ "opentelemetry-otlp", "opentelemetry_sdk", "reqwest 0.13.2", - "rmcp 1.2.0", + "rmcp", "serde", "serde_json", "shlex", @@ -7435,7 +7428,7 @@ checksum = "9a49c948ffc8c07357e76b2e0008503c9fafaf49d91fd7e00e672bbd7aabd157" dependencies = [ "deno_error", "pctx_config", - "rmcp 1.2.0", + "rmcp", "serde_json", "thiserror 2.0.18", "tokio", @@ -8053,9 +8046,9 @@ dependencies = [ [[package]] name = "pulldown-cmark" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6" +checksum = "14104c5a24d9bcf7eb2c24753e0f49fe14555d8bd565ea3d38e4b4303267259d" dependencies = [ "bitflags 2.11.0", "getopts", @@ -8643,28 +8636,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "rmcp" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528d42f8176e6e5e71ea69182b17d1d0a19a6b3b894b564678b74cd7cab13cfa" -dependencies = [ - "async-trait", - "base64 0.22.1", - "chrono", - "futures", - "pastey", - "pin-project-lite", - "rmcp-macros 0.12.0", - "schemars 1.2.1", - "serde", - "serde_json", - "thiserror 2.0.18", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "rmcp" version = "1.2.0" @@ -8685,7 +8656,7 @@ dependencies = [ "process-wrap", "rand 0.10.0", "reqwest 0.13.2", - "rmcp-macros 1.2.0", + "rmcp-macros", "schemars 1.2.1", "serde", "serde_json", @@ -8700,19 +8671,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "rmcp-macros" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3f81daaa494eb8e985c9462f7d6ce1ab05e5299f48aafd76cdd3d8b060e6f59" -dependencies = [ - "darling 0.23.0", - "proc-macro2", - "quote", - "serde_json", - "syn 2.0.117", -] - [[package]] name = "rmcp-macros" version = "1.2.0" @@ -8887,7 +8845,7 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.9", + "rustls-webpki 0.103.10", "subtle", "zeroize", ] @@ -8928,7 +8886,7 @@ dependencies = [ "rustls 0.23.37", "rustls-native-certs", "rustls-platform-verifier-android", - "rustls-webpki 0.103.9", + "rustls-webpki 0.103.10", "security-framework 3.7.0", "security-framework-sys", "webpki-root-certs", @@ -8953,9 +8911,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "aws-lc-rs", "ring", @@ -9011,18 +8969,17 @@ checksum = "dd29631678d6fb0903b69223673e122c32e9ae559d0960a38d574695ebc0ea15" [[package]] name = "sacp" -version = "10.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704f40d3c269b30229c34093b658ec80c4fac103281654b3965249c592dd6fa6" +version = "11.0.0" +source = "git+https://github.com/agentclientprotocol/symposium-acp?rev=14086c9#14086c93482ea0618ea6ca7a4d1288953569147e" dependencies = [ "agent-client-protocol-schema", "anyhow", "boxfnonce", "futures", "futures-concurrency", - "fxhash", "jsonrpcmsg", - "rmcp 0.12.0", + "rmcp", + "rustc-hash 2.1.1", "sacp-derive", "schemars 1.2.1", "serde", @@ -9036,9 +8993,8 @@ dependencies = [ [[package]] name = "sacp-derive" -version = "10.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92150f9246c01d501855e34469810a82adc27c416c8d8e21665567f8cd966f29" +version = "11.0.0" +source = "git+https://github.com/agentclientprotocol/symposium-acp?rev=14086c9#14086c93482ea0618ea6ca7a4d1288953569147e" dependencies = [ "proc-macro2", "quote", @@ -9598,7 +9554,7 @@ dependencies = [ "ring", "rsa", "rustls-pki-types", - "rustls-webpki 0.103.9", + "rustls-webpki 0.103.10", "scrypt", "serde", "serde_json", @@ -10167,7 +10123,16 @@ version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" dependencies = [ - "strum_macros", + "strum_macros 0.27.2", +] + +[[package]] +name = "strum" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" +dependencies = [ + "strum_macros 0.28.0", ] [[package]] @@ -10182,6 +10147,18 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "strum_macros" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab85eea0270ee17587ed4156089e10b9e6880ee688791d45a905f5b1ca36f664" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "subtle" version = "2.6.1" @@ -10908,9 +10885,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tar" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" dependencies = [ "filetime", "libc", @@ -11805,9 +11782,9 @@ dependencies = [ [[package]] name = "tree-sitter-rust" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b9b18034c684a2420722be8b2a91c9c44f2546b631c039edf575ccba8c61be1" +checksum = "f715f73a0687261ddb686f0d64a1e5af57bd199c4d12be5fdda6676ce1885bf9" dependencies = [ "cc", "tree-sitter-language", @@ -13355,18 +13332,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.42" +version = "0.8.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" +checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.42" +version = "0.8.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" +checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" dependencies = [ "proc-macro2", "quote", @@ -13499,9 +13476,9 @@ dependencies = [ [[package]] name = "zip" -version = "8.2.0" +version = "8.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b680f2a0cd479b4cff6e1233c483fdead418106eae419dc60200ae9850f6d004" +checksum = "5c546feb4481b0fbafb4ef0d79b6204fc41c6f9884b1b73b1d73f82442fc0845" dependencies = [ "crc32fast", "flate2", @@ -13612,9 +13589,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec5f41c76397b7da451efd19915684f727d7e1d516384ca6bd0ec43ec94de23c" +checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index a50cc28528cc..4b98f7776853 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,8 @@ string_slice = "warn" [workspace.dependencies] rmcp = { version = "1.2.0", features = ["schemars", "auth"] } -sacp = "10.1.0" +agent-client-protocol-schema = { version = "0.11", features = ["unstable"] } +sacp = "11.0.0" anyhow = "1.0" async-stream = "0.3" async-trait = "0.1" @@ -90,6 +91,9 @@ tree-sitter-typescript = "0.23" [patch.crates-io] v8 = { path = "vendor/v8" } +# TODO: switch to crates.io release once it includes the unstable feature and updates org from symposium-dev to agentclientprotocol +sacp = { git = "https://github.com/agentclientprotocol/symposium-acp", rev = "14086c9" } +sacp-derive = { git = "https://github.com/agentclientprotocol/symposium-acp", rev = "14086c9" } # TODO: switch to released version in opentelemetry 0.32.0 # https://github.com/open-telemetry/opentelemetry-rust/issues/3408 opentelemetry = { git = "https://github.com/open-telemetry/opentelemetry-rust", rev = "345cd74a" } diff --git a/crates/goose-acp-macros/src/lib.rs b/crates/goose-acp-macros/src/lib.rs index 48fa06f99e51..a15213f6536a 100644 --- a/crates/goose-acp-macros/src/lib.rs +++ b/crates/goose-acp-macros/src/lib.rs @@ -10,7 +10,7 @@ use syn::{ /// Generates two methods on the impl: /// /// 1. `handle_custom_request` — a dispatcher that: -/// - Prefixes each method name with `_goose/` +/// - Uses each annotation string as the method name (include `_goose/` for goose-only methods) /// - Parses JSON params into the handler's typed parameter (if any) /// - Serializes the handler's return value to JSON /// @@ -79,13 +79,13 @@ pub fn custom_methods(_attr: TokenStream, item: TokenStream) -> TokenStream { let arms: Vec<_> = routes .iter() .map(|route| { - let full_method = format!("_goose/{}", route.method_name); + let method = &route.method_name; let fn_ident = &route.fn_ident; match &route.param_type { Some(_) => { quote! { - #full_method => { + #method => { let req = serde_json::from_value(params) .map_err(|e| sacp::Error::invalid_params().data(e.to_string()))?; let result = self.#fn_ident(req).await?; @@ -96,7 +96,7 @@ pub fn custom_methods(_attr: TokenStream, item: TokenStream) -> TokenStream { } None => { quote! { - #full_method => { + #method => { let result = self.#fn_ident().await?; serde_json::to_value(&result) .map_err(|e| sacp::Error::internal_error().data(e.to_string())) @@ -111,7 +111,7 @@ pub fn custom_methods(_attr: TokenStream, item: TokenStream) -> TokenStream { let schema_entries: Vec<_> = routes .iter() .map(|route| { - let full_method = format!("_goose/{}", route.method_name); + let method = &route.method_name; let params_expr = if let Some(pt) = &route.param_type { if is_json_value(pt) { @@ -157,7 +157,7 @@ pub fn custom_methods(_attr: TokenStream, item: TokenStream) -> TokenStream { quote! { crate::custom_requests::CustomMethodSchema { - method: #full_method.to_string(), + method: #method.to_string(), params_schema: #params_expr, params_type_name: #params_name_expr, response_schema: #response_expr, diff --git a/crates/goose-acp/Cargo.toml b/crates/goose-acp/Cargo.toml index 04aac3578a69..bb5e0983befb 100644 --- a/crates/goose-acp/Cargo.toml +++ b/crates/goose-acp/Cargo.toml @@ -26,8 +26,8 @@ workspace = true goose = { path = "../goose", default-features = false } goose-mcp = { path = "../goose-mcp" } rmcp = { workspace = true } -sacp = { workspace = true } -agent-client-protocol-schema = { version = "0.10", features = ["unstable"] } +sacp = { workspace = true, features = ["unstable"] } +agent-client-protocol-schema = { workspace = true } async-trait = { workspace = true } anyhow = { workspace = true } tokio = { workspace = true } @@ -60,6 +60,7 @@ tempfile = { workspace = true } test-case = { workspace = true } axum = { workspace = true } rmcp = { workspace = true, features = ["transport-streamable-http-server"] } +sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio-rustls", "sqlite"] } [package.metadata.cargo-machete] # Used to provide extras imports for sacp diff --git a/crates/goose-acp/acp-meta.json b/crates/goose-acp/acp-meta.json index cbd9dd62f111..5050c86e6225 100644 --- a/crates/goose-acp/acp-meta.json +++ b/crates/goose-acp/acp-meta.json @@ -1,35 +1,30 @@ { "methods": [ { - "method": "extensions/add", + "method": "_goose/extensions/add", "requestType": "AddExtensionRequest", "responseType": "EmptyResponse" }, { - "method": "extensions/remove", + "method": "_goose/extensions/remove", "requestType": "RemoveExtensionRequest", "responseType": "EmptyResponse" }, { - "method": "tools", + "method": "_goose/tools", "requestType": "GetToolsRequest", "responseType": "GetToolsResponse" }, { - "method": "resource/read", + "method": "_goose/resource/read", "requestType": "ReadResourceRequest", "responseType": "ReadResourceResponse" }, { - "method": "working_dir/update", + "method": "_goose/working_dir/update", "requestType": "UpdateWorkingDirRequest", "responseType": "EmptyResponse" }, - { - "method": "session/list", - "requestType": null, - "responseType": "ListSessionsResponse" - }, { "method": "session/get", "requestType": "GetSessionRequest", @@ -41,17 +36,17 @@ "responseType": "EmptyResponse" }, { - "method": "session/export", + "method": "_goose/session/export", "requestType": "ExportSessionRequest", "responseType": "ExportSessionResponse" }, { - "method": "session/import", + "method": "_goose/session/import", "requestType": "ImportSessionRequest", "responseType": "ImportSessionResponse" }, { - "method": "config/extensions", + "method": "_goose/config/extensions", "requestType": null, "responseType": "GetExtensionsResponse" } diff --git a/crates/goose-acp/acp-schema.json b/crates/goose-acp/acp-schema.json index f090ccf919cf..f6db67e3fac3 100644 --- a/crates/goose-acp/acp-schema.json +++ b/crates/goose-acp/acp-schema.json @@ -5,7 +5,7 @@ "AddExtensionRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" }, "config": { @@ -13,12 +13,12 @@ } }, "required": [ - "session_id", + "sessionId", "config" ], - "description": "Add an extension to an active session.\nMethod: `_agent/extensions/add`", + "description": "Add an extension to an active session.", "x-side": "agent", - "x-method": "extensions/add" + "x-method": "_goose/extensions/add" }, "EmptyResponse": { "type": "object", @@ -28,7 +28,7 @@ "RemoveExtensionRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" }, "name": { @@ -36,26 +36,26 @@ } }, "required": [ - "session_id", + "sessionId", "name" ], - "description": "Remove an extension from an active session.\nMethod: `_agent/extensions/remove`", + "description": "Remove an extension from an active session.", "x-side": "agent", - "x-method": "extensions/remove" + "x-method": "_goose/extensions/remove" }, "GetToolsRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" } }, "required": [ - "session_id" + "sessionId" ], - "description": "List all tools available in a session.\nMethod: `_agent/tools`", + "description": "List all tools available in a session.", "x-side": "agent", - "x-method": "tools" + "x-method": "_goose/tools" }, "GetToolsResponse": { "type": "object", @@ -70,29 +70,29 @@ "tools" ], "x-side": "agent", - "x-method": "tools" + "x-method": "_goose/tools" }, "ReadResourceRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" }, "uri": { "type": "string" }, - "extension_name": { + "extensionName": { "type": "string" } }, "required": [ - "session_id", + "sessionId", "uri", - "extension_name" + "extensionName" ], - "description": "Read a resource from an extension.\nMethod: `_agent/resource/read`", + "description": "Read a resource from an extension.", "x-side": "agent", - "x-method": "resource/read" + "x-method": "_goose/resource/read" }, "ReadResourceResponse": { "type": "object", @@ -105,56 +105,41 @@ "result" ], "x-side": "agent", - "x-method": "resource/read" + "x-method": "_goose/resource/read" }, "UpdateWorkingDirRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" }, - "working_dir": { + "workingDir": { "type": "string" } }, "required": [ - "session_id", - "working_dir" + "sessionId", + "workingDir" ], - "description": "Update the working directory for a session.\nMethod: `_agent/working_dir/update`", + "description": "Update the working directory for a session.", "x-side": "agent", - "x-method": "working_dir/update" - }, - "ListSessionsResponse": { - "type": "object", - "properties": { - "sessions": { - "type": "array", - "items": true - } - }, - "required": [ - "sessions" - ], - "description": "List all sessions.\nMethod: `_session/list`", - "x-side": "agent", - "x-method": "session/list" + "x-method": "_goose/working_dir/update" }, "GetSessionRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" }, - "include_messages": { + "includeMessages": { "type": "boolean", "default": false } }, "required": [ - "session_id" + "sessionId" ], - "description": "Get a session by ID.\nMethod: `_session/get`", + "description": "Get a session by ID.", "x-side": "agent", "x-method": "session/get" }, @@ -175,30 +160,30 @@ "DeleteSessionRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" } }, "required": [ - "session_id" + "sessionId" ], - "description": "Delete a session.\nMethod: `_session/delete`", + "description": "Delete a session.", "x-side": "agent", "x-method": "session/delete" }, "ExportSessionRequest": { "type": "object", "properties": { - "session_id": { + "sessionId": { "type": "string" } }, "required": [ - "session_id" + "sessionId" ], - "description": "Export a session as a JSON string.\nMethod: `_session/export`", + "description": "Export a session as a JSON string.", "x-side": "agent", - "x-method": "session/export" + "x-method": "_goose/session/export" }, "ExportSessionResponse": { "type": "object", @@ -211,7 +196,7 @@ "data" ], "x-side": "agent", - "x-method": "session/export" + "x-method": "_goose/session/export" }, "ImportSessionRequest": { "type": "object", @@ -223,9 +208,9 @@ "required": [ "data" ], - "description": "Import a session from a JSON string.\nMethod: `_session/import`", + "description": "Import a session from a JSON string.", "x-side": "agent", - "x-method": "session/import" + "x-method": "_goose/session/import" }, "ImportSessionResponse": { "type": "object", @@ -238,7 +223,7 @@ "session" ], "x-side": "agent", - "x-method": "session/import" + "x-method": "_goose/session/import" }, "GetExtensionsResponse": { "type": "object", @@ -259,9 +244,9 @@ "extensions", "warnings" ], - "description": "List configured extensions and any warnings.\nMethod: `_config/extensions`", + "description": "List configured extensions and any warnings.", "x-side": "agent", - "x-method": "config/extensions" + "x-method": "_goose/config/extensions" }, "ExtRequest": { "properties": { @@ -326,7 +311,7 @@ "$ref": "#/$defs/GetSessionRequest" } ], - "description": "Params for _goose/session/get", + "description": "Params for session/get", "title": "GetSessionRequest" }, { @@ -335,7 +320,7 @@ "$ref": "#/$defs/DeleteSessionRequest" } ], - "description": "Params for _goose/session/delete", + "description": "Params for session/delete", "title": "DeleteSessionRequest" }, { @@ -410,14 +395,6 @@ ], "title": "ReadResourceResponse" }, - { - "allOf": [ - { - "$ref": "#/$defs/ListSessionsResponse" - } - ], - "title": "ListSessionsResponse" - }, { "allOf": [ { diff --git a/crates/goose-acp/src/bin/generate_acp_schema.rs b/crates/goose-acp/src/bin/generate_acp_schema.rs index 688d4f4a7046..f4e8ac9eb480 100644 --- a/crates/goose-acp/src/bin/generate_acp_schema.rs +++ b/crates/goose-acp/src/bin/generate_acp_schema.rs @@ -17,15 +17,10 @@ fn main() { .map(|(k, v)| (k, serde_json::to_value(v).unwrap_or(json!({})))) .collect(); - // Strip the `_goose/` prefix to get the bare method name for x-method. - fn bare_method(full: &str) -> &str { - full.strip_prefix("_goose/").unwrap_or(full) - } - // Track which types map to which methods so we can detect shared types. let mut type_methods: HashMap> = HashMap::new(); for m in &methods { - let method = bare_method(&m.method).to_string(); + let method = m.method.clone(); if let Some(name) = &m.params_type_name { type_methods .entry(name.clone()) @@ -172,7 +167,7 @@ fn main() { .iter() .map(|m| { json!({ - "method": bare_method(&m.method), + "method": &m.method, "requestType": m.params_type_name, "responseType": m.response_type_name, }) diff --git a/crates/goose-acp/src/custom_requests.rs b/crates/goose-acp/src/custom_requests.rs index 816d7bac9e86..2782e25d7802 100644 --- a/crates/goose-acp/src/custom_requests.rs +++ b/crates/goose-acp/src/custom_requests.rs @@ -21,6 +21,7 @@ pub struct CustomMethodSchema { /// Add an extension to an active session. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct AddExtensionRequest { pub session_id: String, /// Extension configuration (see ExtensionConfig variants: Stdio, StreamableHttp, Builtin, Platform). @@ -29,6 +30,7 @@ pub struct AddExtensionRequest { /// Remove an extension from an active session. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct RemoveExtensionRequest { pub session_id: String, pub name: String, @@ -36,6 +38,7 @@ pub struct RemoveExtensionRequest { /// List all tools available in a session. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct GetToolsRequest { pub session_id: String, } @@ -48,6 +51,7 @@ pub struct GetToolsResponse { /// Read a resource from an extension. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct ReadResourceRequest { pub session_id: String, pub uri: String, @@ -62,6 +66,7 @@ pub struct ReadResourceResponse { /// Update the working directory for a session. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct UpdateWorkingDirRequest { pub session_id: String, pub working_dir: String, @@ -69,6 +74,7 @@ pub struct UpdateWorkingDirRequest { /// Get a session by ID. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct GetSessionRequest { pub session_id: String, #[serde(default)] @@ -84,12 +90,14 @@ pub struct GetSessionResponse { /// Delete a session. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct DeleteSessionRequest { pub session_id: String, } /// Export a session as a JSON string. #[derive(Debug, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct ExportSessionRequest { pub session_id: String, } diff --git a/crates/goose-acp/src/fs.rs b/crates/goose-acp/src/fs.rs index a751aa34cacd..7677deae3f4e 100644 --- a/crates/goose-acp/src/fs.rs +++ b/crates/goose-acp/src/fs.rs @@ -10,12 +10,12 @@ use goose::agents::platform_extensions::developer::shell::{ShellParams, OUTPUT_L use goose::agents::platform_extensions::developer::DeveloperClient; use rmcp::model::{CallToolResult, Content as RmcpContent, Tool, ToolAnnotations}; use sacp::schema::{ - CreateTerminalRequest, Diff, KillTerminalCommandRequest, ReadTextFileRequest, - ReleaseTerminalRequest, SessionId, SessionNotification, SessionUpdate, Terminal, - TerminalOutputRequest, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallUpdate, - ToolCallUpdateFields, ToolKind, WaitForTerminalExitRequest, WriteTextFileRequest, + CreateTerminalRequest, Diff, KillTerminalRequest, ReadTextFileRequest, ReleaseTerminalRequest, + SessionId, SessionNotification, SessionUpdate, Terminal, TerminalOutputRequest, + ToolCallContent, ToolCallId, ToolCallLocation, ToolCallUpdate, ToolCallUpdateFields, ToolKind, + WaitForTerminalExitRequest, WriteTextFileRequest, }; -use sacp::{AgentToClient, JrConnectionCx}; +use sacp::{Client, ConnectionTo}; use schemars::schema_for; use std::path::Path; use std::sync::Arc; @@ -24,7 +24,7 @@ use tokio::time::timeout; use tokio_util::sync::CancellationToken; async fn acp_read_text_file( - cx: &JrConnectionCx, + cx: &ConnectionTo, session_id: &SessionId, path: &Path, line: Option, @@ -46,7 +46,7 @@ async fn acp_read_text_file( } async fn acp_write_text_file( - cx: &JrConnectionCx, + cx: &ConnectionTo, session_id: &SessionId, path: &Path, content: &str, @@ -62,7 +62,7 @@ async fn acp_write_text_file( pub(crate) struct AcpTools { pub(crate) inner: Arc, - pub(crate) cx: JrConnectionCx, + pub(crate) cx: ConnectionTo, pub(crate) session_id: SessionId, pub(crate) fs_read: bool, pub(crate) fs_write: bool, @@ -335,7 +335,7 @@ impl AcpTools { Err(_) => { let _ = self .cx - .send_request(KillTerminalCommandRequest::new( + .send_request(KillTerminalRequest::new( self.session_id.clone(), terminal_id.clone(), )) diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 2a93c9122f9b..730c1fc03b60 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -26,30 +26,37 @@ use goose::session::{Session, SessionManager}; use goose_acp_macros::custom_methods; use rmcp::model::{CallToolResult, RawContent, ResourceContents, Role}; use sacp::schema::{ - AgentCapabilities, AuthMethod, AuthenticateRequest, AuthenticateResponse, BlobResourceContents, - CancelNotification, Content, ContentBlock, ContentChunk, CurrentModeUpdate, EmbeddedResource, - EmbeddedResourceResource, FileSystemCapability, ImageContent, InitializeRequest, - InitializeResponse, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, - McpCapabilities, McpServer, ModelId, ModelInfo, NewSessionRequest, NewSessionResponse, - PermissionOption, PermissionOptionKind, PromptCapabilities, PromptRequest, PromptResponse, - RequestPermissionOutcome, RequestPermissionRequest, ResourceLink, SessionCapabilities, - SessionId, SessionInfo, SessionListCapabilities, SessionMode, SessionModeId, SessionModeState, - SessionModelState, SessionNotification, SessionUpdate, SetSessionModeRequest, - SetSessionModeResponse, SetSessionModelRequest, SetSessionModelResponse, StopReason, - TextContent, TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, - ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind, + AgentCapabilities, AuthMethod, AuthMethodAgent, AuthenticateRequest, AuthenticateResponse, + BlobResourceContents, CancelNotification, CloseSessionRequest, CloseSessionResponse, + ConfigOptionUpdate, Content, ContentBlock, ContentChunk, CurrentModeUpdate, EmbeddedResource, + EmbeddedResourceResource, FileSystemCapabilities, ImageContent, InitializeRequest, + InitializeResponse, ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, + LoadSessionResponse, McpCapabilities, McpServer, ModelId, ModelInfo, NewSessionRequest, + NewSessionResponse, PermissionOption, PermissionOptionKind, PromptCapabilities, PromptRequest, + PromptResponse, RequestPermissionOutcome, RequestPermissionRequest, ResourceLink, + SessionCapabilities, SessionCloseCapabilities, SessionConfigOption, + SessionConfigOptionCategory, SessionConfigSelectOption, SessionId, SessionInfo, + SessionListCapabilities, SessionMode, SessionModeId, SessionModeState, SessionModelState, + SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, + SetSessionConfigOptionResponse, SetSessionModeRequest, SetSessionModeResponse, + SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent, TextResourceContents, + ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate, + ToolCallUpdateFields, ToolKind, +}; +use sacp::util::MatchDispatchFrom; +use sacp::{ + Agent as SacpAgent, ByteStreams, Client, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, + Responder, }; -use sacp::{AgentToClient, ByteStreams, Handled, JrConnectionCx, JrMessageHandler, MessageCx}; use std::collections::HashMap; use std::sync::Arc; +use strum::{EnumMessage, VariantNames}; use tokio::sync::{Mutex, OnceCell}; use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; use url::Url; -// Agent binds provider, extensions, and permission channels to a single session. -// ACP has no session/close, so sessions accumulate until transport closes. struct GooseAcpSession { agent: Arc, messages: Conversation, @@ -61,7 +68,7 @@ pub struct GooseAcpAgent { sessions: Arc>>, provider_factory: ProviderConstructor, builtins: Vec, - client_fs_capabilities: OnceCell, + client_fs_capabilities: OnceCell, client_terminal: OnceCell, config_dir: std::path::PathBuf, session_manager: Arc, @@ -306,25 +313,22 @@ fn builtin_to_extension_config(name: &str) -> ExtensionConfig { } } -async fn build_model_state(provider: &dyn Provider, current_model: &str) -> SessionModelState { - let models = match provider.fetch_recommended_models().await { - Ok(models) => models, - Err(e) => { - warn!(error = %e, "failed to fetch models, model selection will be unavailable"); - vec![] - } - }; - SessionModelState::new( - ModelId::new(current_model), +async fn build_model_state(provider: &dyn Provider) -> Result { + let models = provider + .fetch_recommended_models() + .await + .map_err(|e| sacp::Error::internal_error().data(e.to_string()))?; + let current_model = &provider.get_model_config().model_name; + Ok(SessionModelState::new( + ModelId::new(current_model.as_str()), models .iter() .map(|name| ModelInfo::new(ModelId::new(&**name), &**name)) .collect(), - ) + )) } fn build_mode_state(current_mode: GooseMode) -> Result { - use strum::{EnumMessage, VariantNames}; let mut available = Vec::with_capacity(GooseMode::VARIANTS.len()); for &name in GooseMode::VARIANTS { let goose_mode: GooseMode = name.parse().map_err(|_| { @@ -341,11 +345,47 @@ fn build_mode_state(current_mode: GooseMode) -> Result Vec { + let mode_options: Vec = mode_state + .available_modes + .iter() + .map(|m| { + SessionConfigSelectOption::new(m.id.0.clone(), m.name.clone()) + .description(m.description.clone()) + }) + .collect(); + let model_options: Vec = model_state + .available_models + .iter() + .map(|m| SessionConfigSelectOption::new(m.model_id.0.clone(), m.name.clone())) + .collect(); + vec![ + SessionConfigOption::select( + "mode", + "Mode", + mode_state.current_mode_id.0.clone(), + mode_options, + ) + .category(SessionConfigOptionCategory::Mode), + SessionConfigOption::select( + "model", + "Model", + model_state.current_model_id.0.clone(), + model_options, + ) + .category(SessionConfigOptionCategory::Model), + ] +} + impl GooseAcpAgent { pub fn permission_manager(&self) -> Arc { Arc::clone(&self.permission_manager) } + // TODO: goose reads Paths::in_state_dir globally (e.g. RequestLog), ignoring this data_dir. pub async fn new( provider_factory: ProviderConstructor, builtins: Vec, @@ -373,7 +413,7 @@ impl GooseAcpAgent { async fn create_agent_for_session( &self, - cx: Option<&JrConnectionCx>, + cx: Option<&ConnectionTo>, session_id: Option<&SessionId>, goose_mode: Option, ) -> Result> { @@ -506,7 +546,7 @@ impl GooseAcpAgent { content_item: &MessageContent, session_id: &SessionId, session: &mut GooseAcpSession, - cx: &JrConnectionCx, + cx: &ConnectionTo, ) -> Result<(), sacp::Error> { match content_item { MessageContent::Text(text) => { @@ -562,7 +602,7 @@ impl GooseAcpAgent { tool_request: &goose::conversation::message::ToolRequest, session_id: &SessionId, session: &mut GooseAcpSession, - cx: &JrConnectionCx, + cx: &ConnectionTo, ) -> Result<(), sacp::Error> { session .tool_requests @@ -592,7 +632,7 @@ impl GooseAcpAgent { tool_response: &goose::conversation::message::ToolResponse, session_id: &SessionId, session: &mut GooseAcpSession, - cx: &JrConnectionCx, + cx: &ConnectionTo, ) -> Result<(), sacp::Error> { let status = match &tool_response.tool_result { Ok(result) if result.is_error == Some(true) => ToolCallStatus::Failed, @@ -635,7 +675,7 @@ impl GooseAcpAgent { #[allow(clippy::too_many_arguments)] fn handle_tool_permission_request( &self, - cx: &JrConnectionCx, + cx: &ConnectionTo, agent: &Arc, session_id: &SessionId, request_id: String, @@ -776,7 +816,11 @@ impl GooseAcpAgent { let capabilities = AgentCapabilities::new() .load_session(true) - .session_capabilities(SessionCapabilities::new().list(SessionListCapabilities::new())) + .session_capabilities( + SessionCapabilities::new() + .list(SessionListCapabilities::new()) + .close(SessionCloseCapabilities::new()), + ) .prompt_capabilities( PromptCapabilities::new() .image(true) @@ -786,18 +830,15 @@ impl GooseAcpAgent { .mcp_capabilities(McpCapabilities::new().http(true)); Ok(InitializeResponse::new(args.protocol_version) .agent_capabilities(capabilities) - .auth_methods(vec![AuthMethod::new( - "goose-provider", - "Configure Provider", - ) - .description( - "Run `goose configure` to set up your AI provider and API key", + .auth_methods(vec![AuthMethod::Agent( + AuthMethodAgent::new("goose-provider", "Configure Provider") + .description("Run `goose configure` to set up your AI provider and API key"), )])) } async fn on_new_session( &self, - cx: &JrConnectionCx, + cx: &ConnectionTo, args: NewSessionRequest, ) -> Result { debug!(?args, "new session request"); @@ -849,13 +890,13 @@ impl GooseAcpAgent { "Session started" ); - let model_state = - build_model_state(&*provider, &provider.get_model_config().model_name).await; + let model_state = build_model_state(&*provider).await?; let mode_state = build_mode_state(self.goose_mode)?; Ok(NewSessionResponse::new(SessionId::new(goose_session.id)) - .models(model_state) - .modes(mode_state)) + .models(model_state.clone()) + .modes(mode_state.clone()) + .config_options(build_config_options(&mode_state, &model_state))) } async fn init_provider(&self, agent: &Agent, session: &Session) -> Result> { @@ -881,7 +922,8 @@ impl GooseAcpAgent { ) -> Result, sacp::Error> { let mut sessions = self.sessions.lock().await; let session = sessions.get_mut(session_id).ok_or_else(|| { - sacp::Error::invalid_params().data(format!("Session not found: {}", session_id)) + sacp::Error::resource_not_found(Some(session_id.to_string())) + .data(format!("Session not found: {}", session_id)) })?; if let Some(token) = cancel_token { session.cancel_token = Some(token); @@ -912,7 +954,7 @@ impl GooseAcpAgent { async fn on_load_session( &self, - cx: &JrConnectionCx, + cx: &ConnectionTo, args: LoadSessionRequest, ) -> Result { debug!(?args, "load session request"); @@ -923,9 +965,9 @@ impl GooseAcpAgent { .session_manager .get_session(&session_id, true) .await - .map_err(|e| { - sacp::Error::invalid_params() - .data(format!("Failed to load session {}: {}", session_id, e)) + .map_err(|_| { + sacp::Error::resource_not_found(Some(session_id.clone())) + .data(format!("Session not found: {}", session_id)) })?; let loaded_mode = goose_session.goose_mode; @@ -1027,18 +1069,18 @@ impl GooseAcpAgent { "Session loaded" ); - let model_state = - build_model_state(&*provider, &provider.get_model_config().model_name).await; + let model_state = build_model_state(&*provider).await?; let mode_state = build_mode_state(goose_mode)?; Ok(LoadSessionResponse::new() - .models(model_state) - .modes(mode_state)) + .models(model_state.clone()) + .modes(mode_state.clone()) + .config_options(build_config_options(&mode_state, &model_state))) } async fn on_prompt( &self, - cx: &JrConnectionCx, + cx: &ConnectionTo, args: PromptRequest, ) -> Result { let session_id = args.session_id.0.to_string(); @@ -1162,6 +1204,25 @@ impl GooseAcpAgent { Ok(SetSessionModelResponse::new()) } + async fn build_config_update( + &self, + session_id: &SessionId, + ) -> Result<(SessionNotification, Vec), sacp::Error> { + let agent = self.get_session_agent(&session_id.0, None).await?; + let provider = agent.provider().await.map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to get provider: {}", e)) + })?; + let goose_mode = agent.goose_mode().await; + let model_state = build_model_state(&*provider).await?; + let mode_state = build_mode_state(goose_mode)?; + let config_options = build_config_options(&mode_state, &model_state); + let notification = SessionNotification::new( + session_id.clone(), + SessionUpdate::ConfigOptionUpdate(ConfigOptionUpdate::new(config_options.clone())), + ); + Ok((notification, config_options)) + } + async fn on_set_mode( &self, session_id: &str, @@ -1181,11 +1242,44 @@ impl GooseAcpAgent { Ok(SetSessionModeResponse::new()) } + + async fn on_list_sessions(&self) -> Result { + let sessions = self + .session_manager + .list_sessions() + .await + .map_err(|e| sacp::Error::internal_error().data(e.to_string()))?; + let session_infos: Vec = sessions + .into_iter() + .map(|s| { + SessionInfo::new(SessionId::new(s.id), s.working_dir) + .title(s.name) + .updated_at(s.updated_at.to_rfc3339()) + }) + .collect(); + Ok(ListSessionsResponse::new(session_infos)) + } + + async fn on_close_session( + &self, + session_id: &str, + ) -> Result { + let mut sessions = self.sessions.lock().await; + // Cancel before removing so on_prompt sees cancellation before session disappears. + if let Some(session) = sessions.get(session_id) { + if let Some(ref token) = session.cancel_token { + token.cancel(); + } + } + sessions.remove(session_id); + info!(session_id = %session_id, "session closed"); + Ok(CloseSessionResponse::new()) + } } #[custom_methods] impl GooseAcpAgent { - #[custom_method("extensions/add")] + #[custom_method("_goose/extensions/add")] async fn on_add_extension( &self, req: AddExtensionRequest, @@ -1200,7 +1294,7 @@ impl GooseAcpAgent { Ok(EmptyResponse {}) } - #[custom_method("extensions/remove")] + #[custom_method("_goose/extensions/remove")] async fn on_remove_extension( &self, req: RemoveExtensionRequest, @@ -1213,7 +1307,7 @@ impl GooseAcpAgent { Ok(EmptyResponse {}) } - #[custom_method("tools")] + #[custom_method("_goose/tools")] async fn on_get_tools(&self, req: GetToolsRequest) -> Result { let agent = self.get_session_agent(&req.session_id, None).await?; let tools = agent.list_tools(&req.session_id, None).await; @@ -1225,7 +1319,7 @@ impl GooseAcpAgent { Ok(GetToolsResponse { tools: tools_json }) } - #[custom_method("resource/read")] + #[custom_method("_goose/resource/read")] async fn on_read_resource( &self, req: ReadResourceRequest, @@ -1244,7 +1338,7 @@ impl GooseAcpAgent { }) } - #[custom_method("working_dir/update")] + #[custom_method("_goose/working_dir/update")] async fn on_update_working_dir( &self, req: UpdateWorkingDirRequest, @@ -1276,24 +1370,7 @@ impl GooseAcpAgent { Ok(EmptyResponse {}) } - #[custom_method("session/list")] - async fn on_list_sessions(&self) -> Result { - let sessions = self - .session_manager - .list_sessions() - .await - .map_err(|e| sacp::Error::internal_error().data(e.to_string()))?; - let session_infos: Vec = sessions - .into_iter() - .map(|s| { - SessionInfo::new(SessionId::new(s.id), s.working_dir) - .title(s.name) - .updated_at(s.updated_at.to_rfc3339()) - }) - .collect(); - Ok(ListSessionsResponse::new(session_infos)) - } - + // TODO: use typed GetSessionRequest when agent-client-protocol-schema adds it (Discussion #60) #[custom_method("session/get")] async fn on_get_session( &self, @@ -1311,6 +1388,7 @@ impl GooseAcpAgent { }) } + // TODO: use typed DeleteSessionRequest when agent-client-protocol-schema adds it (RFD #395) #[custom_method("session/delete")] async fn on_delete_session( &self, @@ -1320,10 +1398,11 @@ impl GooseAcpAgent { .delete_session(&req.session_id) .await .map_err(|e| sacp::Error::internal_error().data(e.to_string()))?; + self.sessions.lock().await.remove(&req.session_id); Ok(EmptyResponse {}) } - #[custom_method("session/export")] + #[custom_method("_goose/session/export")] async fn on_export_session( &self, req: ExportSessionRequest, @@ -1336,7 +1415,7 @@ impl GooseAcpAgent { Ok(ExportSessionResponse { data }) } - #[custom_method("session/import")] + #[custom_method("_goose/session/import")] async fn on_import_session( &self, req: ImportSessionRequest, @@ -1353,7 +1432,7 @@ impl GooseAcpAgent { }) } - #[custom_method("config/extensions")] + #[custom_method("_goose/config/extensions")] async fn on_get_extensions(&self) -> Result { let extensions = goose::config::extensions::get_all_extensions(); let warnings = goose::config::extensions::get_warnings(); @@ -1373,62 +1452,57 @@ pub struct GooseAcpHandler { pub agent: Arc, } -impl JrMessageHandler for GooseAcpHandler { - type Link = AgentToClient; - +impl HandleDispatchFrom for GooseAcpHandler { fn describe_chain(&self) -> impl std::fmt::Debug { "goose-acp" } - fn handle_message( + fn handle_dispatch_from( &mut self, - message: MessageCx, - cx: JrConnectionCx, - ) -> impl std::future::Future, sacp::Error>> + Send { - use sacp::util::MatchMessageFrom; - use sacp::JrRequestCx; - + message: Dispatch, + cx: ConnectionTo, + ) -> impl std::future::Future, sacp::Error>> + Send { let agent = self.agent.clone(); - // The MatchMessageFrom chain produces an ~85KB async state machine. + // The MatchDispatchFrom chain produces an ~85KB async state machine. // Box::pin moves it to the heap so it doesn't overflow the tokio worker stack. Box::pin(async move { - MatchMessageFrom::new(message, &cx) + MatchDispatchFrom::new(message, &cx) .if_request( - |req: InitializeRequest, req_cx: JrRequestCx| async { - req_cx.respond(agent.on_initialize(req).await?) + |req: InitializeRequest, responder: Responder| async { + responder.respond_with_result(agent.on_initialize(req).await) }, ) .await .if_request( - |_req: AuthenticateRequest, req_cx: JrRequestCx| async { - req_cx.respond(AuthenticateResponse::new()) + |_req: AuthenticateRequest, responder: Responder| async { + responder.respond(AuthenticateResponse::new()) }, ) .await .if_request( - |req: NewSessionRequest, req_cx: JrRequestCx| async { - req_cx.respond(agent.on_new_session(&cx, req).await?) + |req: NewSessionRequest, responder: Responder| async { + responder.respond_with_result(agent.on_new_session(&cx, req).await) }, ) .await .if_request( - |req: LoadSessionRequest, req_cx: JrRequestCx| async { - req_cx.respond(agent.on_load_session(&cx, req).await?) + |req: LoadSessionRequest, responder: Responder| async { + responder.respond_with_result(agent.on_load_session(&cx, req).await) }, ) .await .if_request( - |req: PromptRequest, req_cx: JrRequestCx| async { + |req: PromptRequest, responder: Responder| async { let agent = agent.clone(); let cx_clone = cx.clone(); cx.spawn(async move { match agent.on_prompt(&cx_clone, req).await { Ok(response) => { - req_cx.respond(response)?; + responder.respond(response)?; } Err(e) => { - req_cx.respond_with_error(e)?; + responder.respond_with_error(e)?; } } Ok(()) @@ -1439,76 +1513,118 @@ impl JrMessageHandler for GooseAcpHandler { .await .if_notification(|notif: CancelNotification| async { agent.on_cancel(notif).await }) .await - // Handle methods not yet in the sacp typed API. - // - session/set_model, session/set_mode: typed support pending in sacp - // - _: custom requests that will eventually route to goose-server - .otherwise({ + // set_config_option (SACP 11) and legacy set_mode/set_model; custom _goose/* in otherwise. + .if_request({ let agent = agent.clone(); let cx = cx.clone(); - |message: MessageCx| async move { - match message { - MessageCx::Request(req, request_cx) - if req.method == "session/set_mode" => - { - let params: SetSessionModeRequest = - serde_json::from_value(req.params).map_err(|e| { - sacp::Error::invalid_params().data(e.to_string()) - })?; - let session_id = params.session_id.clone(); - let mode_id = params.mode_id.clone(); - match agent.on_set_mode(&session_id.0, &mode_id.0).await { - Ok(resp) => { - let json = serde_json::to_value(resp).map_err(|e| { - sacp::Error::internal_error().data(e.to_string()) - })?; - // Notify before responding so clients see the mode - // update before block_task unblocks (serial dispatch). - cx.send_notification(SessionNotification::new( - session_id, - SessionUpdate::CurrentModeUpdate( - CurrentModeUpdate::new(mode_id), - ), - ))?; - request_cx.respond(json)?; - } - Err(e) => { - request_cx.respond_with_error(e)?; - } + |req: SetSessionConfigOptionRequest, responder: Responder| async move { + let value_id = req.value.as_value_id() + .ok_or_else(|| sacp::Error::invalid_params().data("Expected a value ID"))? + .clone(); + let session_id = req.session_id.clone(); + match req.config_id.0.as_ref() { + "mode" => { + match agent.on_set_mode(&session_id.0, &value_id.0).await { + Ok(_) => {} + Err(e) => { responder.respond_with_error(e)?; return Ok(()); } } - Ok(()) } - MessageCx::Request(req, request_cx) - if req.method == "session/set_model" => - { - let params: SetSessionModelRequest = - serde_json::from_value(req.params).map_err(|e| { - sacp::Error::invalid_params().data(e.to_string()) - })?; - let resp = agent - .on_set_model(¶ms.session_id.0, ¶ms.model_id.0) - .await?; - let json = serde_json::to_value(resp).map_err(|e| { - sacp::Error::internal_error().data(e.to_string()) - })?; - request_cx.respond(json)?; - Ok(()) + "model" => { + match agent.on_set_model(&session_id.0, &value_id.0).await { + Ok(_) => {} + Err(e) => { responder.respond_with_error(e)?; return Ok(()); } + } } - MessageCx::Request(req, request_cx) if req.method == "session/list" => { - let resp = agent.on_list_sessions().await?; - let json = serde_json::to_value(resp).map_err(|e| { - sacp::Error::internal_error().data(e.to_string()) - })?; - request_cx.respond(json)?; - Ok(()) + other => { + responder.respond_with_error( + sacp::Error::invalid_params().data(format!("Unsupported config option: {}", other)) + )?; + return Ok(()); + } + } + let (notification, config_options) = agent.build_config_update(&session_id).await?; + cx.send_notification(notification)?; + responder.respond(SetSessionConfigOptionResponse::new(config_options))?; + Ok(()) + } + }) + .await + .if_request({ + let agent = agent.clone(); + let cx = cx.clone(); + |req: SetSessionModeRequest, responder: Responder| async move { + let session_id = req.session_id.clone(); + let mode_id = req.mode_id.clone(); + match agent.on_set_mode(&session_id.0, &mode_id.0).await { + Ok(resp) => { + // Notify before responding so clients see the mode update before block_task unblocks. + cx.send_notification(SessionNotification::new( + session_id, + SessionUpdate::CurrentModeUpdate( + CurrentModeUpdate::new(mode_id), + ), + ))?; + responder.respond(resp)?; + } + Err(e) => { + responder.respond_with_error(e)?; + } + } + Ok(()) + } + }) + .await + .if_request({ + let agent = agent.clone(); + let cx = cx.clone(); + |req: SetSessionModelRequest, responder: Responder| async move { + let session_id = req.session_id.clone(); + match agent.on_set_model(&session_id.0, &req.model_id.0).await { + Ok(resp) => { + let (notification, _) = agent.build_config_update(&session_id).await?; + cx.send_notification(notification)?; + responder.respond(resp)?; } - MessageCx::Request(req, request_cx) if req.method.starts_with('_') => { + Err(e) => responder.respond_with_error(e)?, + } + Ok(()) + } + }) + .await + .if_request({ + let agent = agent.clone(); + |_req: ListSessionsRequest, responder: Responder| async move { + responder.respond(agent.on_list_sessions().await?) + } + }) + .await + .if_request({ + let agent = agent.clone(); + |req: CloseSessionRequest, responder: Responder| async move { + responder.respond(agent.on_close_session(&req.session_id.0).await?) + } + }) + .await + .otherwise({ + let agent = agent.clone(); + |message: Dispatch| async move { + match message { + Dispatch::Request(req, responder) => { match agent.handle_custom_request(&req.method, req.params).await { - Ok(json) => request_cx.respond(json)?, - Err(e) => request_cx.respond_with_error(e)?, + Ok(json) => responder.respond(json)?, + Err(e) => responder.respond_with_error(e)?, } Ok(()) } - _ => Err(sacp::Error::method_not_found()), + Dispatch::Response(result, router) => { + debug!(method = %router.method(), id = %router.id(), ok = result.is_ok(), "routing response"); + router.respond_with_result(result)?; + Ok(()) + } + Dispatch::Notification(notif) => { + debug!(method = %notif.method, "unhandled notification"); + Ok(()) + } } } }) @@ -1530,10 +1646,11 @@ where Box::pin(async move { let handler = GooseAcpHandler { agent }; - AgentToClient::builder() + SacpAgent + .builder() .name("goose-acp") .with_handler(handler) - .serve(ByteStreams::new(write, read)) + .connect_to(ByteStreams::new(write, read)) .await?; Ok(()) @@ -1565,8 +1682,8 @@ mod tests { use rmcp::model::{CallToolRequestParams, Content as RmcpContent}; use sacp::schema::{ EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio, - PermissionOptionId, ResourceLink, SelectedPermissionOutcome, SessionMode, SessionModeId, - SessionModeState, + PermissionOptionId, ResourceLink, SelectedPermissionOutcome, SessionConfigSelectOption, + SessionMode, SessionModeId, SessionModeState, }; use std::io::Write; use std::path::PathBuf; @@ -1720,7 +1837,7 @@ print(\"hello, world\") } #[async_trait::async_trait] - impl goose::providers::base::Provider for MockModelProvider { + impl Provider for MockModelProvider { fn get_name(&self) -> &str { "mock" } @@ -1746,40 +1863,30 @@ print(\"hello, world\") } #[test_case( - "model-a", Ok(vec!["model-a".into(), "model-b".into()]) - => SessionModelState::new( - ModelId::new("model-a"), + Ok(vec!["model-a".into(), "model-b".into()]) + => Ok(SessionModelState::new( + ModelId::new("unused"), vec![ModelInfo::new(ModelId::new("model-a"), "model-a"), ModelInfo::new(ModelId::new("model-b"), "model-b")], - ) + )) ; "returns current and available models" )] #[test_case( - "model-a", Ok(vec![]) - => SessionModelState::new(ModelId::new("model-a"), vec![]) + Ok(vec![]) + => Ok(SessionModelState::new(ModelId::new("unused"), vec![])) ; "empty model list" )] #[test_case( - "model-a", Err(ProviderError::ExecutionError("fail".into())) - => SessionModelState::new(ModelId::new("model-a"), vec![]) - ; "fetch error falls back to current model only" - )] - #[test_case( - "switched-model", Ok(vec!["model-a".into(), "switched-model".into()]) - => SessionModelState::new( - ModelId::new("switched-model"), - vec![ModelInfo::new(ModelId::new("model-a"), "model-a"), - ModelInfo::new(ModelId::new("switched-model"), "switched-model")], - ) - ; "current model reflects switched model" + Err(ProviderError::ExecutionError("fail".into())) + => Err(sacp::Error::internal_error().data("Execution error: fail".to_string())) + ; "fetch error propagates" )] #[tokio::test] async fn test_build_model_state( - current_model: &str, models: Result, ProviderError>, - ) -> SessionModelState { + ) -> Result { let provider = MockModelProvider { models }; - build_model_state(&provider, current_model).await + build_model_state(&provider).await } fn json_object(pairs: Vec<(&str, serde_json::Value)>) -> rmcp::model::JsonObject { @@ -1920,24 +2027,94 @@ print(\"hello, world\") .map(|locs| locs.into_iter().map(|loc| (loc.path, loc.line)).collect()) } - #[test] - fn test_build_mode_state() { - let state = build_mode_state(GooseMode::Auto).unwrap(); - assert_eq!( - state, - SessionModeState::new( - SessionModeId::new("auto"), + #[test_case( + GooseMode::Auto + => Ok(SessionModeState::new( + SessionModeId::new("auto"), + vec![ + SessionMode::new(SessionModeId::new("auto"), "auto") + .description("Automatically approve tool calls"), + SessionMode::new(SessionModeId::new("approve"), "approve") + .description("Ask before every tool call"), + SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") + .description("Ask only for sensitive tool calls"), + SessionMode::new(SessionModeId::new("chat"), "chat") + .description("Chat only, no tool calls"), + ], + )) + ; "auto mode" + )] + #[test_case( + GooseMode::Approve + => Ok(SessionModeState::new( + SessionModeId::new("approve"), + vec![ + SessionMode::new(SessionModeId::new("auto"), "auto") + .description("Automatically approve tool calls"), + SessionMode::new(SessionModeId::new("approve"), "approve") + .description("Ask before every tool call"), + SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") + .description("Ask only for sensitive tool calls"), + SessionMode::new(SessionModeId::new("chat"), "chat") + .description("Chat only, no tool calls"), + ], + )) + ; "approve mode" + )] + fn test_build_mode_state(current_mode: GooseMode) -> Result { + build_mode_state(current_mode) + } + + #[test_case( + build_mode_state(GooseMode::Auto).unwrap(), + SessionModelState::new( + ModelId::new("gpt-4"), + vec![ModelInfo::new(ModelId::new("gpt-4"), "gpt-4"), ModelInfo::new(ModelId::new("gpt-3.5"), "gpt-3.5")], + ) + => vec![ + SessionConfigOption::select( + "mode", "Mode", "auto", vec![ - SessionMode::new(SessionModeId::new("auto"), "auto") - .description("Automatically approve tool calls"), - SessionMode::new(SessionModeId::new("approve"), "approve") - .description("Ask before every tool call"), - SessionMode::new(SessionModeId::new("smart_approve"), "smart_approve") - .description("Ask only for sensitive tool calls"), - SessionMode::new(SessionModeId::new("chat"), "chat") - .description("Chat only, no tool calls"), + SessionConfigSelectOption::new("auto", "auto").description("Automatically approve tool calls"), + SessionConfigSelectOption::new("approve", "approve").description("Ask before every tool call"), + SessionConfigSelectOption::new("smart_approve", "smart_approve").description("Ask only for sensitive tool calls"), + SessionConfigSelectOption::new("chat", "chat").description("Chat only, no tool calls"), ], - ) - ); + ).category(SessionConfigOptionCategory::Mode), + SessionConfigOption::select( + "model", "Model", "gpt-4", + vec![ + SessionConfigSelectOption::new("gpt-4", "gpt-4"), + SessionConfigSelectOption::new("gpt-3.5", "gpt-3.5"), + ], + ).category(SessionConfigOptionCategory::Model), + ] + ; "auto mode with multiple models" + )] + #[test_case( + build_mode_state(GooseMode::Approve).unwrap(), + SessionModelState::new(ModelId::new("only-model"), vec![ModelInfo::new(ModelId::new("only-model"), "only-model")]) + => vec![ + SessionConfigOption::select( + "mode", "Mode", "approve", + vec![ + SessionConfigSelectOption::new("auto", "auto").description("Automatically approve tool calls"), + SessionConfigSelectOption::new("approve", "approve").description("Ask before every tool call"), + SessionConfigSelectOption::new("smart_approve", "smart_approve").description("Ask only for sensitive tool calls"), + SessionConfigSelectOption::new("chat", "chat").description("Chat only, no tool calls"), + ], + ).category(SessionConfigOptionCategory::Mode), + SessionConfigOption::select( + "model", "Model", "only-model", + vec![SessionConfigSelectOption::new("only-model", "only-model")], + ).category(SessionConfigOptionCategory::Model), + ] + ; "approve mode with single model" + )] + fn test_build_config_options( + mode_state: SessionModeState, + model_state: SessionModelState, + ) -> Vec { + build_config_options(&mode_state, &model_state) } } diff --git a/crates/goose-acp/tests/common_tests/mod.rs b/crates/goose-acp/tests/common_tests/mod.rs index 58dc961ce21c..a3076b002283 100644 --- a/crates/goose-acp/tests/common_tests/mod.rs +++ b/crates/goose-acp/tests/common_tests/mod.rs @@ -6,18 +6,125 @@ pub mod fixtures; use fixtures::{ assert_notifications, Connection, FsFixture, Notification, OpenAiFixture, PermissionDecision, - Session, SessionResult, TerminalCall, TerminalFixture, TestConnectionConfig, + Session, SessionData, TerminalCall, TerminalFixture, TestConnectionConfig, }; use fs_err as fs; use goose::config::base::CONFIG_YAML_NAME; use goose::config::GooseMode; use goose::providers::provider_registry::ProviderConstructor; use goose_test_support::{McpFixture, FAKE_CODE, TEST_IMAGE_B64, TEST_MODEL}; -use sacp::schema::{McpServer, McpServerHttp, ModelId, SessionModeId, ToolCallStatus, ToolKind}; +use sacp::schema::{ + ListSessionsResponse, McpServer, McpServerHttp, ModelId, SessionInfo, SessionModeId, + ToolCallStatus, ToolKind, +}; +use sqlx::sqlite::SqlitePoolOptions; use std::sync::Arc; const SHELL_TEST_CONTENT: &str = "test-shell-content-98765"; +struct BasicSession { + conn: C, + session: C::Session, +} + +async fn new_basic_session(config: TestConnectionConfig) -> BasicSession { + let expected_session_id = C::expected_session_id(); + let openai = OpenAiFixture::new( + vec![( + r#"\nwhat is 1+1""#.into(), + include_str!("../test_data/openai_basic.txt"), + )], + expected_session_id.clone(), + ) + .await; + + let mut conn = C::new(config, openai).await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); + expected_session_id.set(&session.session_id().0); + + let output = session + .prompt("what is 1+1", PermissionDecision::Cancel) + .await + .unwrap(); + assert_eq!(output.text, "2"); + + BasicSession { conn, session } +} + +pub async fn run_list_sessions() { + let BasicSession { conn, session } = + new_basic_session::(TestConnectionConfig::default()).await; + let mut response = conn.list_sessions().await.unwrap(); + for s in &mut response.sessions { + s.updated_at = None; + } + assert_eq!( + response, + ListSessionsResponse::new(vec![SessionInfo::new( + session.session_id().clone(), + session.work_dir() + ) + .title("ACP Session".to_string())]) + ); +} + +pub async fn run_close_session() { + let BasicSession { conn, session } = + new_basic_session::(TestConnectionConfig::default()).await; + let sid = &session.session_id().0; + let data_root = conn.data_root(); + + conn.close_session(sid).await.unwrap(); + + // Provider close drops the connection, so verify via DB not list_sessions. + let db_path = data_root.join("sessions").join("sessions.db"); + let pool = SqlitePoolOptions::new() + .connect(&format!("sqlite:{}?mode=ro", db_path.display())) + .await + .unwrap(); + let db_ids: Vec = sqlx::query_scalar("SELECT id FROM sessions") + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(db_ids.len(), 1); + + let expected_session_id = C::expected_session_id(); + expected_session_id.set(sid); + expected_session_id.assert_matches(&db_ids[0]); +} + +pub async fn run_delete_session() { + let BasicSession { mut conn, session } = + new_basic_session::(TestConnectionConfig::default()).await; + let sid = session.session_id().0.to_string(); + + let before: Vec<_> = conn + .list_sessions() + .await + .unwrap() + .sessions + .iter() + .map(|s| s.session_id.clone()) + .collect(); + assert!(before.contains(session.session_id())); + + conn.delete_session(&sid).await.unwrap(); + + let after: Vec<_> = conn + .list_sessions() + .await + .unwrap() + .sessions + .iter() + .map(|s| s.session_id.clone()) + .collect(); + assert!(!after.contains(session.session_id())); + + let err = conn.load_session(&sid, vec![]).await.unwrap_err(); + let sacp_err = err.downcast::().unwrap(); + assert_eq!(sacp_err.code, sacp::ErrorCode::ResourceNotFound); +} + pub async fn run_config_mcp() { let temp_dir = tempfile::tempdir().unwrap(); let expected_session_id = C::expected_session_id(); @@ -51,10 +158,13 @@ pub async fn run_config_mcp() { }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(prompt, PermissionDecision::Cancel).await; + let output = session + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); assert_eq!(output.text, FAKE_CODE); assert_notifications( &session.notifications(), @@ -100,10 +210,13 @@ pub async fn run_fs_read_text_file_true() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(prompt, PermissionDecision::Cancel).await; + let output = session + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); assert_eq!(output.text, "test-read-content-12345"); assert_notifications( &session.notifications(), @@ -144,10 +257,13 @@ pub async fn run_fs_write_text_file_false() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(prompt, PermissionDecision::AllowOnce).await; + let output = session + .prompt(prompt, PermissionDecision::AllowOnce) + .await + .unwrap(); assert!(!output.text.is_empty()); assert_eq!( fs::read_to_string("/tmp/test_acp_write.txt").unwrap(), @@ -193,10 +309,13 @@ pub async fn run_fs_write_text_file_true() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(prompt, PermissionDecision::AllowOnce).await; + let output = session + .prompt(prompt, PermissionDecision::AllowOnce) + .await + .unwrap(); assert!(!output.text.is_empty()); assert_notifications( &session.notifications(), @@ -227,7 +346,7 @@ pub async fn run_initialize_doesnt_hit_provider() { assert!(conn .auth_methods() .iter() - .any(|m| &*m.id.0 == "goose-provider")); + .any(|m| m.id().0.as_ref() == "goose-provider")); } pub async fn run_load_mode() { @@ -263,7 +382,7 @@ pub async fn run_load_mode() { }; let mut conn = C::new(config, openai).await; - let SessionResult { session, modes, .. } = conn.new_session().await; + let SessionData { session, modes, .. } = conn.new_session().await.unwrap(); assert_eq!( modes.unwrap().current_mode_id, SessionModeId::new(<&str>::from(GooseMode::default())) @@ -273,11 +392,11 @@ pub async fn run_load_mode() { .await .unwrap(); - let SessionResult { + let SessionData { session: mut loaded, modes, .. - } = conn.load_session(&session_id, vec![]).await; + } = conn.load_session(&session_id, vec![]).await.unwrap(); assert_eq!( modes.unwrap().current_mode_id, SessionModeId::new(<&str>::from(GooseMode::Approve)) @@ -285,7 +404,10 @@ pub async fn run_load_mode() { // Approve mode + Cancel = permission denied → tool fails expected_session_id.set(&loaded.session_id().0); - let output = loaded.prompt(prompt, PermissionDecision::Cancel).await; + let output = loaded + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); assert_eq!(output.tool_status.unwrap(), ToolCallStatus::Failed); assert_notifications( &loaded.notifications(), @@ -310,7 +432,7 @@ pub async fn run_load_model() { .await; let mut conn = C::new(TestConnectionConfig::default(), openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let session_id = session.session_id().0.to_string(); @@ -318,10 +440,11 @@ pub async fn run_load_model() { let output = session .prompt("what is 1+1", PermissionDecision::Cancel) - .await; + .await + .unwrap(); assert_eq!(output.text, "2"); - let SessionResult { models, .. } = conn.load_session(&session_id, vec![]).await; + let SessionData { models, .. } = conn.load_session(&session_id, vec![]).await.unwrap(); assert_eq!(&*models.unwrap().current_model_id.0, "o4-mini"); } @@ -362,28 +485,107 @@ pub async fn run_load_session_mcp() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); // First prompt: tool should work in the new session. - let output = session.prompt(prompt, PermissionDecision::Cancel).await; + let output = session + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); assert_eq!(output.text, FAKE_CODE, "tool call failed in new session"); // Load the same session with MCP servers re-specified. let session_id = session.session_id().0.to_string(); - let SessionResult { + let SessionData { session: mut loaded_session, .. - } = conn.load_session(&session_id, mcp_servers).await; + } = conn.load_session(&session_id, mcp_servers).await.unwrap(); // Second prompt: tool should work in the loaded session. let output = loaded_session .prompt(prompt, PermissionDecision::Cancel) - .await; + .await + .unwrap(); assert_eq!(output.text, FAKE_CODE, "tool call failed in loaded session"); } +pub async fn run_load_session_error() { + let openai = OpenAiFixture::new(vec![], C::expected_session_id()).await; + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + + let err = conn + .load_session("nonexistent-session-id", vec![]) + .await + .unwrap_err(); + + let sacp_err = err.downcast::().unwrap(); + assert_eq!( + sacp_err, + sacp::Error::resource_not_found(Some("nonexistent-session-id".to_string())) + .data("Session not found: nonexistent-session-id") + ); +} + +pub async fn run_config_option_mode_set() { + run_mode_set_impl::(SetModeVia::ConfigOption).await; +} + +pub async fn run_config_option_set_error( + config_id: &str, + value: &str, + session_id_override: Option<&str>, + expected: sacp::Error, +) { + let openai = OpenAiFixture::new(vec![], C::expected_session_id()).await; + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let SessionData { session, .. } = conn.new_session().await.unwrap(); + + let target_session_id = session_id_override + .map(str::to_string) + .unwrap_or_else(|| session.session_id().0.to_string()); + + let err = conn + .set_config_option(&target_session_id, config_id, value) + .await + .unwrap_err(); + + let sacp_err = err.downcast::().unwrap(); + assert_eq!(sacp_err, expected); +} + +#[macro_export] +macro_rules! tests_config_option_set_error { + ($conn:ty) => { + #[test_case::test_case("mode", "not_a_mode", None, sacp::Error::invalid_params().data("Invalid mode: not_a_mode") ; "invalid mode via config option")] + #[test_case::test_case("mode", "auto", Some("nonexistent-session-id"), sacp::Error::resource_not_found(Some("nonexistent-session-id".to_string())).data("Session not found: nonexistent-session-id") ; "session not found via config option")] + #[test_case::test_case("thought_level", "high", None, sacp::Error::invalid_params().data("Unsupported config option: thought_level") ; "unsupported config option")] + fn test_config_option_set_error( + config_id: &'static str, + value: &'static str, + session_id: Option<&'static str>, + expected: sacp::Error, + ) { + common_tests::fixtures::run_test(async move { + common_tests::run_config_option_set_error::<$conn>( + config_id, value, session_id, expected, + ) + .await + }); + } + }; +} + pub async fn run_mode_set() { + run_mode_set_impl::(SetModeVia::Dedicated).await; +} + +enum SetModeVia { + Dedicated, + ConfigOption, +} + +async fn run_mode_set_impl(via: SetModeVia) { let temp_dir = tempfile::tempdir().unwrap(); let expected_session_id = C::expected_session_id(); let prompt = "Use the get_code tool and output only its result."; @@ -412,26 +614,45 @@ pub async fn run_mode_set() { let config = TestConnectionConfig { data_root: temp_dir.path().to_path_buf(), + strip_config_options: matches!(via, SetModeVia::Dedicated), ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { + let SessionData { session: mut session_a, .. - } = conn.new_session().await; + } = conn.new_session().await.unwrap(); - let SessionResult { + let SessionData { session: mut session_b, .. - } = conn.new_session().await; - conn.set_mode(&session_b.session_id().0, <&str>::from(GooseMode::Approve)) - .await - .unwrap(); + } = conn.new_session().await.unwrap(); + let session_id = &session_b.session_id().0; + let approve = <&str>::from(GooseMode::Approve); + match via { + SetModeVia::Dedicated => conn.set_mode(session_id, approve).await.unwrap(), + SetModeVia::ConfigOption => conn + .set_config_option(session_id, "mode", approve) + .await + .unwrap(), + } - // Approve mode + Cancel = permission denied → tool fails + match via { + SetModeVia::Dedicated => { + assert_notifications(&session_b.notifications(), &[Notification::CurrentMode]) + } + SetModeVia::ConfigOption => { + assert_notifications(&session_b.notifications(), &[Notification::ConfigOption]) + } + } + + // Approve mode + Cancel = permission denied -> tool fails expected_session_id.set(&session_b.session_id().0); - let output = session_b.prompt(prompt, PermissionDecision::Cancel).await; + let output = session_b + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); assert_eq!(output.tool_status.unwrap(), ToolCallStatus::Failed); assert_notifications( &session_b.notifications(), @@ -443,10 +664,13 @@ pub async fn run_mode_set() { ], ); - // Auto mode ignores Cancel — tool succeeds without permission prompt + // Auto mode ignores Cancel -- tool succeeds without permission prompt conn.reset_openai(); expected_session_id.set(&session_a.session_id().0); - let output = session_a.prompt(prompt, PermissionDecision::Cancel).await; + let output = session_a + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); assert_eq!(output.text, FAKE_CODE); assert_notifications( &session_a.notifications(), @@ -466,7 +690,7 @@ pub async fn run_mode_set_error( ) { let openai = OpenAiFixture::new(vec![], C::expected_session_id()).await; let mut conn = C::new(TestConnectionConfig::default(), openai).await; - let SessionResult { session, .. } = conn.new_session().await; + let SessionData { session, .. } = conn.new_session().await.unwrap(); let target_session_id = session_id_override .map(str::to_string) @@ -485,7 +709,7 @@ pub async fn run_mode_set_error( macro_rules! tests_mode_set_error { ($conn:ty) => { #[test_case::test_case("not_a_mode", None, sacp::Error::invalid_params().data("Invalid mode: not_a_mode") ; "invalid mode")] - #[test_case::test_case("auto", Some("nonexistent-session-id"), sacp::Error::invalid_params().data("Session not found: nonexistent-session-id") ; "session not found")] + #[test_case::test_case("auto", Some("nonexistent-session-id"), sacp::Error::resource_not_found(Some("nonexistent-session-id".to_string())).data("Session not found: nonexistent-session-id") ; "session not found")] fn test_mode_set_error( mode_id: &'static str, session_id: Option<&'static str>, @@ -506,9 +730,9 @@ pub async fn run_model_list() { let openai = OpenAiFixture::new(vec![], expected_session_id.clone()).await; let mut conn = C::new(TestConnectionConfig::default(), openai).await; - let SessionResult { + let SessionData { session, models, .. - } = conn.new_session().await; + } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let models = models.unwrap(); @@ -516,7 +740,20 @@ pub async fn run_model_list() { assert_eq!(models.current_model_id, ModelId::new(TEST_MODEL)); } +pub async fn run_config_option_model_set() { + run_model_set_impl::(SetModelVia::ConfigOption).await; +} + pub async fn run_model_set() { + run_model_set_impl::(SetModelVia::Dedicated).await; +} + +enum SetModelVia { + Dedicated, + ConfigOption, +} + +async fn run_model_set_impl(via: SetModelVia) { let expected_session_id = C::expected_session_id(); let openai = OpenAiFixture::new( vec![ @@ -535,36 +772,104 @@ pub async fn run_model_set() { ) .await; - let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let config = TestConnectionConfig { + strip_config_options: matches!(via, SetModelVia::Dedicated), + ..Default::default() + }; + let mut conn = C::new(config, openai).await; // Session A: default model - let SessionResult { + let SessionData { session: mut session_a, .. - } = conn.new_session().await; + } = conn.new_session().await.unwrap(); // Session B: switch to o4-mini - let SessionResult { + let SessionData { session: mut session_b, .. - } = conn.new_session().await; - conn.set_model(&session_b.session_id().0, "o4-mini") - .await - .unwrap(); + } = conn.new_session().await.unwrap(); + let session_id = &session_b.session_id().0; + match via { + SetModelVia::Dedicated => conn.set_model(session_id, "o4-mini").await.unwrap(), + SetModelVia::ConfigOption => conn + .set_config_option(session_id, "model", "o4-mini") + .await + .unwrap(), + } + + let set_model_notifs = session_b.notifications(); // Prompt B — expects o4-mini expected_session_id.set(&session_b.session_id().0); let output = session_b .prompt("what is 1+1", PermissionDecision::Cancel) - .await; + .await + .unwrap(); assert_eq!(output.text, "2"); - // Prompt A — expects default TEST_MODEL (proves sessions are independent) + // Both model paths emit ConfigOption (no CurrentModelUpdate in the schema). + let prompt_notifs = session_b.notifications(); + let mut all = set_model_notifs; + all.extend(prompt_notifs); + assert_eq!( + all, + vec![Notification::ConfigOption, Notification::AgentMessage], + ); + + // Prompt A: expects default TEST_MODEL (proves sessions are independent) expected_session_id.set(&session_a.session_id().0); let output = session_a .prompt("what is 1+1", PermissionDecision::Cancel) - .await; + .await + .unwrap(); assert_eq!(output.text, "2"); + assert_notifications(&session_a.notifications(), &[Notification::AgentMessage]); +} + +pub async fn run_model_set_error_session_not_found() { + let openai = OpenAiFixture::new(vec![], C::expected_session_id()).await; + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let SessionData { .. } = conn.new_session().await.unwrap(); + + let err = conn + .set_model("nonexistent-session-id", "o4-mini") + .await + .unwrap_err(); + + let sacp_err = err.downcast::().unwrap(); + assert_eq!( + sacp_err, + sacp::Error::resource_not_found(Some("nonexistent-session-id".to_string())) + .data("Session not found: nonexistent-session-id") + ); +} + +#[allow(dead_code)] +pub async fn run_new_session_error( + cx: &sacp::ConnectionTo, + params: serde_json::Value, + expected: sacp::Error, +) { + let err = fixtures::send_custom(cx, "session/new", params) + .await + .unwrap_err(); + assert_eq!(err, expected); +} + +pub async fn run_prompt_error() { + let BasicSession { conn, mut session } = + new_basic_session::(TestConnectionConfig::default()).await; + let sid = session.session_id().0.to_string(); + + conn.delete_session(&sid).await.unwrap(); + + let err = session + .prompt("test", PermissionDecision::Cancel) + .await + .unwrap_err(); + let sacp_err = err.downcast::().unwrap(); + assert_eq!(sacp_err.code, sacp::ErrorCode::ResourceNotFound); } pub async fn run_permission_persistence() { @@ -611,14 +916,14 @@ pub async fn run_permission_persistence() { }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); for (decision, expected_status, expected_yaml) in cases { conn.reset_openai(); conn.reset_permissions(); let _ = fs::remove_file(temp_dir.path().join("permission.yaml")); - let output = session.prompt(prompt, decision).await; + let output = session.prompt(prompt, decision).await.unwrap(); assert_eq!(output.tool_status.unwrap(), expected_status); assert_eq!( @@ -641,12 +946,13 @@ pub async fn run_prompt_basic() { .await; let mut conn = C::new(TestConnectionConfig::default(), openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let output = session .prompt("what is 1+1", PermissionDecision::Cancel) - .await; + .await + .unwrap(); assert_eq!(output.text, "2"); assert_notifications(&session.notifications(), &[Notification::AgentMessage]); expected_session_id.assert_matches(&session.session_id().0); @@ -685,10 +991,13 @@ pub async fn run_prompt_codemode() { let _ = fs::remove_file("/tmp/result.txt"); let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(prompt, PermissionDecision::Cancel).await; + let output = session + .prompt(prompt, PermissionDecision::Cancel) + .await + .unwrap(); if matches!(output.tool_status, Some(ToolCallStatus::Failed)) || output.text.contains("error") { panic!("{}", output.text); } @@ -722,7 +1031,7 @@ pub async fn run_prompt_image() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let output = session @@ -730,7 +1039,8 @@ pub async fn run_prompt_image() { "Use the get_image tool and describe what you see in its result.", PermissionDecision::Cancel, ) - .await; + .await + .unwrap(); assert_eq!(output.text, "Hello Goose!\nThis is a test image."); assert_notifications( &session.notifications(), @@ -756,7 +1066,7 @@ pub async fn run_prompt_image_attachment() { .await; let mut conn = C::new(TestConnectionConfig::default(), openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let output = session @@ -766,7 +1076,8 @@ pub async fn run_prompt_image_attachment() { "image/png", PermissionDecision::Cancel, ) - .await; + .await + .unwrap(); assert!(output.text.contains("Hello Goose!")); assert_notifications(&session.notifications(), &[Notification::AgentMessage]); expected_session_id.assert_matches(&session.session_id().0); @@ -795,7 +1106,7 @@ pub async fn run_prompt_mcp() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let output = session @@ -803,7 +1114,8 @@ pub async fn run_prompt_mcp() { "Use the get_code tool and output only its result.", PermissionDecision::Cancel, ) - .await; + .await + .unwrap(); assert_eq!(output.text, FAKE_CODE); assert_notifications( &session.notifications(), @@ -817,6 +1129,19 @@ pub async fn run_prompt_mcp() { expected_session_id.assert_matches(&session.session_id().0); } +pub async fn run_prompt_model_mismatch() { + // Start the connection where the current model is not desired. + let config = TestConnectionConfig { + current_model: "o4-mini".to_string(), + ..Default::default() + }; + + // Server starts on o4-mini; client is configured with TEST_MODEL. + // If session_model is seeded from the response, stream() detects the + // mismatch and sends set_model(TEST_MODEL) before prompting. + let BasicSession { conn: _, .. } = new_basic_session::(config).await; +} + pub async fn run_prompt_skill() { let cwd = tempfile::tempdir().unwrap(); let skill_dir = cwd.path().join(".agents/skills/test-skill"); @@ -844,12 +1169,13 @@ pub async fn run_prompt_skill() { }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); let output = session .prompt("what is 1+1", PermissionDecision::Cancel) - .await; + .await + .unwrap(); assert_eq!(output.text, "2"); expected_session_id.assert_matches(&session.session_id().0); } @@ -877,10 +1203,13 @@ pub async fn run_shell_terminal_false() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(&prompt, PermissionDecision::AllowOnce).await; + let output = session + .prompt(&prompt, PermissionDecision::AllowOnce) + .await + .unwrap(); assert!(!output.text.is_empty()); assert_notifications( &session.notifications(), @@ -927,10 +1256,13 @@ pub async fn run_shell_terminal_true() { ..Default::default() }; let mut conn = C::new(config, openai).await; - let SessionResult { mut session, .. } = conn.new_session().await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); - let output = session.prompt(&prompt, PermissionDecision::AllowOnce).await; + let output = session + .prompt(&prompt, PermissionDecision::AllowOnce) + .await + .unwrap(); assert_eq!(output.tool_status, Some(ToolCallStatus::Completed)); assert_notifications( &session.notifications(), diff --git a/crates/goose-acp/tests/custom_requests_test.rs b/crates/goose-acp/tests/custom_requests_test.rs index 01606317e92c..c54b96cc1d85 100644 --- a/crates/goose-acp/tests/custom_requests_test.rs +++ b/crates/goose-acp/tests/custom_requests_test.rs @@ -1,78 +1,29 @@ #[allow(dead_code)] mod common_tests; -use common_tests::fixtures::server::ClientToAgentConnection; -use common_tests::fixtures::{run_test, Connection, Session, SessionResult, TestConnectionConfig}; +use common_tests::fixtures::server::AcpServerConnection; +use common_tests::fixtures::{ + run_test, send_custom, Connection, Session, SessionData, TestConnectionConfig, +}; use goose_test_support::EnforceSessionId; use std::sync::Arc; use common_tests::fixtures::OpenAiFixture; -/// Send an untyped custom request and return the result or error. -async fn send_custom( - cx: &sacp::JrConnectionCx, - method: &str, - params: serde_json::Value, -) -> Result { - let msg = sacp::UntypedMessage::new(method, params).unwrap(); - cx.send_request(msg).block_task().await -} - -#[test] -fn test_custom_session_list() { - run_test(async { - let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; - - let SessionResult { session, .. } = conn.new_session().await; - let session_id = session.session_id().0.clone(); - - // Verify the session exists via _session/get - let get_result = send_custom( - conn.cx(), - "_goose/session/get", - serde_json::json!({ "session_id": session_id }), - ) - .await; - assert!( - get_result.is_ok(), - "session should exist via get: {:?}", - get_result - ); - let get_response = get_result.unwrap(); - assert_eq!( - get_response - .get("session") - .and_then(|s| s.get("id")) - .and_then(|v| v.as_str()), - Some(session_id.as_ref()), - ); - - // Verify _session/list returns a valid response - // Note: list_sessions uses INNER JOIN on messages, so a fresh session - // with no messages won't appear. We just verify the call succeeds. - let result = send_custom(conn.cx(), "_goose/session/list", serde_json::json!({})).await; - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let response = result.unwrap(); - let sessions = response.get("sessions").expect("missing 'sessions' field"); - assert!(sessions.is_array(), "sessions should be array"); - }); -} - #[test] fn test_custom_session_get() { run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + let mut conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; - let SessionResult { session, .. } = conn.new_session().await; + let SessionData { session, .. } = conn.new_session().await.unwrap(); let session_id = session.session_id().0.clone(); let result = send_custom( conn.cx(), - "_goose/session/get", + "session/get", serde_json::json!({ - "session_id": session_id, + "sessionId": session_id, }), ) .await; @@ -87,46 +38,19 @@ fn test_custom_session_get() { }); } -#[test] -fn test_custom_session_delete() { - run_test(async { - let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; - - let SessionResult { session, .. } = conn.new_session().await; - let session_id = session.session_id().0.clone(); - - let result = send_custom( - conn.cx(), - "_goose/session/delete", - serde_json::json!({ "session_id": session_id }), - ) - .await; - assert!(result.is_ok(), "delete failed: {:?}", result); - - let result = send_custom( - conn.cx(), - "_goose/session/get", - serde_json::json!({ "session_id": session_id }), - ) - .await; - assert!(result.is_err(), "expected error for deleted session"); - }); -} - #[test] fn test_custom_get_tools() { run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + let mut conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; - let SessionResult { session, .. } = conn.new_session().await; + let SessionData { session, .. } = conn.new_session().await.unwrap(); let session_id = session.session_id().0.clone(); let result = send_custom( conn.cx(), "_goose/tools", - serde_json::json!({ "session_id": session_id }), + serde_json::json!({ "sessionId": session_id }), ) .await; assert!(result.is_ok(), "expected ok, got: {:?}", result); @@ -141,7 +65,7 @@ fn test_custom_get_tools() { fn test_custom_get_extensions() { run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; let result = send_custom(conn.cx(), "_goose/config/extensions", serde_json::json!({})).await; @@ -163,7 +87,7 @@ fn test_custom_get_extensions() { fn test_custom_unknown_method() { run_test(async { let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; let result = send_custom(conn.cx(), "_unknown/method", serde_json::json!({})).await; assert!(result.is_err(), "expected method_not_found error"); diff --git a/crates/goose-acp/tests/fixtures/mod.rs b/crates/goose-acp/tests/fixtures/mod.rs index 0d5c44989bb0..e7fb12dea8b1 100644 --- a/crates/goose-acp/tests/fixtures/mod.rs +++ b/crates/goose-acp/tests/fixtures/mod.rs @@ -5,6 +5,7 @@ use async_trait::async_trait; use fs_err as fs; pub use goose::acp::{map_permission_response, PermissionDecision, PermissionMapping}; use goose::builtin_extension::register_builtin_extensions; +use goose::config::paths::Paths; use goose::config::{GooseMode, PermissionManager}; use goose::providers::api_client::{ApiClient, AuthMethod as ApiAuthMethod}; use goose::providers::base::Provider; @@ -14,7 +15,7 @@ use goose::session_context::SESSION_ID_HEADER; use goose_acp::server::{serve, GooseAcpAgent}; use goose_test_support::{ExpectedSessionId, TEST_MODEL}; use sacp::schema::{ - AuthMethod, CreateTerminalResponse, KillTerminalCommandResponse, McpServer, + AuthMethod, CreateTerminalResponse, KillTerminalResponse, ListSessionsResponse, McpServer, ReadTextFileRequest, ReadTextFileResponse, ReleaseTerminalResponse, SessionModeState, SessionModelState, SessionUpdate, TerminalExitStatus, TerminalId, TerminalOutputResponse, ToolCallContent, ToolCallStatus, ToolKind, WaitForTerminalExitResponse, WriteTextFileRequest, @@ -155,13 +156,16 @@ pub async fn spawn_acp_server_in_process( data_root: &std::path::Path, goose_mode: GooseMode, provider_factory: Option, + current_model: &str, ) -> (DuplexTransport, JoinHandle<()>, Arc) { fs::create_dir_all(data_root).unwrap(); + // TODO: Paths::in_state_dir is global, ignoring per-test data_root + fs::create_dir_all(Paths::in_state_dir("logs")).unwrap(); let config_path = data_root.join(goose::config::base::CONFIG_YAML_NAME); if !config_path.exists() { fs::write( &config_path, - format!("GOOSE_MODEL: {TEST_MODEL}\nGOOSE_PROVIDER: openai\n"), + format!("GOOSE_MODEL: {current_model}\nGOOSE_PROVIDER: openai\n"), ) .unwrap(); } @@ -180,24 +184,24 @@ pub async fn spawn_acp_server_in_process( }) }); - let agent = Arc::new( - GooseAcpAgent::new( - provider_factory, - builtins.to_vec(), - data_root.to_path_buf(), - data_root.to_path_buf(), - goose_mode, - true, - ) - .await - .unwrap(), - ); + let agent = GooseAcpAgent::new( + provider_factory, + builtins.to_vec(), + data_root.to_path_buf(), + data_root.to_path_buf(), + goose_mode, + true, + ) + .await + .unwrap(); + let agent = Arc::new(agent); let permission_manager = agent.permission_manager(); let (transport, handle) = serve_agent_in_process(agent).await; (transport, handle, permission_manager) } +#[derive(Debug)] pub struct TestOutput { pub text: String, pub tool_status: Option, @@ -437,11 +441,11 @@ impl TerminalFixture { ReleaseTerminalResponse::new() } - pub fn on_kill(&self, terminal_id: &TerminalId) -> KillTerminalCommandResponse { + pub fn on_kill(&self, terminal_id: &TerminalId) -> KillTerminalResponse { if let Some(TerminalCall::Kill(expected_id)) = self.pop("kill") { self.validate_terminal_id("kill", &expected_id, terminal_id); } - KillTerminalCommandResponse::new() + KillTerminalResponse::new() } pub fn assert_called(&self) { @@ -455,7 +459,8 @@ impl TerminalFixture { } } -pub struct SessionResult { +#[derive(Debug)] +pub struct SessionData { pub session: S, pub models: Option, pub modes: Option, @@ -471,6 +476,11 @@ pub struct TestConnectionConfig { pub read_text_file: Option, pub write_text_file: Option, pub terminal: Option>, + // When true, strips config_options from responses to test the legacy set_mode/set_model path. + #[allow(dead_code)] + pub strip_config_options: bool, + // The model the server-side provider starts with. Defaults to TEST_MODEL. + pub current_model: String, } impl Default for TestConnectionConfig { @@ -485,6 +495,8 @@ impl Default for TestConnectionConfig { read_text_file: None, write_text_file: None, terminal: None, + strip_config_options: false, + current_model: TEST_MODEL.to_string(), } } } @@ -495,31 +507,46 @@ pub trait Connection: Sized { fn expected_session_id() -> Arc; async fn new(config: TestConnectionConfig, openai: OpenAiFixture) -> Self; - async fn new_session(&mut self) -> SessionResult; + async fn new_session(&mut self) -> anyhow::Result>; async fn load_session( &mut self, session_id: &str, mcp_servers: Vec, - ) -> SessionResult; + ) -> anyhow::Result>; + async fn list_sessions(&self) -> anyhow::Result; + async fn close_session(&self, session_id: &str) -> anyhow::Result<()>; + async fn delete_session(&self, session_id: &str) -> anyhow::Result<()>; async fn set_mode(&self, session_id: &str, mode_id: &str) -> anyhow::Result<()>; async fn set_model(&self, session_id: &str, model_id: &str) -> anyhow::Result<()>; + async fn set_config_option( + &self, + session_id: &str, + config_id: &str, + value: &str, + ) -> anyhow::Result<()>; fn auth_methods(&self) -> &[AuthMethod]; + fn data_root(&self) -> std::path::PathBuf; fn reset_openai(&self); fn reset_permissions(&self); } #[async_trait] -pub trait Session { +pub trait Session: std::fmt::Debug { fn session_id(&self) -> &sacp::schema::SessionId; + fn work_dir(&self) -> std::path::PathBuf; fn notifications(&self) -> Vec; - async fn prompt(&mut self, text: &str, decision: PermissionDecision) -> TestOutput; + async fn prompt( + &mut self, + text: &str, + decision: PermissionDecision, + ) -> anyhow::Result; async fn prompt_with_image( &mut self, text: &str, image_b64: &str, mime_type: &str, decision: PermissionDecision, - ) -> TestOutput; + ) -> anyhow::Result; } #[allow(dead_code)] @@ -548,5 +575,14 @@ where } } +pub async fn send_custom( + cx: &sacp::ConnectionTo, + method: &str, + params: serde_json::Value, +) -> Result { + let msg = sacp::UntypedMessage::new(method, params).unwrap(); + cx.send_request(msg).block_task().await +} + pub mod provider; pub mod server; diff --git a/crates/goose-acp/tests/fixtures/provider.rs b/crates/goose-acp/tests/fixtures/provider.rs index e5de5896f6da..d9a38d1e8550 100644 --- a/crates/goose-acp/tests/fixtures/provider.rs +++ b/crates/goose-acp/tests/fixtures/provider.rs @@ -1,56 +1,82 @@ use super::{ - spawn_acp_server_in_process, Connection, OpenAiFixture, PermissionDecision, Session, - SessionResult, TestConnectionConfig, TestOutput, + spawn_acp_server_in_process, Connection, DuplexTransport, OpenAiFixture, PermissionDecision, + Session, SessionData, TestConnectionConfig, TestOutput, }; use async_trait::async_trait; use futures::StreamExt; use goose::acp::{AcpProvider, AcpProviderConfig, PermissionMapping}; -use goose::config::goose_mode::GooseMode; -use goose::config::PermissionManager; +use goose::config::{GooseMode, PermissionManager}; use goose::conversation::message::{ActionRequiredData, Message, MessageContent}; use goose::model::ModelConfig; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::{Permission, PermissionConfirmation}; use goose::providers::base::Provider; -use goose::providers::errors::ProviderError; use goose_test_support::{ExpectedSessionId, IgnoreSessionId, TEST_MODEL}; -use sacp::schema::{AuthMethod, McpServer, SessionUpdate, ToolCallStatus}; +use sacp::schema::{AuthMethod, ListSessionsResponse, McpServer, SessionUpdate, ToolCallStatus}; +use sacp::{Channel, Client, ConnectTo, DynConnectTo}; +use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; +use strum::VariantNames; use tokio::sync::Mutex; pub type NotificationSink = Arc>>; +type SessionModels = Arc>>; #[allow(dead_code)] -pub struct ClientToProviderConnection { - provider: Arc>, +pub struct AcpProviderConnection { + /// Option so close_session can trigger session/close via Drop. + provider: Arc>>, permission_manager: Arc, auth_methods: Vec, session_counter: usize, notification_sink: NotificationSink, + session_models: SessionModels, + work_dir: std::path::PathBuf, + data_root: std::path::PathBuf, _openai: OpenAiFixture, _temp_dir: Option, _cwd: Option, } #[allow(dead_code)] -pub struct ClientToProviderSession { - provider: Arc>, +pub struct AcpProviderSession { + provider: Arc>>, session_id: sacp::schema::SessionId, notification_sink: NotificationSink, + session_models: SessionModels, + work_dir: std::path::PathBuf, } -impl ClientToProviderSession { +impl std::fmt::Debug for AcpProviderSession { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AcpProviderSession") + .field("session_id", &self.session_id) + .finish() + } +} + +impl AcpProviderSession { #[allow(dead_code)] - async fn send_message(&mut self, message: Message, decision: PermissionDecision) -> TestOutput { + async fn send_message( + &mut self, + message: Message, + decision: PermissionDecision, + ) -> anyhow::Result { let session_id = self.session_id.0.clone(); - let provider = self.provider.lock().await; + let guard = self.provider.lock().await; + let provider = guard.as_ref().unwrap(); self.notification_sink.lock().unwrap().clear(); - let model_config = provider.get_model_config(); + let model_config = self + .session_models + .lock() + .unwrap() + .get(session_id.as_ref()) + .cloned() + .unwrap_or_else(|| provider.get_model_config()); let mut stream = provider .stream(&model_config, &session_id, "", &[message], &[]) - .await - .unwrap(); + .await?; let mut text = String::new(); let mut tool_error = false; let mut saw_tool = false; @@ -101,13 +127,13 @@ impl ClientToProviderSession { None }; - TestOutput { text, tool_status } + Ok(TestOutput { text, tool_status }) } } #[async_trait] -impl Connection for ClientToProviderConnection { - type Session = ClientToProviderSession; +impl Connection for AcpProviderConnection { + type Session = AcpProviderSession; fn expected_session_id() -> Arc { Arc::new(IgnoreSessionId) @@ -125,12 +151,14 @@ impl Connection for ClientToProviderConnection { let goose_mode = config.goose_mode; let mcp_servers = config.mcp_servers; + let current_model = config.current_model.clone(); let (transport, _handle, permission_manager) = spawn_acp_server_in_process( openai.uri(), &config.builtins, data_root.as_path(), goose_mode, config.provider_factory, + ¤t_model, ) .await; @@ -138,31 +166,44 @@ impl Connection for ClientToProviderConnection { .cwd .as_ref() .map(|td| td.path().to_path_buf()) - .unwrap_or(data_root); + .unwrap_or_else(|| data_root.clone()); let notification_sink: NotificationSink = Arc::new(std::sync::Mutex::new(Vec::new())); + let session_models: SessionModels = Arc::new(std::sync::Mutex::new(HashMap::new())); let sink_clone = notification_sink.clone(); let provider_config = AcpProviderConfig { command: "unused".into(), args: vec![], env: vec![], env_remove: vec![], - work_dir: cwd_path, + work_dir: cwd_path.clone(), mcp_servers, session_mode_id: None, + mode_mapping: GooseMode::VARIANTS + .iter() + .map(|v| { + let mode = GooseMode::from_str(v).unwrap(); + (mode, mode.to_string()) + }) + .collect(), permission_mapping: PermissionMapping::default(), notification_callback: Some(Arc::new(move |n| { sink_clone.lock().unwrap().push(n.update.clone()); })), }; + // Server always advertises both configOptions and legacy; only the client fallback needs testing. + let transport: DynConnectTo = if config.strip_config_options { + DynConnectTo::new(strip_config_options(transport)) + } else { + DynConnectTo::new(transport) + }; let provider = AcpProvider::connect_with_transport( "acp-test".to_string(), ModelConfig::new(TEST_MODEL).unwrap(), goose_mode, provider_config, - transport.incoming, - transport.outgoing, + transport, ) .await .unwrap(); @@ -170,18 +211,21 @@ impl Connection for ClientToProviderConnection { let auth_methods = provider.auth_methods().to_vec(); Self { - provider: Arc::new(Mutex::new(provider)), + provider: Arc::new(Mutex::new(Some(provider))), permission_manager, auth_methods, session_counter: 0, notification_sink, + session_models, + work_dir: cwd_path, + data_root, _openai: openai, _temp_dir: temp_dir, _cwd: config.cwd, } } - async fn new_session(&mut self) -> SessionResult { + async fn new_session(&mut self) -> anyhow::Result> { // Tests like run_model_set call new_session() multiple times on the same // connection, so each needs a distinct key to avoid returning a cached session. self.session_counter += 1; @@ -190,58 +234,132 @@ impl Connection for ClientToProviderConnection { .provider .lock() .await + .as_ref() + .unwrap() .ensure_session(Some(&goose_id)) - .await - .unwrap(); + .await?; - let session = ClientToProviderSession { + let session = AcpProviderSession { provider: Arc::clone(&self.provider), session_id: sacp::schema::SessionId::new(goose_id), notification_sink: self.notification_sink.clone(), + session_models: self.session_models.clone(), + work_dir: self.work_dir.clone(), }; - SessionResult { + Ok(SessionData { session, models: response.models, modes: response.modes, - } + }) } async fn load_session( &mut self, _session_id: &str, _mcp_servers: Vec, - ) -> SessionResult { - unimplemented!("TODO: implement load_session in ACP provider") + ) -> anyhow::Result> { + Err(sacp::Error::internal_error() + .data("load_session not implemented for ACP provider") + .into()) } - async fn set_mode(&self, session_id: &str, mode_id: &str) -> anyhow::Result<()> { - let mode = GooseMode::from_str(mode_id).map_err(|_| { - sacp::Error::invalid_params().data(format!("Invalid mode: {}", mode_id)) - })?; + async fn list_sessions(&self) -> anyhow::Result { self.provider .lock() .await - .update_mode(session_id, mode) + .as_ref() + .unwrap() + .list_sessions() .await - .map_err(|e| match e { - ProviderError::RequestFailed(msg) => sacp::Error::invalid_params().data(msg), - other => sacp::Error::internal_error().data(other.to_string()), - })?; + } + + async fn close_session(&self, _session_id: &str) -> anyhow::Result<()> { + // ACP close exists but SessionManager isn't integrated with it; drop the provider instead. + self.provider.lock().await.take(); Ok(()) } - async fn set_model(&self, session_id: &str, model_id: &str) -> anyhow::Result<()> { - let provider = self.provider.lock().await; - let response = provider.ensure_session(Some(session_id)).await?; + async fn delete_session(&self, session_id: &str) -> anyhow::Result<()> { + self.provider + .lock() + .await + .as_ref() + .unwrap() + .delete_session(session_id) + .await + } + + fn data_root(&self) -> std::path::PathBuf { + self.data_root.clone() + } + + async fn set_mode(&self, session_id: &str, mode_id: &str) -> anyhow::Result<()> { + let mode = GooseMode::from_str(mode_id) + .map_err(|_| sacp::Error::invalid_params().data(format!("Invalid mode: {mode_id}")))?; + let guard = self.provider.lock().await; + let provider = guard.as_ref().unwrap(); + if !provider.has_session(session_id).await { + return Err( + sacp::Error::resource_not_found(Some(session_id.to_string())) + .data(format!("Session not found: {session_id}")) + .into(), + ); + } provider - .send_untyped( - "session/set_model", - serde_json::json!({ "sessionId": response.session_id, "modelId": model_id }), - ) - .await?; + .update_mode(session_id, mode) + .await + .map_err(|e| anyhow::anyhow!("{e}")) + } + + async fn set_model(&self, session_id: &str, model_id: &str) -> anyhow::Result<()> { + let config = ModelConfig::new(model_id).map_err(|e| anyhow::anyhow!("{e}"))?; + self.session_models + .lock() + .unwrap() + .insert(session_id.to_string(), config); Ok(()) } + async fn set_config_option( + &self, + session_id: &str, + config_id: &str, + value: &str, + ) -> anyhow::Result<()> { + // Check up front because the "model" branch doesn't go through the provider. + let guard = self.provider.lock().await; + let provider = guard.as_ref().unwrap(); + if !provider.has_session(session_id).await { + return Err( + sacp::Error::resource_not_found(Some(session_id.to_string())) + .data(format!("Session not found: {session_id}")) + .into(), + ); + } + match config_id { + "mode" => { + let mode = GooseMode::from_str(value).map_err(|_| { + sacp::Error::invalid_params().data(format!("Invalid mode: {value}")) + })?; + provider + .update_mode(session_id, mode) + .await + .map_err(|e| anyhow::anyhow!("{e}")) + } + "model" => { + let config = ModelConfig::new(value).map_err(|e| anyhow::anyhow!("{e}"))?; + self.session_models + .lock() + .unwrap() + .insert(session_id.to_string(), config); + Ok(()) + } + other => Err(sacp::Error::invalid_params() + .data(format!("Unsupported config option: {other}")) + .into()), + } + } + fn auth_methods(&self) -> &[AuthMethod] { &self.auth_methods } @@ -251,22 +369,31 @@ impl Connection for ClientToProviderConnection { } fn reset_permissions(&self) { + // "" matches all extensions, clearing all stored permission decisions self.permission_manager.remove_extension(""); } } #[async_trait] -impl Session for ClientToProviderSession { +impl Session for AcpProviderSession { fn session_id(&self) -> &sacp::schema::SessionId { &self.session_id } + fn work_dir(&self) -> std::path::PathBuf { + self.work_dir.clone() + } + fn notifications(&self) -> Vec { - let updates = self.notification_sink.lock().unwrap(); + let updates: Vec<_> = self.notification_sink.lock().unwrap().drain(..).collect(); super::to_notifications(&updates) } - async fn prompt(&mut self, prompt: &str, decision: PermissionDecision) -> TestOutput { + async fn prompt( + &mut self, + prompt: &str, + decision: PermissionDecision, + ) -> anyhow::Result { self.send_message(Message::user().with_text(prompt), decision) .await } @@ -277,10 +404,58 @@ impl Session for ClientToProviderSession { image_b64: &str, mime_type: &str, decision: PermissionDecision, - ) -> TestOutput { + ) -> anyhow::Result { let message = Message::user() .with_image(image_b64, mime_type) .with_text(prompt); self.send_message(message, decision).await } } + +// Strips config_options from responses so goose falls back to legacy set_mode/set_model. +#[allow(dead_code)] +fn strip_config_options(transport: DuplexTransport) -> Channel { + let (server, server_future) = ConnectTo::::into_channel_and_future(transport); + let (client_channel, filter) = Channel::duplex(); + + tokio::spawn(async move { + if let Err(e) = server_future.await { + tracing::error!("config_options filter transport error: {e}"); + } + }); + + tokio::spawn(async move { + let goose_to_server = async { + let mut from_goose = filter.rx; + while let Some(msg) = from_goose.next().await { + if server.tx.unbounded_send(msg).is_err() { + break; + } + } + }; + + let server_to_goose = async { + let mut from_server = server.rx; + while let Some(msg) = from_server.next().await { + let msg = msg.map(|m| match m { + sacp::jsonrpcmsg::Message::Response(mut resp) => { + if let Some(ref mut result) = resp.result { + if let Some(obj) = result.as_object_mut() { + obj.remove("configOptions"); + } + } + sacp::jsonrpcmsg::Message::Response(resp) + } + other => other, + }); + if filter.tx.unbounded_send(msg).is_err() { + break; + } + } + }; + + futures::join!(goose_to_server, server_to_goose); + }); + + client_channel +} diff --git a/crates/goose-acp/tests/fixtures/server.rs b/crates/goose-acp/tests/fixtures/server.rs index 623d0f61c11b..6a6ed2443b8c 100644 --- a/crates/goose-acp/tests/fixtures/server.rs +++ b/crates/goose-acp/tests/fixtures/server.rs @@ -1,28 +1,31 @@ use super::{ map_permission_response, spawn_acp_server_in_process, Connection, PermissionDecision, - PermissionMapping, Session, SessionResult, TestConnectionConfig, TestOutput, + PermissionMapping, Session, SessionData, TestConnectionConfig, TestOutput, }; use async_trait::async_trait; use goose::config::PermissionManager; use goose_test_support::{EnforceSessionId, ExpectedSessionId}; use sacp::schema::{ - AuthMethod, ClientCapabilities, ContentBlock, CreateTerminalRequest, FileSystemCapability, - ImageContent, InitializeRequest, KillTerminalCommandRequest, LoadSessionRequest, McpServer, - NewSessionRequest, PromptRequest, ProtocolVersion, ReadTextFileRequest, ReleaseTerminalRequest, - RequestPermissionRequest, SessionNotification, SessionUpdate, StopReason, - TerminalOutputRequest, TextContent, ToolCallStatus, WaitForTerminalExitRequest, - WriteTextFileRequest, + AuthMethod, ClientCapabilities, CloseSessionRequest, ContentBlock, CreateTerminalRequest, + FileSystemCapabilities, ImageContent, InitializeRequest, KillTerminalRequest, + ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, McpServer, NewSessionRequest, + PromptRequest, ProtocolVersion, ReadTextFileRequest, ReleaseTerminalRequest, + RequestPermissionRequest, SessionConfigOptionValue, SessionId, SessionModeId, + SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, SetSessionModeRequest, + SetSessionModelRequest, StopReason, TerminalOutputRequest, TextContent, ToolCallStatus, + WaitForTerminalExitRequest, WriteTextFileRequest, }; -use sacp::{ClientToAgent, JrConnectionCx}; +use sacp::{Agent, Client, ConnectionTo}; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::sync::Notify; -pub struct ClientToAgentConnection { - cx: JrConnectionCx, +pub struct AcpServerConnection { + cx: ConnectionTo, // MCP servers from config, consumed by the first new_session call. pending_mcp_servers: Vec, cwd: Option, + data_root: std::path::PathBuf, updates: Arc>>, permission: Arc>, notify: Arc, @@ -32,8 +35,8 @@ pub struct ClientToAgentConnection { _temp_dir: Option, } -pub struct ClientToAgentSession { - cx: JrConnectionCx, +pub struct AcpServerSession { + cx: ConnectionTo, session_id: sacp::schema::SessionId, updates: Arc>>, permission: Arc>, @@ -41,12 +44,20 @@ pub struct ClientToAgentSession { _work_dir: tempfile::TempDir, } -impl ClientToAgentSession { +impl std::fmt::Debug for AcpServerSession { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AcpServerSession") + .field("session_id", &self.session_id) + .finish() + } +} + +impl AcpServerSession { async fn send_prompt( &mut self, content: Vec, decision: PermissionDecision, - ) -> TestOutput { + ) -> anyhow::Result { *self.permission.lock().unwrap() = decision; self.updates.lock().unwrap().clear(); @@ -54,8 +65,7 @@ impl ClientToAgentSession { .cx .send_request(PromptRequest::new(self.session_id.clone(), content)) .block_task() - .await - .unwrap(); + .await?; assert_eq!(response.stop_reason, StopReason::EndTurn); @@ -73,20 +83,20 @@ impl ClientToAgentSession { tool_status = extract_tool_status(&self.updates); } - TestOutput { text, tool_status } + Ok(TestOutput { text, tool_status }) } } -impl ClientToAgentConnection { +impl AcpServerConnection { #[allow(dead_code)] - pub fn cx(&self) -> &JrConnectionCx { + pub fn cx(&self) -> &ConnectionTo { &self.cx } } #[async_trait] -impl Connection for ClientToAgentConnection { - type Session = ClientToAgentSession; +impl Connection for AcpServerConnection { + type Session = AcpServerSession; fn expected_session_id() -> Arc { Arc::new(EnforceSessionId::default()) @@ -107,6 +117,7 @@ impl Connection for ClientToAgentConnection { data_root.as_path(), config.goose_mode, config.provider_factory, + &config.current_model, ) .await; @@ -114,7 +125,7 @@ impl Connection for ClientToAgentConnection { let notify = Arc::new(Notify::new()); let permission = Arc::new(Mutex::new(PermissionDecision::Cancel)); - let mut fs_cap = FileSystemCapability::default(); + let mut fs_cap = FileSystemCapabilities::default(); if config.read_text_file.is_some() { fs_cap = fs_cap.read_text_file(true); } @@ -130,8 +141,7 @@ impl Connection for ClientToAgentConnection { let write_handler = config.write_text_file; let terminal = config.terminal; - let cx_holder: Arc>>> = - Arc::new(Mutex::new(None)); + let cx_holder: Arc>>> = Arc::new(Mutex::new(None)); let cx_holder_clone = cx_holder.clone(); let auth_holder: Arc>> = Arc::new(Mutex::new(Vec::new())); let auth_holder_clone = auth_holder.clone(); @@ -141,7 +151,8 @@ impl Connection for ClientToAgentConnection { tokio::spawn(async move { let permission_mapping = PermissionMapping::default(); - let result = ClientToAgent::builder() + let result = Client + .builder() .on_receive_notification( { let updates = updates_clone.clone(); @@ -157,45 +168,42 @@ impl Connection for ClientToAgentConnection { .on_receive_request( { let permission = permission_clone.clone(); - async move |req: RequestPermissionRequest, - request_cx, - _connection_cx| { + async move |req: RequestPermissionRequest, responder, _connection_cx| { let decision = *permission.lock().unwrap(); let response = map_permission_response(&permission_mapping, &req, decision); - request_cx.respond(response) + responder.respond(response) } }, sacp::on_receive_request!(), ) .on_receive_request( - async move |req: ReadTextFileRequest, request_cx, _cx| match read_handler { + async move |req: ReadTextFileRequest, responder, _cx| match read_handler { Some(ref rh) => match rh(&req) { - Ok(resp) => request_cx.respond(resp), - Err(msg) => request_cx.respond_with_internal_error(msg), + Ok(resp) => responder.respond(resp), + Err(msg) => responder.respond_with_internal_error(msg), }, - None => request_cx.respond_with_error(sacp::Error::method_not_found()), + None => responder.respond_with_error(sacp::Error::method_not_found()), }, sacp::on_receive_request!(), ) .on_receive_request( - async move |req: WriteTextFileRequest, request_cx, _cx| match write_handler - { + async move |req: WriteTextFileRequest, responder, _cx| match write_handler { Some(ref wh) => match wh(&req) { - Ok(resp) => request_cx.respond(resp), - Err(msg) => request_cx.respond_with_internal_error(msg), + Ok(resp) => responder.respond(resp), + Err(msg) => responder.respond_with_internal_error(msg), }, - None => request_cx.respond_with_error(sacp::Error::method_not_found()), + None => responder.respond_with_error(sacp::Error::method_not_found()), }, sacp::on_receive_request!(), ) .on_receive_request( { let t = terminal.clone(); - async move |req: CreateTerminalRequest, request_cx, _cx| match t { - Some(ref f) => request_cx.respond(f.on_create(&req.command)), + async move |req: CreateTerminalRequest, responder, _cx| match t { + Some(ref f) => responder.respond(f.on_create(&req.command)), None => { - request_cx.respond_with_error(sacp::Error::method_not_found()) + responder.respond_with_error(sacp::Error::method_not_found()) } } }, @@ -204,12 +212,12 @@ impl Connection for ClientToAgentConnection { .on_receive_request( { let t = terminal.clone(); - async move |req: WaitForTerminalExitRequest, request_cx, _cx| match t { + async move |req: WaitForTerminalExitRequest, responder, _cx| match t { Some(ref f) => { - request_cx.respond(f.on_wait_for_exit(&req.terminal_id)) + responder.respond(f.on_wait_for_exit(&req.terminal_id)) } None => { - request_cx.respond_with_error(sacp::Error::method_not_found()) + responder.respond_with_error(sacp::Error::method_not_found()) } } }, @@ -218,10 +226,10 @@ impl Connection for ClientToAgentConnection { .on_receive_request( { let t = terminal.clone(); - async move |req: TerminalOutputRequest, request_cx, _cx| match t { - Some(ref f) => request_cx.respond(f.on_output(&req.terminal_id)), + async move |req: TerminalOutputRequest, responder, _cx| match t { + Some(ref f) => responder.respond(f.on_output(&req.terminal_id)), None => { - request_cx.respond_with_error(sacp::Error::method_not_found()) + responder.respond_with_error(sacp::Error::method_not_found()) } } }, @@ -230,10 +238,10 @@ impl Connection for ClientToAgentConnection { .on_receive_request( { let t = terminal.clone(); - async move |req: ReleaseTerminalRequest, request_cx, _cx| match t { - Some(ref f) => request_cx.respond(f.on_release(&req.terminal_id)), + async move |req: ReleaseTerminalRequest, responder, _cx| match t { + Some(ref f) => responder.respond(f.on_release(&req.terminal_id)), None => { - request_cx.respond_with_error(sacp::Error::method_not_found()) + responder.respond_with_error(sacp::Error::method_not_found()) } } }, @@ -242,21 +250,19 @@ impl Connection for ClientToAgentConnection { .on_receive_request( { let t = terminal.clone(); - async move |req: KillTerminalCommandRequest, request_cx, _cx| match t { - Some(ref f) => request_cx.respond(f.on_kill(&req.terminal_id)), + async move |req: KillTerminalRequest, responder, _cx| match t { + Some(ref f) => responder.respond(f.on_kill(&req.terminal_id)), None => { - request_cx.respond_with_error(sacp::Error::method_not_found()) + responder.respond_with_error(sacp::Error::method_not_found()) } } }, sacp::on_receive_request!(), ) - .connect_to(transport) - .unwrap() - .run_until({ + .connect_with(transport, { let cx_holder = cx_holder_clone; let auth_holder = auth_holder_clone; - move |cx: JrConnectionCx| async move { + async move |cx: ConnectionTo| { let resp = cx .send_request( InitializeRequest::new(ProtocolVersion::LATEST) @@ -294,6 +300,7 @@ impl Connection for ClientToAgentConnection { cx, pending_mcp_servers: config.mcp_servers, cwd: config.cwd, + data_root, updates, permission, notify, @@ -304,7 +311,7 @@ impl Connection for ClientToAgentConnection { } } - async fn new_session(&mut self) -> SessionResult { + async fn new_session(&mut self) -> anyhow::Result> { let work_dir = self .cwd .take() @@ -314,9 +321,8 @@ impl Connection for ClientToAgentConnection { .cx .send_request(NewSessionRequest::new(work_dir.path()).mcp_servers(mcp_servers)) .block_task() - .await - .unwrap(); - let session = ClientToAgentSession { + .await?; + let session = AcpServerSession { cx: self.cx.clone(), session_id: response.session_id.clone(), updates: self.updates.clone(), @@ -324,18 +330,18 @@ impl Connection for ClientToAgentConnection { notify: self.notify.clone(), _work_dir: work_dir, }; - SessionResult { + Ok(SessionData { session, models: response.models, modes: response.modes, - } + }) } async fn load_session( &mut self, session_id: &str, mcp_servers: Vec, - ) -> SessionResult { + ) -> anyhow::Result> { self.updates.lock().unwrap().clear(); let work_dir = tempfile::tempdir().unwrap(); let session_id = sacp::schema::SessionId::new(session_id.to_string()); @@ -346,9 +352,8 @@ impl Connection for ClientToAgentConnection { .mcp_servers(mcp_servers), ) .block_task() - .await - .unwrap(); - let session = ClientToAgentSession { + .await?; + let session = AcpServerSession { cx: self.cx.clone(), session_id, updates: self.updates.clone(), @@ -356,20 +361,47 @@ impl Connection for ClientToAgentConnection { notify: self.notify.clone(), _work_dir: work_dir, }; - SessionResult { + Ok(SessionData { session, models: response.models, modes: response.modes, - } + }) + } + + async fn list_sessions(&self) -> anyhow::Result { + self.cx + .send_request(ListSessionsRequest::new()) + .block_task() + .await + .map_err(|e| e.into()) + } + + async fn close_session(&self, session_id: &str) -> anyhow::Result<()> { + self.cx + .send_request(CloseSessionRequest::new(SessionId::new(session_id))) + .block_task() + .await + .map(|_| ()) + .map_err(|e| e.into()) + } + + async fn delete_session(&self, session_id: &str) -> anyhow::Result<()> { + super::send_custom( + &self.cx, + "session/delete", + serde_json::json!({ "sessionId": session_id }), + ) + .await + .map(|_| ()) + .map_err(|e| e.into()) } async fn set_mode(&self, session_id: &str, mode_id: &str) -> anyhow::Result<()> { - let msg = sacp::UntypedMessage::new( - "session/set_mode", - serde_json::json!({ "sessionId": session_id, "modeId": mode_id }), - )?; self.cx - .send_request(msg) + .send_request(SetSessionModeRequest::new( + SessionId::new(session_id), + SessionModeId::new(mode_id), + )) .block_task() .await .map(|_| ()) @@ -377,12 +409,29 @@ impl Connection for ClientToAgentConnection { } async fn set_model(&self, session_id: &str, model_id: &str) -> anyhow::Result<()> { - let msg = sacp::UntypedMessage::new( - "session/set_model", - serde_json::json!({ "sessionId": session_id, "modelId": model_id }), - )?; self.cx - .send_request(msg) + .send_request(SetSessionModelRequest::new( + SessionId::new(session_id), + model_id.to_string(), + )) + .block_task() + .await + .map(|_| ()) + .map_err(|e| e.into()) + } + + async fn set_config_option( + &self, + session_id: &str, + config_id: &str, + value: &str, + ) -> anyhow::Result<()> { + self.cx + .send_request(SetSessionConfigOptionRequest::new( + SessionId::new(session_id), + config_id.to_string(), + SessionConfigOptionValue::value_id(value.to_string()), + )) .block_task() .await .map(|_| ()) @@ -393,33 +442,46 @@ impl Connection for ClientToAgentConnection { &self.auth_methods } + fn data_root(&self) -> std::path::PathBuf { + self.data_root.clone() + } + fn reset_openai(&self) { self._openai.reset(); } fn reset_permissions(&self) { + // "" matches all extensions, clearing all stored permission decisions self.permission_manager.remove_extension(""); } } #[async_trait] -impl Session for ClientToAgentSession { +impl Session for AcpServerSession { fn session_id(&self) -> &sacp::schema::SessionId { &self.session_id } + fn work_dir(&self) -> std::path::PathBuf { + self._work_dir.path().to_path_buf() + } + fn notifications(&self) -> Vec { let updates: Vec<_> = self .updates .lock() .unwrap() - .iter() - .map(|n| n.update.clone()) + .drain(..) + .map(|n| n.update) .collect(); super::to_notifications(&updates) } - async fn prompt(&mut self, text: &str, decision: PermissionDecision) -> TestOutput { + async fn prompt( + &mut self, + text: &str, + decision: PermissionDecision, + ) -> anyhow::Result { self.send_prompt(vec![ContentBlock::Text(TextContent::new(text))], decision) .await } @@ -430,7 +492,7 @@ impl Session for ClientToAgentSession { image_b64: &str, mime_type: &str, decision: PermissionDecision, - ) -> TestOutput { + ) -> anyhow::Result { self.send_prompt( vec![ ContentBlock::Image(ImageContent::new(image_b64, mime_type)), diff --git a/crates/goose-acp/tests/provider_test.rs b/crates/goose-acp/tests/provider_test.rs index ff0a81395f45..a4fed9592c68 100644 --- a/crates/goose-acp/tests/provider_test.rs +++ b/crates/goose-acp/tests/provider_test.rs @@ -1,120 +1,173 @@ #![recursion_limit = "256"] mod common_tests; -use common_tests::fixtures::provider::ClientToProviderConnection; +use common_tests::fixtures::provider::AcpProviderConnection; use common_tests::fixtures::run_test; use common_tests::{ - run_config_mcp, run_fs_read_text_file_true, run_fs_write_text_file_false, - run_fs_write_text_file_true, run_initialize_doesnt_hit_provider, run_load_mode, run_load_model, - run_load_session_mcp, run_mode_set, run_model_list, run_model_set, run_permission_persistence, - run_prompt_basic, run_prompt_codemode, run_prompt_image, run_prompt_image_attachment, - run_prompt_mcp, run_prompt_skill, run_shell_terminal_false, run_shell_terminal_true, + run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set, + run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false, + run_fs_write_text_file_true, run_initialize_doesnt_hit_provider, run_list_sessions, + run_load_mode, run_load_model, run_load_session_error, run_load_session_mcp, run_mode_set, + run_model_list, run_model_set, run_model_set_error_session_not_found, + run_permission_persistence, run_prompt_basic, run_prompt_codemode, run_prompt_error, + run_prompt_image, run_prompt_image_attachment, run_prompt_mcp, run_prompt_model_mismatch, + run_prompt_skill, run_shell_terminal_false, run_shell_terminal_true, }; -tests_mode_set_error!(ClientToProviderConnection); +tests_config_option_set_error!(AcpProviderConnection); +tests_mode_set_error!(AcpProviderConnection); #[test] fn test_config_mcp() { - run_test(async { run_config_mcp::().await }); + run_test(async { run_config_mcp::().await }); +} + +#[test] +fn test_config_option_mode_set() { + run_test(async { run_config_option_mode_set::().await }); +} + +#[test] +fn test_list_sessions() { + run_test(async { run_list_sessions::().await }); +} + +#[test] +fn test_close_session() { + run_test(async { run_close_session::().await }); +} + +#[test] +fn test_config_option_model_set() { + run_test(async { run_config_option_model_set::().await }); +} + +#[test] +#[ignore = "delete is a server-side custom method not routed through the provider"] +fn test_delete_session() { + run_test(async { run_delete_session::().await }); } #[test] #[ignore = "provider is a plug-in to the goose CLI, UI and terminal clients, none of which handle buffered changes to files"] fn test_fs_read_text_file_true() { - run_test(async { run_fs_read_text_file_true::().await }); + run_test(async { run_fs_read_text_file_true::().await }); } #[test] fn test_fs_write_text_file_false() { - run_test(async { run_fs_write_text_file_false::().await }); + run_test(async { run_fs_write_text_file_false::().await }); } #[test] #[ignore = "provider is a plug-in to the goose CLI, UI and terminal clients, none of which handle buffered changes to files"] fn test_fs_write_text_file_true() { - run_test(async { run_fs_write_text_file_true::().await }); + run_test(async { run_fs_write_text_file_true::().await }); } #[test] fn test_initialize_doesnt_hit_provider() { - run_test(async { run_initialize_doesnt_hit_provider::().await }); + run_test(async { run_initialize_doesnt_hit_provider::().await }); } #[test] #[ignore = "TODO: implement load_session in ACP provider"] fn test_load_mode() { - run_test(async { run_load_mode::().await }); + run_test(async { run_load_mode::().await }); } #[test] #[ignore = "TODO: implement load_session in ACP provider"] fn test_load_model() { - run_test(async { run_load_model::().await }); + run_test(async { run_load_model::().await }); +} + +#[test] +#[ignore = "TODO: implement load_session in ACP provider"] +fn test_load_session_error_session_not_found() { + run_test(async { run_load_session_error::().await }); } #[test] #[ignore = "TODO: implement load_session in ACP provider"] fn test_load_session_mcp() { - run_test(async { run_load_session_mcp::().await }); + run_test(async { run_load_session_mcp::().await }); } #[test] fn test_mode_set() { - run_test(async { run_mode_set::().await }); + run_test(async { run_mode_set::().await }); } #[test] fn test_model_list() { - run_test(async { run_model_list::().await }); + run_test(async { run_model_list::().await }); } #[test] fn test_model_set() { - run_test(async { run_model_set::().await }); + run_test(async { run_model_set::().await }); +} + +#[test] +#[ignore = "ensure_session lazy-creates sessions so deleted ones reappear"] +fn test_model_set_error_session_not_found() { + run_test(async { run_model_set_error_session_not_found::().await }); } #[test] fn test_permission_persistence() { - run_test(async { run_permission_persistence::().await }); + run_test(async { run_permission_persistence::().await }); } #[test] fn test_prompt_basic() { - run_test(async { run_prompt_basic::().await }); + run_test(async { run_prompt_basic::().await }); } #[test] fn test_prompt_codemode() { - run_test(async { run_prompt_codemode::().await }); + run_test(async { run_prompt_codemode::().await }); +} + +#[test] +#[ignore = "ensure_session lazy-creates sessions so deleted ones reappear"] +fn test_prompt_error_session_not_found() { + run_test(async { run_prompt_error::().await }); } #[test] fn test_prompt_image() { - run_test(async { run_prompt_image::().await }); + run_test(async { run_prompt_image::().await }); } #[test] fn test_prompt_image_attachment() { - run_test(async { run_prompt_image_attachment::().await }); + run_test(async { run_prompt_image_attachment::().await }); } #[test] fn test_prompt_mcp() { - run_test(async { run_prompt_mcp::().await }); + run_test(async { run_prompt_mcp::().await }); +} + +#[test] +fn test_prompt_model_mismatch() { + run_test(async { run_prompt_model_mismatch::().await }); } #[test] fn test_prompt_skill() { - run_test(async { run_prompt_skill::().await }); + run_test(async { run_prompt_skill::().await }); } #[test] fn test_shell_terminal_false() { - run_test(async { run_shell_terminal_false::().await }); + run_test(async { run_shell_terminal_false::().await }); } #[test] #[ignore = "provider does not handle terminal delegation requests"] fn test_shell_terminal_true() { - run_test(async { run_shell_terminal_true::().await }); + run_test(async { run_shell_terminal_true::().await }); } diff --git a/crates/goose-acp/tests/server_test.rs b/crates/goose-acp/tests/server_test.rs index 3f84b2347378..e2a278b58eca 100644 --- a/crates/goose-acp/tests/server_test.rs +++ b/crates/goose-acp/tests/server_test.rs @@ -1,112 +1,161 @@ mod common_tests; use common_tests::fixtures::run_test; -use common_tests::fixtures::server::ClientToAgentConnection; +use common_tests::fixtures::server::AcpServerConnection; use common_tests::{ - run_config_mcp, run_fs_read_text_file_true, run_fs_write_text_file_false, - run_fs_write_text_file_true, run_initialize_doesnt_hit_provider, run_load_mode, run_load_model, - run_load_session_mcp, run_mode_set, run_model_list, run_model_set, run_permission_persistence, - run_prompt_basic, run_prompt_codemode, run_prompt_image, run_prompt_image_attachment, - run_prompt_mcp, run_prompt_skill, run_shell_terminal_false, run_shell_terminal_true, + run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set, + run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false, + run_fs_write_text_file_true, run_initialize_doesnt_hit_provider, run_list_sessions, + run_load_mode, run_load_model, run_load_session_error, run_load_session_mcp, run_mode_set, + run_model_list, run_model_set, run_model_set_error_session_not_found, + run_permission_persistence, run_prompt_basic, run_prompt_codemode, run_prompt_error, + run_prompt_image, run_prompt_image_attachment, run_prompt_mcp, run_prompt_model_mismatch, + run_prompt_skill, run_shell_terminal_false, run_shell_terminal_true, }; -tests_mode_set_error!(ClientToAgentConnection); +tests_config_option_set_error!(AcpServerConnection); +tests_mode_set_error!(AcpServerConnection); #[test] fn test_config_mcp() { - run_test(async { run_config_mcp::().await }); + run_test(async { run_config_mcp::().await }); +} + +#[test] +fn test_config_option_mode_set() { + run_test(async { run_config_option_mode_set::().await }); +} + +#[test] +fn test_list_sessions() { + run_test(async { run_list_sessions::().await }); +} + +#[test] +fn test_close_session() { + run_test(async { run_close_session::().await }); +} + +#[test] +fn test_config_option_model_set() { + run_test(async { run_config_option_model_set::().await }); +} + +#[test] +fn test_delete_session() { + run_test(async { run_delete_session::().await }); } #[test] fn test_fs_read_text_file_true() { - run_test(async { run_fs_read_text_file_true::().await }); + run_test(async { run_fs_read_text_file_true::().await }); } #[test] fn test_fs_write_text_file_false() { - run_test(async { run_fs_write_text_file_false::().await }); + run_test(async { run_fs_write_text_file_false::().await }); } #[test] fn test_fs_write_text_file_true() { - run_test(async { run_fs_write_text_file_true::().await }); + run_test(async { run_fs_write_text_file_true::().await }); } #[test] fn test_initialize_doesnt_hit_provider() { - run_test(async { run_initialize_doesnt_hit_provider::().await }); + run_test(async { run_initialize_doesnt_hit_provider::().await }); } #[test] fn test_load_mode() { - run_test(async { run_load_mode::().await }); + run_test(async { run_load_mode::().await }); } #[test] fn test_load_model() { - run_test(async { run_load_model::().await }); + run_test(async { run_load_model::().await }); +} + +#[test] +fn test_load_session_error_session_not_found() { + run_test(async { run_load_session_error::().await }); } #[test] fn test_load_session_mcp() { - run_test(async { run_load_session_mcp::().await }); + run_test(async { run_load_session_mcp::().await }); } #[test] fn test_mode_set() { - run_test(async { run_mode_set::().await }); + run_test(async { run_mode_set::().await }); } #[test] fn test_model_list() { - run_test(async { run_model_list::().await }); + run_test(async { run_model_list::().await }); } #[test] fn test_model_set() { - run_test(async { run_model_set::().await }); + run_test(async { run_model_set::().await }); +} + +#[test] +fn test_model_set_error_session_not_found() { + run_test(async { run_model_set_error_session_not_found::().await }); } #[test] fn test_permission_persistence() { - run_test(async { run_permission_persistence::().await }); + run_test(async { run_permission_persistence::().await }); } #[test] fn test_prompt_basic() { - run_test(async { run_prompt_basic::().await }); + run_test(async { run_prompt_basic::().await }); } #[test] fn test_prompt_codemode() { - run_test(async { run_prompt_codemode::().await }); + run_test(async { run_prompt_codemode::().await }); +} + +#[test] +fn test_prompt_error_session_not_found() { + run_test(async { run_prompt_error::().await }); } #[test] fn test_prompt_image() { - run_test(async { run_prompt_image::().await }); + run_test(async { run_prompt_image::().await }); } #[test] fn test_prompt_image_attachment() { - run_test(async { run_prompt_image_attachment::().await }); + run_test(async { run_prompt_image_attachment::().await }); } #[test] fn test_prompt_mcp() { - run_test(async { run_prompt_mcp::().await }); + run_test(async { run_prompt_mcp::().await }); +} + +#[test] +fn test_prompt_model_mismatch() { + run_test(async { run_prompt_model_mismatch::().await }); } #[test] fn test_prompt_skill() { - run_test(async { run_prompt_skill::().await }); + run_test(async { run_prompt_skill::().await }); } #[test] fn test_shell_terminal_false() { - run_test(async { run_shell_terminal_false::().await }); + run_test(async { run_shell_terminal_false::().await }); } #[test] fn test_shell_terminal_true() { - run_test(async { run_shell_terminal_true::().await }); + run_test(async { run_shell_terminal_true::().await }); } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 10c5f3a8072d..b3969a135d11 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -578,7 +578,7 @@ impl CliSession { } InputResult::GooseMode(mode) => { history.save(editor); - self.handle_goose_mode(&mode)?; + self.handle_goose_mode(&mode).await?; } InputResult::Plan(options) => { self.handle_plan_mode(options).await?; @@ -712,7 +712,7 @@ impl CliSession { } } - fn handle_goose_mode(&self, mode: &str) -> Result<()> { + async fn handle_goose_mode(&self, mode: &str) -> Result<()> { let config = Config::global(); let mode = match GooseMode::from_str(&mode.to_lowercase()) { Ok(mode) => mode, @@ -724,6 +724,7 @@ impl CliSession { return Ok(()); } }; + self.agent.update_goose_mode(mode, &self.session_id).await?; config.set_goose_mode(mode)?; output::goose_mode_message(&format!("Goose mode set to '{mode}'")); Ok(()) diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index ca94fe8d8b04..bf8565eec989 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -71,6 +71,7 @@ serde_yaml = { workspace = true } strum = { workspace = true } once_cell = { workspace = true } etcetera = { workspace = true } +fs-err = "3" rand = { workspace = true } utoipa = { workspace = true, features = ["chrono"] } tokio-cron-scheduler = "0.14.0" @@ -103,8 +104,8 @@ tempfile = { workspace = true } dashmap = "6.1" ahash = "0.8" tokio-util = { workspace = true, features = ["compat"] } -sacp = { workspace = true } -agent-client-protocol-schema = { version = "0.10", features = ["unstable"] } +agent-client-protocol-schema = { workspace = true } +sacp = { workspace = true, features = ["unstable"] } unicode-normalization = "0.1" # For local Whisper transcription @@ -188,7 +189,7 @@ path = "src/providers/canonical/build_canonical_models.rs" ignored = [ # Used only on windows - "winapi", - # Used to provide sacp additional schemas to deserialization - "agent-client-protocol-schema" + "winapi", + # Used to provide extras imports for sacp + "agent-client-protocol-schema", ] diff --git a/crates/goose/src/acp/mod.rs b/crates/goose/src/acp/mod.rs index a29d207dd4e8..4692125e532c 100644 --- a/crates/goose/src/acp/mod.rs +++ b/crates/goose/src/acp/mod.rs @@ -2,4 +2,6 @@ mod common; mod provider; pub use common::{map_permission_response, PermissionDecision, PermissionMapping}; -pub use provider::{extension_configs_to_mcp_servers, AcpProvider, AcpProviderConfig}; +pub use provider::{ + extension_configs_to_mcp_servers, AcpProvider, AcpProviderConfig, ACP_CURRENT_MODEL, +}; diff --git a/crates/goose/src/acp/provider.rs b/crates/goose/src/acp/provider.rs index 1619d9e3fe5e..4461351c6f9e 100644 --- a/crates/goose/src/acp/provider.rs +++ b/crates/goose/src/acp/provider.rs @@ -1,22 +1,27 @@ +use agent_client_protocol_schema::AGENT_METHOD_NAMES; use anyhow::{Context, Result}; use async_stream::try_stream; +use futures::future::BoxFuture; use rmcp::model::{Role, Tool}; use sacp::schema::{ - AuthMethod, ContentBlock, ContentChunk, EnvVariable, HttpHeader, ImageContent, - InitializeRequest, InitializeResponse, McpCapabilities, McpServer, McpServerHttp, - McpServerStdio, NewSessionRequest, NewSessionResponse, PromptRequest, ProtocolVersion, - RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse, SessionId, - SessionNotification, SessionUpdate, SetSessionModeRequest, StopReason, TextContent, - ToolCallContent, + AuthMethod, CloseSessionRequest, ContentBlock, ContentChunk, EnvVariable, HttpHeader, + ImageContent, InitializeRequest, InitializeResponse, ListSessionsRequest, ListSessionsResponse, + McpCapabilities, McpServer, McpServerHttp, McpServerStdio, NewSessionRequest, + NewSessionResponse, PromptRequest, PromptResponse, ProtocolVersion, RequestPermissionOutcome, + RequestPermissionRequest, RequestPermissionResponse, SessionConfigKind, + SessionConfigOptionCategory, SessionConfigSelectOptions, SessionId, SessionNotification, + SessionUpdate, SetSessionConfigOptionRequest, SetSessionModeRequest, SetSessionModeResponse, + SetSessionModelRequest, StopReason, TextContent, ToolCallContent, }; -use sacp::{ClientToAgent, JrConnectionCx}; +use sacp::{Agent, Client, ConnectionTo}; use std::collections::{HashMap, HashSet}; +use std::future::Future; use std::path::PathBuf; use std::process::Stdio; -use std::str::FromStr; use std::sync::{Arc, Mutex}; +use std::thread::JoinHandle; use tokio::process::{Child, Command}; -use tokio::sync::{mpsc, oneshot, Mutex as TokioMutex}; +use tokio::sync::{mpsc, oneshot, Mutex as TokioMutex, OnceCell}; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use crate::acp::{map_permission_response, PermissionDecision, PermissionMapping}; @@ -28,6 +33,9 @@ use crate::permission::{Permission, PermissionConfirmation}; use crate::providers::base::{MessageStream, PermissionRouting, Provider}; use crate::providers::errors::ProviderError; +/// Sentinel: resolved to SessionModelState.current_model_id at connect time. +pub const ACP_CURRENT_MODEL: &str = "current"; + pub struct AcpProviderConfig { pub command: PathBuf, pub args: Vec, @@ -36,6 +44,7 @@ pub struct AcpProviderConfig { pub work_dir: PathBuf, pub mcp_servers: Vec, pub session_mode_id: Option, + pub mode_mapping: HashMap, pub permission_mapping: PermissionMapping, pub notification_callback: Option>, } @@ -44,19 +53,53 @@ enum ClientRequest { NewSession { response_tx: oneshot::Sender>, }, - Untyped { - method: String, - params: serde_json::Value, - response_tx: oneshot::Sender>, + ListSessions { + response_tx: oneshot::Sender>, + }, + SetMode { + session_id: SessionId, + mode_id: String, + response_tx: oneshot::Sender>, + }, + SetModel { + session_id: SessionId, + model_id: String, + response_tx: oneshot::Sender>, + }, + SetConfigOption { + session_id: SessionId, + config_id: String, + value: String, + response_tx: oneshot::Sender>, }, Prompt { session_id: SessionId, content: Vec, response_tx: mpsc::Sender, }, - Shutdown, + CloseSession { + session_id: SessionId, + response_tx: oneshot::Sender>, + }, + // For ACP methods not yet in agent-client-protocol-schema (e.g. session/delete) + Untyped { + method: String, + params: serde_json::Value, + response_tx: oneshot::Sender>, + }, } +// tokio I/O handles can't move between runtimes, so the child process must be +// spawned inside the OS thread. This closure lets start() share all other logic. +type ClientLoopFn = Box< + dyn FnOnce( + AcpClientLoop, + mpsc::Receiver, + oneshot::Sender>, + ) -> BoxFuture<'static, ()> + + Send, +>; + #[derive(Debug)] enum AcpUpdate { Text(String), @@ -79,13 +122,20 @@ pub struct AcpProvider { name: String, model: ModelConfig, goose_mode: Arc>, - tx: mpsc::Sender, + tx: Option>, + loop_thread: Option>, + mode_mapping: HashMap, permission_mapping: PermissionMapping, rejected_tool_calls: Arc>>, pending_confirmations: Arc>>>, goose_to_acp_id: Arc>>, + acp_to_goose_id: Arc>>, + /// Per-session model tracking for detecting model changes in stream(). + session_model: Arc>>, auth_methods: Vec, + supports_close: bool, + init_session: OnceCell, } impl std::fmt::Debug for AcpProvider { @@ -97,6 +147,18 @@ impl std::fmt::Debug for AcpProvider { } } +// Dedicated runtime on an OS thread so session/close completes even during +// main runtime shutdown. See reqwest InnerClientHandle. +fn spawn_client_loop(fut: impl Future + Send + 'static) -> JoinHandle<()> { + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to build ACP client runtime"); + rt.block_on(fut) + }) +} + impl AcpProvider { pub async fn connect( name: String, @@ -104,89 +166,115 @@ impl AcpProvider { goose_mode: GooseMode, config: AcpProviderConfig, ) -> Result { - let (tx, rx) = mpsc::channel(32); - let (init_tx, init_rx) = oneshot::channel(); - let permission_mapping = config.permission_mapping.clone(); - let rejected_tool_calls = Arc::new(TokioMutex::new(HashSet::new())); - let goose_mode = Arc::new(Mutex::new(goose_mode)); - - let client_loop = AcpClientLoop::new(config, goose_mode.clone()); - tokio::spawn(client_loop.spawn(rx, init_tx)); - - let init_response = init_rx - .await - .context("ACP client initialization cancelled")??; + Self::start( + name, + model, + goose_mode, + config, + Box::new(|cl, rx, init_tx| Box::pin(cl.spawn(rx, init_tx))), + ) + .await + } - Ok(Self::new_with_runtime( + #[doc(hidden)] + pub async fn connect_with_transport( + name: String, + model: ModelConfig, + goose_mode: GooseMode, + config: AcpProviderConfig, + transport: impl sacp::ConnectTo + 'static, + ) -> Result { + Self::start( name, model, goose_mode, - tx, - permission_mapping, - rejected_tool_calls, - init_response.auth_methods, - )) + config, + Box::new(move |cl, mut rx, init_tx| { + Box::pin(async move { + if let Err(e) = cl.run(transport, &mut rx, init_tx).await { + tracing::error!("ACP protocol error: {e}"); + } + }) + }), + ) + .await } - pub async fn connect_with_transport( + async fn start( name: String, model: ModelConfig, goose_mode: GooseMode, config: AcpProviderConfig, - read: R, - write: W, - ) -> Result - where - R: futures::AsyncRead + Unpin + Send + 'static, - W: futures::AsyncWrite + Unpin + Send + 'static, - { - let (tx, mut rx) = mpsc::channel(32); + run: ClientLoopFn, + ) -> Result { + let (tx, rx) = mpsc::channel(32); let (init_tx, init_rx) = oneshot::channel(); + let mode_mapping = config.mode_mapping.clone(); let permission_mapping = config.permission_mapping.clone(); let rejected_tool_calls = Arc::new(TokioMutex::new(HashSet::new())); let goose_mode = Arc::new(Mutex::new(goose_mode)); - let transport = sacp::ByteStreams::new(write, read); let client_loop = AcpClientLoop::new(config, goose_mode.clone()); - tokio::spawn(async move { - if let Err(e) = client_loop.run(transport, &mut rx, init_tx).await { - tracing::error!("ACP protocol error: {e}"); - } - }); + let loop_thread = spawn_client_loop(run(client_loop, rx, init_tx)); let init_response = init_rx .await .context("ACP client initialization cancelled")??; - Ok(Self::new_with_runtime( + let supports_close = init_response + .agent_capabilities + .session_capabilities + .close + .is_some(); + let mut provider = Self::new_with_runtime( name, model, goose_mode, tx, + loop_thread, + mode_mapping, permission_mapping, rejected_tool_calls, init_response.auth_methods, - )) + supports_close, + ); + if provider.model.model_name == ACP_CURRENT_MODEL { + let response = provider.get_init_session().await?; + let (current_model, _) = resolve_model_info(&provider.name, response)?; + tracing::info!(from = ACP_CURRENT_MODEL, to = %current_model, "resolved ACP model"); + provider.model.model_name = current_model; + } + Ok(provider) } + #[allow(clippy::too_many_arguments)] fn new_with_runtime( name: String, model: ModelConfig, goose_mode: Arc>, tx: mpsc::Sender, + loop_thread: JoinHandle<()>, + mode_mapping: HashMap, permission_mapping: PermissionMapping, rejected_tool_calls: Arc>>, auth_methods: Vec, + supports_close: bool, ) -> Self { Self { name, model, goose_mode, - tx, + tx: Some(tx), + loop_thread: Some(loop_thread), + mode_mapping, permission_mapping, rejected_tool_calls, pending_confirmations: Arc::new(TokioMutex::new(HashMap::new())), goose_to_acp_id: Arc::new(TokioMutex::new(HashMap::new())), + acp_to_goose_id: Arc::new(TokioMutex::new(HashMap::new())), + session_model: Arc::new(TokioMutex::new(HashMap::new())), auth_methods, + supports_close, + init_session: OnceCell::new(), } } @@ -197,10 +285,112 @@ impl AcpProvider { pub async fn new_session(&self) -> Result { let (response_tx, response_rx) = oneshot::channel(); self.tx + .as_ref() + .unwrap() .send(ClientRequest::NewSession { response_tx }) .await .context("ACP client is unavailable")?; - response_rx.await.context("ACP session/new cancelled")? + response_rx + .await + .context(format!("ACP {} cancelled", AGENT_METHOD_NAMES.session_new))? + } + + pub async fn list_sessions(&self) -> Result { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .as_ref() + .unwrap() + .send(ClientRequest::ListSessions { response_tx }) + .await + .context("ACP client is unavailable")?; + let raw = response_rx.await.context("ACP request cancelled")??; + let acp_to_goose = self.acp_to_goose_id.lock().await; + Ok(map_sessions_to_goose_ids(raw, &acp_to_goose)) + } + + async fn resolve_acp_session_id(&self, goose_id: &str) -> Result { + let map = self.goose_to_acp_id.lock().await; + map.get(goose_id) + .map(|r| r.session_id.clone()) + .ok_or_else(|| { + sacp::Error::resource_not_found(Some(goose_id.to_string())) + .data(format!("Session not found: {goose_id}")) + .into() + }) + } + + pub(crate) async fn send_set_mode(&self, goose_id: &str, mode_id: String) -> Result<()> { + let session_id = self.resolve_acp_session_id(goose_id).await?; + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .as_ref() + .unwrap() + .send(ClientRequest::SetMode { + session_id, + mode_id, + response_tx, + }) + .await + .context("ACP client is unavailable")?; + response_rx.await.context("ACP request cancelled")? + } + + pub(crate) async fn send_set_model(&self, goose_id: &str, model_id: String) -> Result<()> { + let session_id = self.resolve_acp_session_id(goose_id).await?; + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .as_ref() + .unwrap() + .send(ClientRequest::SetModel { + session_id, + model_id, + response_tx, + }) + .await + .context("ACP client is unavailable")?; + response_rx.await.context("ACP request cancelled")? + } + + pub(crate) async fn send_set_config_option( + &self, + goose_id: &str, + config_id: String, + value: String, + ) -> Result<()> { + let session_id = self.resolve_acp_session_id(goose_id).await?; + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .as_ref() + .unwrap() + .send(ClientRequest::SetConfigOption { + session_id, + config_id, + value, + response_tx, + }) + .await + .context("ACP client is unavailable")?; + response_rx.await.context("ACP request cancelled")? + } + + // Only used by tests; session/delete has no typed request in agent-client-protocol-schema yet. + #[doc(hidden)] + pub async fn delete_session(&self, goose_id: &str) -> Result<()> { + let session_id = self.resolve_acp_session_id(goose_id).await?; + self.send_untyped( + "session/delete", + serde_json::json!({ "sessionId": session_id.0 }), + ) + .await?; + + // Clean up cached mappings so ensure_session doesn't return a stale entry. + self.goose_to_acp_id.lock().await.remove(goose_id); + self.acp_to_goose_id + .lock() + .await + .remove(session_id.0.as_ref()); + self.session_model.lock().await.remove(goose_id); + Ok(()) } pub async fn send_untyped( @@ -210,6 +400,8 @@ impl AcpProvider { ) -> Result { let (response_tx, response_rx) = oneshot::channel(); self.tx + .as_ref() + .unwrap() .send(ClientRequest::Untyped { method: method.to_string(), params, @@ -220,6 +412,22 @@ impl AcpProvider { response_rx.await.context("ACP request cancelled")? } + pub async fn has_session(&self, goose_id: &str) -> bool { + self.goose_to_acp_id.lock().await.contains_key(goose_id) + } + + // If false, callers fall back to legacy set_mode/set_model. + async fn session_has_config_option( + &self, + goose_id: &str, + category: SessionConfigOptionCategory, + ) -> bool { + let map = self.goose_to_acp_id.lock().await; + map.get(goose_id) + .and_then(|r| r.config_options.as_ref()) + .is_some_and(|opts| opts.iter().any(|o| o.category.as_ref() == Some(&category))) + } + pub async fn handle_permission_confirmation( &self, request_id: &str, @@ -252,6 +460,18 @@ impl AcpProvider { .lock() .await .insert(session_id.to_string(), response.clone()); + self.acp_to_goose_id + .lock() + .await + .insert(response.session_id.0.to_string(), session_id.to_string()); + + // Initialize model tracking so stream() can detect changes. + let (current_model, _) = resolve_model_info(&self.name, &response)?; + self.session_model + .lock() + .await + .entry(session_id.to_string()) + .or_insert(current_model); } Ok(response) @@ -264,6 +484,8 @@ impl AcpProvider { ) -> Result> { let (response_tx, response_rx) = mpsc::channel(64); self.tx + .as_ref() + .unwrap() .send(ClientRequest::Prompt { session_id, content, @@ -273,6 +495,33 @@ impl AcpProvider { .context("ACP client is unavailable")?; Ok(response_rx) } + + async fn get_init_session(&self) -> Result<&NewSessionResponse> { + self.init_session + .get_or_try_init(|| async { + let response = self.new_session().await?; + if self.supports_close { + self.close_session_by_acp_id(response.session_id.clone()) + .await?; + } + Ok(response) + }) + .await + } + + async fn close_session_by_acp_id(&self, session_id: SessionId) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .as_ref() + .unwrap() + .send(ClientRequest::CloseSession { + session_id, + response_tx, + }) + .await + .context("ACP client is unavailable")?; + response_rx.await.context("ACP request cancelled")? + } } #[async_trait::async_trait] @@ -291,21 +540,25 @@ impl Provider for AcpProvider { // Pre-initialization: no ACP session yet, just store the mode. // The shared Arc> is read at session creation time. drop(map); - } else if let Some(acp_session_id) = map.get(session_id).map(|r| r.session_id.clone()) { - drop(map); - self.send_untyped( - "session/set_mode", - serde_json::json!({ - "sessionId": acp_session_id, - "modeId": mode.to_string().to_lowercase() - }), - ) - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to set mode: {e}")))?; } else { - return Err(ProviderError::RequestFailed(format!( - "Session not found: {session_id}" - ))); + drop(map); + let mode_str = self.mode_mapping[&mode].clone(); + if self + .session_has_config_option(session_id, SessionConfigOptionCategory::Mode) + .await + { + self.send_set_config_option(session_id, "mode".into(), mode_str) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to set mode: {e}")) + })?; + } else { + self.send_set_mode(session_id, mode_str) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to set mode: {e}")) + })?; + } } let mut current = self @@ -330,13 +583,42 @@ impl Provider for AcpProvider { async fn stream( &self, - _model_config: &ModelConfig, + model_config: &ModelConfig, session_id: &str, _system: &str, messages: &[Message], _tools: &[Tool], ) -> Result { let response = self.ensure_session(Some(session_id)).await?; + + // Provider trait has no update_model — stream() is the only place to forward model changes. + { + let new_model = &model_config.model_name; + let tracked = self.session_model.lock().await.get(session_id).cloned(); + if tracked.as_deref() != Some(new_model) { + if self + .session_has_config_option(session_id, SessionConfigOptionCategory::Model) + .await + { + self.send_set_config_option(session_id, "model".into(), new_model.clone()) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to set model: {e}")) + })?; + } else { + self.send_set_model(session_id, new_model.clone()) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to set model: {e}")) + })?; + } + self.session_model + .lock() + .await + .insert(session_id.to_string(), new_model.clone()); + } + } + let prompt_blocks = messages_to_prompt(messages); let mut rx = self .prompt(response.session_id, prompt_blocks) @@ -437,26 +719,23 @@ impl Provider for AcpProvider { } async fn fetch_supported_models(&self) -> Result, ProviderError> { - let response = self.ensure_session(None).await?; - Ok(response - .models - .map(|state| { - state - .available_models - .iter() - .map(|m| m.model_id.0.to_string()) - .collect() - }) - .unwrap_or_default()) + let response = self.get_init_session().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to create ACP session: {e}")) + })?; + let (_, available) = resolve_model_info(&self.name, response)?; + Ok(available) } } impl Drop for AcpProvider { fn drop(&mut self) { - let tx = self.tx.clone(); - tokio::spawn(async move { - let _ = tx.send(ClientRequest::Shutdown).await; - }); + // Join OS thread so session/close completes before runtime exits (reqwest InnerClientHandle pattern). + self.tx.take(); + if let Some(h) = self.loop_thread.take() { + if let Err(e) = h.join() { + tracing::debug!("AcpClientLoop thread panicked: {e:?}"); + } + } } } @@ -507,39 +786,62 @@ impl AcpClientLoop { self.run(transport, rx, init_tx).await } - async fn run( + async fn run( self, - transport: sacp::ByteStreams, + transport: impl sacp::ConnectTo + 'static, rx: &mut mpsc::Receiver, init_tx: oneshot::Sender>, - ) -> Result<()> - where - R: futures::AsyncRead + Unpin + Send + 'static, - W: futures::AsyncWrite + Unpin + Send + 'static, - { + ) -> Result<()> { let AcpClientLoop { config, goose_mode, prompt_response_tx, } = self; let notification_callback = config.notification_callback.clone(); + let reverse_modes = reverse_mode_mapping(&config.mode_mapping); - ClientToAgent::builder() + Client + .builder() .on_receive_notification( { let prompt_response_tx = prompt_response_tx.clone(); + let reverse_modes = reverse_modes.clone(); async move |notification: SessionNotification, _cx| { if let Some(ref cb) = notification_callback { cb(notification.clone()); } // stream() reads goose_mode at call time, so it must // reflect any prior set_mode before the next prompt. - if let SessionUpdate::CurrentModeUpdate(update) = ¬ification.update { - if let Ok(mode) = GooseMode::from_str(&update.current_mode_id.0) { - if let Ok(mut guard) = goose_mode.lock() { - *guard = mode; + match ¬ification.update { + SessionUpdate::CurrentModeUpdate(update) => { + if let Some(mode) = resolve_mode( + &reverse_modes, + update.current_mode_id.0.as_ref(), + &goose_mode, + ) { + if let Ok(mut guard) = goose_mode.lock() { + *guard = mode; + } + } + } + SessionUpdate::ConfigOptionUpdate(update) => { + for opt in &update.config_options { + if opt.category == Some(SessionConfigOptionCategory::Mode) { + if let SessionConfigKind::Select(sel) = &opt.kind { + if let Some(mode) = resolve_mode( + &reverse_modes, + sel.current_value.0.as_ref(), + &goose_mode, + ) { + if let Ok(mut guard) = goose_mode.lock() { + *guard = mode; + } + } + } + } } } + _ => {} } if let Some(tx) = prompt_response_tx .lock() @@ -583,7 +885,7 @@ impl AcpClientLoop { .on_receive_request( { let prompt_response_tx = prompt_response_tx.clone(); - async move |request: RequestPermissionRequest, request_cx, _connection_cx| { + async move |request: RequestPermissionRequest, responder, _connection_cx| { let (response_tx, response_rx) = oneshot::channel(); let handler = prompt_response_tx @@ -606,14 +908,13 @@ impl AcpClientLoop { let response = response_rx.await.unwrap_or_else(|_| { RequestPermissionResponse::new(RequestPermissionOutcome::Cancelled) }); - request_cx.respond(response) + responder.respond(response) } }, sacp::on_receive_request!(), ) - .connect_to(transport)? - .run_until(move |cx: JrConnectionCx| { - handle_requests(config, cx, rx, prompt_response_tx, init_tx) + .connect_with(transport, async move |cx: ConnectionTo| { + handle_requests(config, cx, rx, prompt_response_tx, init_tx).await }) .await?; @@ -640,53 +941,156 @@ async fn spawn_acp_process(config: &AcpProviderConfig) -> Result { cmd.spawn().context("failed to spawn ACP process") } +// sacp panics on Err from connect_with handlers, so log send failures instead of ?. +fn log_undelivered(result: Result<(), E>, method: &str) { + if let Err(e) = result { + tracing::debug!(method, error = ?e, "response not delivered"); + } +} + async fn handle_requests( config: AcpProviderConfig, - cx: JrConnectionCx, + cx: ConnectionTo, rx: &mut mpsc::Receiver, prompt_response_tx: Arc>>>, init_tx: oneshot::Sender>, ) -> Result<(), sacp::Error> { let mut init_tx = Some(init_tx); - let init_response = cx + let init_response: InitializeResponse = cx .send_request(InitializeRequest::new(ProtocolVersion::LATEST)) .block_task() .await .map_err(|err| { - let message = format!("ACP initialize failed: {err}"); + let message = format!("ACP {} failed: {err}", AGENT_METHOD_NAMES.initialize); + // Attempt to send a specific error to the ctor waiting on init_rx; if let Some(tx) = init_tx.take() { let _ = tx.send(Err(anyhow::anyhow!(message.clone()))); } sacp::Error::internal_error().data(message) })?; + let supports_close = init_response + .agent_capabilities + .session_capabilities + .close + .is_some(); let mcp_capabilities = init_response.agent_capabilities.mcp_capabilities.clone(); if let Some(tx) = init_tx.take() { - let _ = tx.send(Ok(init_response)); + log_undelivered(tx.send(Ok(init_response)), AGENT_METHOD_NAMES.initialize); } + let mut session_ids: Vec = Vec::new(); + while let Some(request) = rx.recv().await { match request { ClientRequest::NewSession { response_tx } => { - handle_new_session_request(&config, &cx, &mcp_capabilities, response_tx).await; + let mcp_servers = filter_supported_servers(&config.mcp_servers, &mcp_capabilities); + let session = cx + .send_request( + NewSessionRequest::new(config.work_dir.clone()).mcp_servers(mcp_servers), + ) + .block_task() + .await; + let result = match session { + Ok(session) => { + session_ids.push(session.session_id.clone()); + apply_session_mode(&config, &cx, session).await + } + Err(err) => Err(anyhow::anyhow!( + "ACP {} failed: {err}", + AGENT_METHOD_NAMES.session_new + )), + }; + log_undelivered(response_tx.send(result), AGENT_METHOD_NAMES.session_new); + } + ClientRequest::ListSessions { response_tx } => { + let result: Result = cx + .send_request(ListSessionsRequest::new()) + .block_task() + .await + .map_err(anyhow::Error::from); + log_undelivered(response_tx.send(result), AGENT_METHOD_NAMES.session_list); + } + ClientRequest::SetMode { + session_id, + mode_id, + response_tx, + } => { + let result: Result<()> = cx + .send_request(SetSessionModeRequest::new(session_id, mode_id)) + .block_task() + .await + .map(|_| ()) + .map_err(anyhow::Error::from); + log_undelivered( + response_tx.send(result), + AGENT_METHOD_NAMES.session_set_mode, + ); + } + ClientRequest::SetModel { + session_id, + model_id, + response_tx, + } => { + let result: Result<()> = cx + .send_request(SetSessionModelRequest::new(session_id, model_id)) + .block_task() + .await + .map(|_| ()) + .map_err(anyhow::Error::from); + log_undelivered( + response_tx.send(result), + AGENT_METHOD_NAMES.session_set_model, + ); + } + ClientRequest::SetConfigOption { + session_id, + config_id, + value, + response_tx, + } => { + let value_id = sacp::schema::SessionConfigValueId::new(value); + let req = SetSessionConfigOptionRequest::new(session_id, config_id, value_id); + let result: Result<()> = cx + .send_request(req) + .block_task() + .await + .map(|_| ()) + .map_err(anyhow::Error::from); + log_undelivered( + response_tx.send(result), + AGENT_METHOD_NAMES.session_set_config_option, + ); + } + ClientRequest::CloseSession { + session_id, + response_tx, + } => { + let result: Result<()> = cx + .send_request(CloseSessionRequest::new(session_id.clone())) + .block_task() + .await + .map(|_| ()) + .map_err(anyhow::Error::from); + session_ids.retain(|s| s != &session_id); + log_undelivered(response_tx.send(result), AGENT_METHOD_NAMES.session_close); } ClientRequest::Untyped { method, params, response_tx, } => { - // Untyped because sacp doesn't have typed client requests for - // session/set_mode and session/set_model yet. - let result = match sacp::UntypedMessage::new(&method, params) { - Ok(msg) => cx - .send_request(msg) - .block_task() - .await - .map_err(anyhow::Error::from), - Err(e) => Err(anyhow::Error::from(e)), - }; - let _ = response_tx.send(result); + let result: Result = + match sacp::UntypedMessage::new(&method, params) { + Ok(msg) => cx + .send_request(msg) + .block_task() + .await + .map_err(anyhow::Error::from), + Err(e) => Err(anyhow::Error::from(e)), + }; + log_undelivered(response_tx.send(result), &method); } ClientRequest::Prompt { session_id, @@ -695,52 +1099,50 @@ async fn handle_requests( } => { *prompt_response_tx.lock().unwrap() = Some(response_tx.clone()); - let response = cx + let response: Result = cx .send_request(PromptRequest::new(session_id, content)) .block_task() .await; match response { Ok(r) => { - let _ = response_tx.try_send(AcpUpdate::Complete(r.stop_reason)); + log_undelivered( + response_tx.try_send(AcpUpdate::Complete(r.stop_reason)), + AGENT_METHOD_NAMES.session_prompt, + ); } Err(e) => { - let _ = response_tx.try_send(AcpUpdate::Error(e.to_string())); + log_undelivered( + response_tx.try_send(AcpUpdate::Error(e.to_string())), + AGENT_METHOD_NAMES.session_prompt, + ); } } *prompt_response_tx.lock().unwrap() = None; } - ClientRequest::Shutdown => break, } } - Ok(()) -} - -async fn handle_new_session_request( - config: &AcpProviderConfig, - cx: &JrConnectionCx, - mcp_capabilities: &McpCapabilities, - response_tx: oneshot::Sender>, -) { - let mcp_servers = filter_supported_servers(&config.mcp_servers, mcp_capabilities); - let session = cx - .send_request(NewSessionRequest::new(config.work_dir.clone()).mcp_servers(mcp_servers)) - .block_task() - .await; - - let result = match session { - Ok(session) => apply_session_mode(config, cx, session).await, - Err(err) => Err(anyhow::anyhow!("ACP session/new failed: {err}")), - }; + // After loop exits (channel closed by Drop): + if supports_close { + for session_id in session_ids { + if let Err(e) = cx + .send_request(CloseSessionRequest::new(session_id.clone())) + .block_task() + .await + { + tracing::debug!(method = AGENT_METHOD_NAMES.session_close, session_id = %session_id, error = %e, "failed on shutdown"); + } + } + } - let _ = response_tx.send(result); + Ok(()) } async fn apply_session_mode( config: &AcpProviderConfig, - cx: &JrConnectionCx, + cx: &ConnectionTo, session: NewSessionResponse, ) -> Result { if let (Some(mode_id), Some(modes)) = (config.session_mode_id.clone(), session.modes.as_ref()) { @@ -758,13 +1160,19 @@ async fn apply_session_mode( available.join(", ") )); } - cx.send_request(SetSessionModeRequest::new( - session.session_id.clone(), - mode_id, - )) - .block_task() - .await - .map_err(|err| anyhow::anyhow!("ACP agent rejected session/set_mode: {err}"))?; + let _: SetSessionModeResponse = cx + .send_request(SetSessionModeRequest::new( + session.session_id.clone(), + mode_id, + )) + .block_task() + .await + .map_err(|err| { + anyhow::anyhow!( + "ACP agent rejected {}: {err}", + AGENT_METHOD_NAMES.session_set_mode + ) + })?; } } @@ -915,6 +1323,79 @@ fn build_action_required_message(request: &RequestPermissionRequest) -> Option Result<(String, Vec), ProviderError> { + if let Some(opts) = &response.config_options { + if let Some(sel) = opts.iter().find_map(|opt| { + if opt.category.as_ref() != Some(&SessionConfigOptionCategory::Model) { + return None; + } + match &opt.kind { + SessionConfigKind::Select(s) => Some(s), + _ => None, + } + }) { + let current = sel.current_value.0.to_string(); + let available = match &sel.options { + SessionConfigSelectOptions::Ungrouped(opts) => { + opts.iter().map(|o| o.value.0.to_string()).collect() + } + SessionConfigSelectOptions::Grouped(groups) => groups + .iter() + .flat_map(|g| g.options.iter().map(|o| o.value.0.to_string())) + .collect(), + _ => vec![], + }; + return Ok((current, available)); + } + } + + let models = response.models.as_ref().ok_or_else(|| { + ProviderError::RequestFailed(format!( + "{provider_name}: agent returned neither config_options nor models" + )) + })?; + let current = models.current_model_id.0.to_string(); + let available = models + .available_models + .iter() + .map(|am| am.model_id.0.to_string()) + .collect(); + Ok((current, available)) +} + +fn reverse_mode_mapping( + mode_mapping: &HashMap, +) -> HashMap> { + let mut reverse: HashMap> = HashMap::new(); + for (mode, id) in mode_mapping { + reverse.entry(id.clone()).or_default().push(*mode); + } + reverse +} + +// When multiple GooseModes map to the same provider ID (e.g. codex "read-only"), +// prefer the current mode if it's among candidates. +fn resolve_mode( + reverse_modes: &HashMap>, + mode_id: &str, + current: &Arc>, +) -> Option { + let candidates = reverse_modes.get(mode_id)?; + if candidates.len() == 1 { + return Some(candidates[0]); + } + let current = current.lock().ok()?; + if candidates.contains(&*current) { + Some(*current) + } else { + Some(candidates[0]) + } +} + fn permission_decision_from_mode(goose_mode: GooseMode) -> Option { match goose_mode { GooseMode::Auto => Some(PermissionDecision::AllowOnce), @@ -923,10 +1404,29 @@ fn permission_decision_from_mode(goose_mode: GooseMode) -> Option, +) -> ListSessionsResponse { + let sessions = response + .sessions + .into_iter() + .filter_map(|mut info| { + let goose_id = acp_to_goose.get(info.session_id.0.as_ref())?; + info.session_id = SessionId::new(goose_id.clone()); + Some(info) + }) + .collect(); + ListSessionsResponse::new(sessions) +} + #[cfg(test)] mod tests { use super::*; use crate::agents::extension::Envs; + use sacp::schema::{SessionConfigOption, SessionConfigSelectOption, SessionInfo}; use test_case::test_case; #[test_case( @@ -1020,4 +1520,225 @@ mod tests { let filtered = filter_supported_servers(&servers, &McpCapabilities::default()); assert!(filtered.is_empty()); } + + #[test_case( + ListSessionsResponse::new(vec![ + SessionInfo::new(SessionId::new("20260318_1"), "/Users/codefromthecrypt/oss/goose-2") + .title("Fix login bug".to_string()) + .updated_at("2026-03-18T07:02:42.549655Z".to_string()), + SessionInfo::new(SessionId::new("20260318_2"), "/tmp/test-acpx") + .title("Add caching layer".to_string()) + .updated_at("2026-03-18T07:05:01.123Z".to_string()), + ]), + HashMap::from([ + ("20260318_1".to_string(), "goose-session-1".to_string()), + ("20260318_2".to_string(), "goose-session-2".to_string()), + ]), + ListSessionsResponse::new(vec![ + SessionInfo::new(SessionId::new("goose-session-1"), "/Users/codefromthecrypt/oss/goose-2") + .title("Fix login bug".to_string()) + .updated_at("2026-03-18T07:02:42.549655Z".to_string()), + SessionInfo::new(SessionId::new("goose-session-2"), "/tmp/test-acpx") + .title("Add caching layer".to_string()) + .updated_at("2026-03-18T07:05:01.123Z".to_string()), + ]) + ; "all sessions mapped with all fields preserved" + )] + #[test_case( + ListSessionsResponse::new(vec![ + SessionInfo::new(SessionId::new("20260318_1"), "/Users/codefromthecrypt/oss/goose-2") + .title("Fix login bug".to_string()), + SessionInfo::new(SessionId::new("other-agent-session"), "/tmp/other") + .title("Not our session".to_string()), + ]), + HashMap::from([ + ("20260318_1".to_string(), "goose-session-1".to_string()), + ]), + ListSessionsResponse::new(vec![ + SessionInfo::new(SessionId::new("goose-session-1"), "/Users/codefromthecrypt/oss/goose-2") + .title("Fix login bug".to_string()), + ]) + ; "unmapped sessions filtered out" + )] + #[test_case( + ListSessionsResponse::new(vec![ + SessionInfo::new(SessionId::new("20260318_1"), "/Users/codefromthecrypt/oss/goose-2") + .title("ACP Session".to_string()) + .updated_at("2026-03-18T01:29:02.141700Z".to_string()), + ]), + HashMap::new(), + ListSessionsResponse::new(vec![]) + ; "empty map returns empty list" + )] + fn test_map_sessions_to_goose_ids( + response: ListSessionsResponse, + acp_to_goose: HashMap, + expected: ListSessionsResponse, + ) { + let result = map_sessions_to_goose_ids(response, &acp_to_goose); + assert_eq!(result, expected); + } + + #[test_case(GooseMode::Auto => Some(PermissionDecision::AllowOnce) ; "auto allows")] + #[test_case(GooseMode::Chat => Some(PermissionDecision::RejectOnce) ; "chat rejects")] + #[test_case(GooseMode::Approve => None ; "approve defers")] + #[test_case(GooseMode::SmartApprove => None ; "smart_approve defers")] + fn test_permission_decision_from_mode(mode: GooseMode) -> Option { + permission_decision_from_mode(mode) + } + + #[test_case( + HashMap::from([ + (GooseMode::Auto, "yolo".to_string()), + (GooseMode::Approve, "default".to_string()), + (GooseMode::SmartApprove, "auto_edit".to_string()), + (GooseMode::Chat, "plan".to_string()), + ]), + HashMap::from([ + ("yolo".to_string(), vec![GooseMode::Auto]), + ("default".to_string(), vec![GooseMode::Approve]), + ("auto_edit".to_string(), vec![GooseMode::SmartApprove]), + ("plan".to_string(), vec![GooseMode::Chat]), + ]) + ; "gemini provider mapping" + )] + #[test_case( + HashMap::from([ + (GooseMode::Auto, "bypassPermissions".to_string()), + (GooseMode::Approve, "default".to_string()), + (GooseMode::SmartApprove, "acceptEdits".to_string()), + (GooseMode::Chat, "plan".to_string()), + ]), + HashMap::from([ + ("bypassPermissions".to_string(), vec![GooseMode::Auto]), + ("default".to_string(), vec![GooseMode::Approve]), + ("acceptEdits".to_string(), vec![GooseMode::SmartApprove]), + ("plan".to_string(), vec![GooseMode::Chat]), + ]) + ; "claude provider mapping" + )] + #[test_case( + HashMap::from([ + (GooseMode::Auto, "full-access".to_string()), + (GooseMode::Approve, "read-only".to_string()), + (GooseMode::SmartApprove, "auto".to_string()), + (GooseMode::Chat, "read-only".to_string()), + ]), + HashMap::from([ + ("full-access".to_string(), vec![GooseMode::Auto]), + ("read-only".to_string(), vec![GooseMode::Approve, GooseMode::Chat]), + ("auto".to_string(), vec![GooseMode::SmartApprove]), + ]) + ; "codex duplicate read-only" + )] + fn test_reverse_mode_mapping( + forward: HashMap, + expected: HashMap>, + ) { + let result = reverse_mode_mapping(&forward); + assert_eq!(result.len(), expected.len()); + for (key, expected_modes) in &expected { + let actual = result.get(key).expect("missing key"); + assert_eq!( + actual.len(), + expected_modes.len(), + "length mismatch for key {key}" + ); + for mode in expected_modes { + assert!(actual.contains(mode), "missing {mode:?} for key {key}"); + } + } + } + + #[test_case( + NewSessionResponse::new("s1") + .models(sacp::schema::SessionModelState::new( + "default", + vec![ + sacp::schema::ModelInfo::new("default", "Default (recommended)"), + sacp::schema::ModelInfo::new("sonnet", "Sonnet"), + sacp::schema::ModelInfo::new("haiku", "Haiku"), + ], + )) + .config_options(vec![ + SessionConfigOption::select("model", "Model", "default", vec![ + SessionConfigSelectOption::new("default", "Default (recommended)"), + SessionConfigSelectOption::new("sonnet", "Sonnet"), + SessionConfigSelectOption::new("haiku", "Haiku"), + ]) + .category(SessionConfigOptionCategory::Model), + ]) + => Ok(("default".to_string(), vec!["default".to_string(), "sonnet".to_string(), "haiku".to_string()])) + ; "claude-agent-acp config_options supersedes models" + )] + #[test_case( + NewSessionResponse::new("s1") + .models(sacp::schema::SessionModelState::new( + "auto-gemini-3", + vec![ + sacp::schema::ModelInfo::new("auto-gemini-3", "Auto (Gemini 3)"), + sacp::schema::ModelInfo::new("auto-gemini-2.5", "Auto (Gemini 2.5)"), + sacp::schema::ModelInfo::new("gemini-2.5-pro", "gemini-2.5-pro"), + ], + )) + => Ok(("auto-gemini-3".to_string(), vec!["auto-gemini-3".to_string(), "auto-gemini-2.5".to_string(), "gemini-2.5-pro".to_string()])) + ; "gemini-acp falls back to models" + )] + #[test_case( + NewSessionResponse::new("s1") + => Err(ProviderError::RequestFailed( + "test: agent returned neither config_options nor models".to_string() + )) + ; "neither config_options nor models is an error" + )] + fn test_resolve_model_info( + response: NewSessionResponse, + ) -> Result<(String, Vec), ProviderError> { + resolve_model_info("test", &response) + } + + // Codex mapping: read-only maps to both Approve and Chat. + fn codex_reverse_modes() -> HashMap> { + HashMap::from([ + ("full-access".to_string(), vec![GooseMode::Auto]), + ( + "read-only".to_string(), + vec![GooseMode::Approve, GooseMode::Chat], + ), + ("auto".to_string(), vec![GooseMode::SmartApprove]), + ]) + } + + #[test_case( + "full-access", GooseMode::Auto, Some(GooseMode::Auto) + ; "unique mapping returns the only candidate" + )] + #[test_case( + "read-only", GooseMode::Approve, Some(GooseMode::Approve) + ; "duplicate prefers current when current is Approve" + )] + #[test_case( + "read-only", GooseMode::Chat, Some(GooseMode::Chat) + ; "duplicate prefers current when current is Chat" + )] + #[test_case( + "read-only", GooseMode::Auto, Some(GooseMode::Approve) + ; "duplicate falls back to first when current not in candidates" + )] + #[test_case( + "unknown-id", GooseMode::Auto, None + ; "unknown mode id returns None" + )] + fn test_resolve_mode(mode_id: &str, current: GooseMode, expected: Option) { + let reverse_modes = codex_reverse_modes(); + let current = Arc::new(Mutex::new(current)); + let result = resolve_mode(&reverse_modes, mode_id, ¤t); + // For the fallback case, just check we got *some* candidate (order is nondeterministic). + if mode_id == "read-only" && expected == Some(GooseMode::Approve) { + // Current (Auto) not in candidates — any candidate is valid. + assert!(result == Some(GooseMode::Approve) || result == Some(GooseMode::Chat)); + } else { + assert_eq!(result, expected); + } + } } diff --git a/crates/goose/src/config/goose_mode.rs b/crates/goose/src/config/goose_mode.rs index 7694e5ccfea0..dbfe8af1282f 100644 --- a/crates/goose/src/config/goose_mode.rs +++ b/crates/goose/src/config/goose_mode.rs @@ -8,6 +8,7 @@ use utoipa::ToSchema; Debug, Default, Eq, + Hash, PartialEq, Serialize, Deserialize, diff --git a/crates/goose/src/providers/claude_acp.rs b/crates/goose/src/providers/claude_acp.rs index 0003624d452b..468929c2ce4e 100644 --- a/crates/goose/src/providers/claude_acp.rs +++ b/crates/goose/src/providers/claude_acp.rs @@ -1,9 +1,11 @@ use anyhow::Result; use futures::future::BoxFuture; +use std::collections::HashMap; use std::path::PathBuf; use crate::acp::{ extension_configs_to_mcp_servers, AcpProvider, AcpProviderConfig, PermissionMapping, + ACP_CURRENT_MODEL, }; use crate::config::search_path::SearchPaths; use crate::config::{Config, GooseMode}; @@ -11,7 +13,6 @@ use crate::model::ModelConfig; use crate::providers::base::{ProviderDef, ProviderMetadata}; const CLAUDE_ACP_PROVIDER_NAME: &str = "claude-acp"; -pub const CLAUDE_ACP_DEFAULT_MODEL: &str = "default"; const CLAUDE_ACP_DOC_URL: &str = "https://github.com/zed-industries/claude-agent-acp"; const CLAUDE_ACP_BINARY: &str = "claude-agent-acp"; @@ -25,7 +26,7 @@ impl ProviderDef for ClaudeAcpProvider { CLAUDE_ACP_PROVIDER_NAME, "Claude Code", "ACP wrapper for Anthropic's Claude. Install: npm install -g @zed-industries/claude-agent-acp", - CLAUDE_ACP_DEFAULT_MODEL, + ACP_CURRENT_MODEL, vec![], CLAUDE_ACP_DOC_URL, vec![], @@ -51,6 +52,17 @@ impl ProviderDef for ClaudeAcpProvider { rejected_tool_status: sacp::schema::ToolCallStatus::Failed, }; + let mode_mapping = HashMap::from([ + // Closest to "autonomous": bypassPermissions skips confirmations. + (GooseMode::Auto, "bypassPermissions".to_string()), + // Claude Code's default matches "ask before risky actions". + (GooseMode::Approve, "default".to_string()), + // acceptEdits auto-accepts file edits but still prompts for risky ops. + (GooseMode::SmartApprove, "acceptEdits".to_string()), + // Plan mode disables tool execution, aligning with chat-only intent. + (GooseMode::Chat, "plan".to_string()), + ]); + let provider_config = AcpProviderConfig { command: resolved_command, args: vec![], @@ -59,7 +71,8 @@ impl ProviderDef for ClaudeAcpProvider { env_remove: vec!["CLAUDECODE".to_string()], work_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), mcp_servers: extension_configs_to_mcp_servers(&extensions), - session_mode_id: Some(map_goose_mode(goose_mode)), + session_mode_id: Some(mode_mapping[&goose_mode].clone()), + mode_mapping, permission_mapping, notification_callback: None, }; @@ -69,24 +82,3 @@ impl ProviderDef for ClaudeAcpProvider { }) } } - -fn map_goose_mode(goose_mode: GooseMode) -> String { - match goose_mode { - GooseMode::Auto => { - // Closest to "autonomous": Claude Code's bypassPermissions skips confirmations. - "bypassPermissions".to_string() - } - GooseMode::Approve => { - // Claude Code's default matches "ask before risky actions". - "default".to_string() - } - GooseMode::SmartApprove => { - // Best-effort: acceptEdits auto-accepts file edits but still prompts for risky ops. - "acceptEdits".to_string() - } - GooseMode::Chat => { - // Plan mode disables tool execution, aligning with chat-only intent. - "plan".to_string() - } - } -} diff --git a/crates/goose/src/providers/codex_acp.rs b/crates/goose/src/providers/codex_acp.rs index bf755ac15497..232f8145f85f 100644 --- a/crates/goose/src/providers/codex_acp.rs +++ b/crates/goose/src/providers/codex_acp.rs @@ -1,9 +1,11 @@ use anyhow::Result; use futures::future::BoxFuture; +use std::collections::HashMap; use std::path::PathBuf; use crate::acp::{ extension_configs_to_mcp_servers, AcpProvider, AcpProviderConfig, PermissionMapping, + ACP_CURRENT_MODEL, }; use crate::config::search_path::SearchPaths; use crate::config::{Config, GooseMode}; @@ -11,7 +13,6 @@ use crate::model::ModelConfig; use crate::providers::base::{ProviderDef, ProviderMetadata}; const CODEX_ACP_PROVIDER_NAME: &str = "codex-acp"; -pub const CODEX_ACP_DEFAULT_MODEL: &str = "gpt-5.2-codex"; const CODEX_ACP_DOC_URL: &str = "https://github.com/zed-industries/codex-acp"; pub struct CodexAcpProvider; @@ -24,7 +25,7 @@ impl ProviderDef for CodexAcpProvider { CODEX_ACP_PROVIDER_NAME, "Codex CLI", "ACP adapter for OpenAI's coding assistant. Install: npm install -g @zed-industries/codex-acp", - CODEX_ACP_DEFAULT_MODEL, + ACP_CURRENT_MODEL, vec![], CODEX_ACP_DOC_URL, vec![], @@ -74,6 +75,14 @@ impl ProviderDef for CodexAcpProvider { rejected_tool_status: sacp::schema::ToolCallStatus::Failed, }; + // Chat and Approve both map to "read-only". + let mode_mapping = HashMap::from([ + (GooseMode::Auto, "full-access".to_string()), + (GooseMode::Approve, "read-only".to_string()), + (GooseMode::SmartApprove, "auto".to_string()), + (GooseMode::Chat, "read-only".to_string()), + ]); + let provider_config = AcpProviderConfig { command: resolved_command, args, @@ -83,6 +92,7 @@ impl ProviderDef for CodexAcpProvider { mcp_servers, // Disabled until https://github.com/zed-industries/codex-acp/issues/179 is fixed. session_mode_id: None, + mode_mapping, permission_mapping, notification_callback: None, }; diff --git a/crates/goose/src/providers/gemini_acp.rs b/crates/goose/src/providers/gemini_acp.rs index 15c3d5dc9b5b..6b996e05091f 100644 --- a/crates/goose/src/providers/gemini_acp.rs +++ b/crates/goose/src/providers/gemini_acp.rs @@ -1,9 +1,11 @@ use anyhow::Result; use futures::future::BoxFuture; +use std::collections::HashMap; use std::path::PathBuf; use crate::acp::{ extension_configs_to_mcp_servers, AcpProvider, AcpProviderConfig, PermissionMapping, + ACP_CURRENT_MODEL, }; use crate::config::search_path::SearchPaths; use crate::config::{Config, GooseMode}; @@ -11,7 +13,6 @@ use crate::model::ModelConfig; use crate::providers::base::{ProviderDef, ProviderMetadata}; const GEMINI_ACP_PROVIDER_NAME: &str = "gemini-acp"; -pub const GEMINI_ACP_DEFAULT_MODEL: &str = "default"; const GEMINI_ACP_DOC_URL: &str = "https://github.com/google-gemini/gemini-cli"; pub struct GeminiAcpProvider; @@ -24,7 +25,7 @@ impl ProviderDef for GeminiAcpProvider { GEMINI_ACP_PROVIDER_NAME, "Gemini CLI (ACP)", "ACP provider for Google's Gemini CLI. Install: npm install -g @google/gemini-cli", - GEMINI_ACP_DEFAULT_MODEL, + ACP_CURRENT_MODEL, vec![], GEMINI_ACP_DOC_URL, vec![], @@ -48,11 +49,18 @@ impl ProviderDef for GeminiAcpProvider { }; let mut args = vec!["--acp".to_string()]; - if model.model_name != "default" { + if model.model_name != ACP_CURRENT_MODEL { args.push("--model".to_string()); args.push(model.model_name.clone()); } + let mode_mapping = HashMap::from([ + (GooseMode::Auto, "yolo".to_string()), + (GooseMode::Approve, "default".to_string()), + (GooseMode::SmartApprove, "auto_edit".to_string()), + (GooseMode::Chat, "plan".to_string()), + ]); + let provider_config = AcpProviderConfig { command: resolved_command, args, @@ -60,7 +68,8 @@ impl ProviderDef for GeminiAcpProvider { env_remove: vec![], work_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), mcp_servers: extension_configs_to_mcp_servers(&extensions), - session_mode_id: Some(map_goose_mode(goose_mode)), + session_mode_id: Some(mode_mapping[&goose_mode].clone()), + mode_mapping, permission_mapping, notification_callback: None, }; @@ -70,12 +79,3 @@ impl ProviderDef for GeminiAcpProvider { }) } } - -fn map_goose_mode(goose_mode: GooseMode) -> String { - match goose_mode { - GooseMode::Auto => "yolo".to_string(), - GooseMode::Approve => "default".to_string(), - GooseMode::SmartApprove => "auto_edit".to_string(), - GooseMode::Chat => "plan".to_string(), - } -} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 7e717fee8fd5..4a012e9b3067 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -5,13 +5,13 @@ use crate::model::ModelConfig; use crate::providers::errors::ProviderError; use anyhow::{anyhow, Result}; use base64::Engine; +use fs_err::File; use regex::Regex; use reqwest::{Response, StatusCode}; use rmcp::model::{AnnotateAble, ImageContent, RawImageContent}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::fmt::Display; -use std::fs::File; use std::io::{BufWriter, Read, Write}; use std::path::{Path, PathBuf}; use std::sync::OnceLock; @@ -436,10 +436,10 @@ impl RequestLog { let log_path = |i| logs_dir.join(format!("llm_request.{}.jsonl", i)); for i in (0..LOGS_TO_KEEP - 1).rev() { - let _ = std::fs::rename(log_path(i), log_path(i + 1)); + let _ = fs_err::rename(log_path(i), log_path(i + 1)); } - std::fs::rename(&self.temp_path, log_path(0))?; + fs_err::rename(&self.temp_path, log_path(0))?; } Ok(()) } diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index f1b7c8655f4a..e9678cbaf41c 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -1,6 +1,7 @@ use anyhow::Result; use dotenvy::dotenv; use futures::StreamExt; +use goose::acp::ACP_CURRENT_MODEL; use goose::agents::{Agent, AgentConfig, AgentEvent, GoosePlatform, PromptManager, SessionConfig}; use goose::config::{ExtensionConfig, GooseMode, PermissionManager}; use goose::conversation::message::{ActionRequiredData, Message, MessageContent}; @@ -10,14 +11,11 @@ use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; use goose::providers::azure::AZURE_DEFAULT_MODEL; use goose::providers::base::Provider; use goose::providers::bedrock::BEDROCK_DEFAULT_MODEL; -use goose::providers::claude_acp::CLAUDE_ACP_DEFAULT_MODEL; use goose::providers::claude_code::CLAUDE_CODE_DEFAULT_MODEL; use goose::providers::codex::CODEX_DEFAULT_MODEL; -use goose::providers::codex_acp::CODEX_ACP_DEFAULT_MODEL; use goose::providers::create_with_named_model; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::errors::ProviderError; -use goose::providers::gemini_acp::GEMINI_ACP_DEFAULT_MODEL; use goose::providers::google::GOOGLE_DEFAULT_MODEL; use goose::providers::litellm::LITELLM_DEFAULT_MODEL; use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; @@ -492,11 +490,11 @@ impl ProviderFixture { println!("==================="); assert!(!models.is_empty()); - let model_name = &self.provider.get_model_config().model_name; - // model names may be substrings (e.g. "sonnet" vs "claude-sonnet-4-5-20250929") + let resolved = &self.provider.get_model_config().model_name; + assert_ne!(resolved.as_str(), ACP_CURRENT_MODEL); assert!(models .iter() - .any(|m| m == model_name || m.contains(model_name) || model_name.contains(m))); + .any(|m| m == resolved || m.contains(resolved) || resolved.contains(m))); if let Some(alt) = &self.model_switch_name { assert!(models .iter() @@ -873,20 +871,17 @@ async fn test_codex_provider() -> Result<()> { // Requires: npm install -g @zed-industries/claude-agent-acp #[tokio::test] async fn test_claude_acp_provider() -> Result<()> { - ProviderTestConfig::with_agentic_provider( - "claude-acp", - CLAUDE_ACP_DEFAULT_MODEL, - "claude-agent-acp", - ) - .model_switch_name("sonnet") - .run() - .await + ProviderTestConfig::with_agentic_provider("claude-acp", ACP_CURRENT_MODEL, "claude-agent-acp") + .model_switch_name("sonnet") + .run() + .await } // Requires: npm install -g @zed-industries/codex-acp #[tokio::test] async fn test_codex_acp_provider() -> Result<()> { - ProviderTestConfig::with_agentic_provider("codex-acp", CODEX_ACP_DEFAULT_MODEL, "codex-acp") + ProviderTestConfig::with_agentic_provider("codex-acp", ACP_CURRENT_MODEL, "codex-acp") + .model_switch_name("gpt-5.4-mini") .run() .await } @@ -894,7 +889,11 @@ async fn test_codex_acp_provider() -> Result<()> { // Requires: npm install -g @google/gemini-cli #[tokio::test] async fn test_gemini_acp_provider() -> Result<()> { - ProviderTestConfig::with_agentic_provider("gemini-acp", GEMINI_ACP_DEFAULT_MODEL, "gemini") + // Don't run tests with ACP_CURRENT_MODEL, as gemini sets "auto-gemini-3" even when the user + // has no access to the Preview Release Channel, resulting in "Requested entity was not found." + // See https://github.com/google-gemini/gemini-cli/issues/22803 + ProviderTestConfig::with_agentic_provider("gemini-acp", "auto-gemini-2.5", "gemini") + .model_switch_name("gemini-2.5-flash") .run() .await } diff --git a/ui/acp/generate-schema.ts b/ui/acp/generate-schema.ts index 909542f84284..4f80d84d4d34 100644 --- a/ui/acp/generate-schema.ts +++ b/ui/acp/generate-schema.ts @@ -138,7 +138,7 @@ async function generateClient(meta: { methods: MethodMeta[] }) { for (const m of meta.methods) { const fnName = methodToCamelCase(m.method); - const fullMethod = `_goose/${m.method}`; + const fullMethod = m.method; let paramType = ""; let paramArg = ""; diff --git a/ui/acp/src/generated/client.gen.ts b/ui/acp/src/generated/client.gen.ts index 8de00f5a1efd..748c3fdc0a3c 100644 --- a/ui/acp/src/generated/client.gen.ts +++ b/ui/acp/src/generated/client.gen.ts @@ -19,7 +19,6 @@ import type { GetToolsResponse, ImportSessionRequest, ImportSessionResponse, - ListSessionsResponse, ReadResourceRequest, ReadResourceResponse, RemoveExtensionRequest, @@ -31,70 +30,60 @@ import { zGetSessionResponse, zGetToolsResponse, zImportSessionResponse, - zListSessionsResponse, zReadResourceResponse, } from './zod.gen.js'; -/** - * Typed client for Goose custom extension methods. - * Wraps an ExtMethodProvider (e.g. ClientSideConnection) with proper types and Zod validation. - */ export class GooseExtClient { constructor(private conn: ExtMethodProvider) {} - async extensionsAdd(params: AddExtensionRequest): Promise { + async GooseExtensionsAdd(params: AddExtensionRequest): Promise { await this.conn.extMethod("_goose/extensions/add", params); } - async extensionsRemove(params: RemoveExtensionRequest): Promise { + async GooseExtensionsRemove(params: RemoveExtensionRequest): Promise { await this.conn.extMethod("_goose/extensions/remove", params); } - async tools(params: GetToolsRequest): Promise { + async GooseTools(params: GetToolsRequest): Promise { const raw = await this.conn.extMethod("_goose/tools", params); return zGetToolsResponse.parse(raw) as GetToolsResponse; } - async resourceRead( + async GooseResourceRead( params: ReadResourceRequest, ): Promise { const raw = await this.conn.extMethod("_goose/resource/read", params); return zReadResourceResponse.parse(raw) as ReadResourceResponse; } - async workingDirUpdate(params: UpdateWorkingDirRequest): Promise { + async GooseWorkingDirUpdate(params: UpdateWorkingDirRequest): Promise { await this.conn.extMethod("_goose/working_dir/update", params); } - async sessionList(): Promise { - const raw = await this.conn.extMethod("_goose/session/list", {}); - return zListSessionsResponse.parse(raw) as ListSessionsResponse; - } - async sessionGet(params: GetSessionRequest): Promise { - const raw = await this.conn.extMethod("_goose/session/get", params); + const raw = await this.conn.extMethod("session/get", params); return zGetSessionResponse.parse(raw) as GetSessionResponse; } async sessionDelete(params: DeleteSessionRequest): Promise { - await this.conn.extMethod("_goose/session/delete", params); + await this.conn.extMethod("session/delete", params); } - async sessionExport( + async GooseSessionExport( params: ExportSessionRequest, ): Promise { const raw = await this.conn.extMethod("_goose/session/export", params); return zExportSessionResponse.parse(raw) as ExportSessionResponse; } - async sessionImport( + async GooseSessionImport( params: ImportSessionRequest, ): Promise { const raw = await this.conn.extMethod("_goose/session/import", params); return zImportSessionResponse.parse(raw) as ImportSessionResponse; } - async configExtensions(): Promise { + async GooseConfigExtensions(): Promise { const raw = await this.conn.extMethod("_goose/config/extensions", {}); return zGetExtensionsResponse.parse(raw) as GetExtensionsResponse; } diff --git a/ui/acp/src/generated/index.ts b/ui/acp/src/generated/index.ts index 2f6a85b441e9..a51e05eef1e2 100644 --- a/ui/acp/src/generated/index.ts +++ b/ui/acp/src/generated/index.ts @@ -1,38 +1,33 @@ // This file is auto-generated by @hey-api/openapi-ts -export type { AddExtensionRequest, DeleteSessionRequest, EmptyResponse, ExportSessionRequest, ExportSessionResponse, ExtRequest, ExtResponse, GetExtensionsResponse, GetSessionRequest, GetSessionResponse, GetToolsRequest, GetToolsResponse, ImportSessionRequest, ImportSessionResponse, ListSessionsResponse, ReadResourceRequest, ReadResourceResponse, RemoveExtensionRequest, UpdateWorkingDirRequest } from './types.gen.js'; +export type { AddExtensionRequest, DeleteSessionRequest, EmptyResponse, ExportSessionRequest, ExportSessionResponse, ExtRequest, ExtResponse, GetExtensionsResponse, GetSessionRequest, GetSessionResponse, GetToolsRequest, GetToolsResponse, ImportSessionRequest, ImportSessionResponse, ReadResourceRequest, ReadResourceResponse, RemoveExtensionRequest, UpdateWorkingDirRequest } from './types.gen.js'; export const GOOSE_EXT_METHODS = [ { - method: "extensions/add", + method: "_goose/extensions/add", requestType: "AddExtensionRequest", responseType: "EmptyResponse", }, { - method: "extensions/remove", + method: "_goose/extensions/remove", requestType: "RemoveExtensionRequest", responseType: "EmptyResponse", }, { - method: "tools", + method: "_goose/tools", requestType: "GetToolsRequest", responseType: "GetToolsResponse", }, { - method: "resource/read", + method: "_goose/resource/read", requestType: "ReadResourceRequest", responseType: "ReadResourceResponse", }, { - method: "working_dir/update", + method: "_goose/working_dir/update", requestType: "UpdateWorkingDirRequest", responseType: "EmptyResponse", }, - { - method: "session/list", - requestType: null, - responseType: "ListSessionsResponse", - }, { method: "session/get", requestType: "GetSessionRequest", @@ -44,17 +39,17 @@ export const GOOSE_EXT_METHODS = [ responseType: "EmptyResponse", }, { - method: "session/export", + method: "_goose/session/export", requestType: "ExportSessionRequest", responseType: "ExportSessionResponse", }, { - method: "session/import", + method: "_goose/session/import", requestType: "ImportSessionRequest", responseType: "ImportSessionResponse", }, { - method: "config/extensions", + method: "_goose/config/extensions", requestType: null, responseType: "GetExtensionsResponse", }, diff --git a/ui/acp/src/generated/types.gen.ts b/ui/acp/src/generated/types.gen.ts index e404c8f82c98..a787eb1ada77 100644 --- a/ui/acp/src/generated/types.gen.ts +++ b/ui/acp/src/generated/types.gen.ts @@ -3,10 +3,9 @@ /** * Add an extension to an active session. - * Method: `_agent/extensions/add` */ export type AddExtensionRequest = { - session_id: string; + sessionId: string; /** * Extension configuration (see ExtensionConfig variants: Stdio, StreamableHttp, Builtin, Platform). */ @@ -22,19 +21,17 @@ export type EmptyResponse = { /** * Remove an extension from an active session. - * Method: `_agent/extensions/remove` */ export type RemoveExtensionRequest = { - session_id: string; + sessionId: string; name: string; }; /** * List all tools available in a session. - * Method: `_agent/tools` */ export type GetToolsRequest = { - session_id: string; + sessionId: string; }; export type GetToolsResponse = { @@ -46,12 +43,11 @@ export type GetToolsResponse = { /** * Read a resource from an extension. - * Method: `_agent/resource/read` */ export type ReadResourceRequest = { - session_id: string; + sessionId: string; uri: string; - extension_name: string; + extensionName: string; }; export type ReadResourceResponse = { @@ -63,28 +59,18 @@ export type ReadResourceResponse = { /** * Update the working directory for a session. - * Method: `_agent/working_dir/update` */ export type UpdateWorkingDirRequest = { - session_id: string; - working_dir: string; -}; - -/** - * List all sessions. - * Method: `_session/list` - */ -export type ListSessionsResponse = { - sessions: Array; + sessionId: string; + workingDir: string; }; /** * Get a session by ID. - * Method: `_session/get` */ export type GetSessionRequest = { - session_id: string; - include_messages?: boolean; + sessionId: string; + includeMessages?: boolean; }; /** @@ -99,18 +85,16 @@ export type GetSessionResponse = { /** * Delete a session. - * Method: `_session/delete` */ export type DeleteSessionRequest = { - session_id: string; + sessionId: string; }; /** * Export a session as a JSON string. - * Method: `_session/export` */ export type ExportSessionRequest = { - session_id: string; + sessionId: string; }; export type ExportSessionResponse = { @@ -119,7 +103,6 @@ export type ExportSessionResponse = { /** * Import a session from a JSON string. - * Method: `_session/import` */ export type ImportSessionRequest = { data: string; @@ -134,7 +117,6 @@ export type ImportSessionResponse = { /** * List configured extensions and any warnings. - * Method: `_config/extensions` */ export type GetExtensionsResponse = { /** @@ -154,7 +136,7 @@ export type ExtRequest = { export type ExtResponse = { id: string; - result?: EmptyResponse | GetToolsResponse | ReadResourceResponse | ListSessionsResponse | GetSessionResponse | ExportSessionResponse | ImportSessionResponse | GetExtensionsResponse | unknown; + result?: EmptyResponse | GetToolsResponse | ReadResourceResponse | GetSessionResponse | ExportSessionResponse | ImportSessionResponse | GetExtensionsResponse | unknown; } | { error: { code: number; diff --git a/ui/acp/src/generated/zod.gen.ts b/ui/acp/src/generated/zod.gen.ts index 24cc5277007c..fe4fd5311b24 100644 --- a/ui/acp/src/generated/zod.gen.ts +++ b/ui/acp/src/generated/zod.gen.ts @@ -4,10 +4,9 @@ import { z } from 'zod'; /** * Add an extension to an active session. - * Method: `_agent/extensions/add` */ export const zAddExtensionRequest = z.object({ - session_id: z.string(), + sessionId: z.string(), config: z.unknown() }); @@ -18,19 +17,17 @@ export const zEmptyResponse = z.record(z.unknown()); /** * Remove an extension from an active session. - * Method: `_agent/extensions/remove` */ export const zRemoveExtensionRequest = z.object({ - session_id: z.string(), + sessionId: z.string(), name: z.string() }); /** * List all tools available in a session. - * Method: `_agent/tools` */ export const zGetToolsRequest = z.object({ - session_id: z.string() + sessionId: z.string() }); export const zGetToolsResponse = z.object({ @@ -39,12 +36,11 @@ export const zGetToolsResponse = z.object({ /** * Read a resource from an extension. - * Method: `_agent/resource/read` */ export const zReadResourceRequest = z.object({ - session_id: z.string(), + sessionId: z.string(), uri: z.string(), - extension_name: z.string() + extensionName: z.string() }); export const zReadResourceResponse = z.object({ @@ -53,28 +49,18 @@ export const zReadResourceResponse = z.object({ /** * Update the working directory for a session. - * Method: `_agent/working_dir/update` */ export const zUpdateWorkingDirRequest = z.object({ - session_id: z.string(), - working_dir: z.string() -}); - -/** - * List all sessions. - * Method: `_session/list` - */ -export const zListSessionsResponse = z.object({ - sessions: z.array(z.unknown()) + sessionId: z.string(), + workingDir: z.string() }); /** * Get a session by ID. - * Method: `_session/get` */ export const zGetSessionRequest = z.object({ - session_id: z.string(), - include_messages: z.boolean().optional().default(false) + sessionId: z.string(), + includeMessages: z.boolean().optional().default(false) }); /** @@ -86,18 +72,16 @@ export const zGetSessionResponse = z.object({ /** * Delete a session. - * Method: `_session/delete` */ export const zDeleteSessionRequest = z.object({ - session_id: z.string() + sessionId: z.string() }); /** * Export a session as a JSON string. - * Method: `_session/export` */ export const zExportSessionRequest = z.object({ - session_id: z.string() + sessionId: z.string() }); export const zExportSessionResponse = z.object({ @@ -106,7 +90,6 @@ export const zExportSessionResponse = z.object({ /** * Import a session from a JSON string. - * Method: `_session/import` */ export const zImportSessionRequest = z.object({ data: z.string() @@ -118,7 +101,6 @@ export const zImportSessionResponse = z.object({ /** * List configured extensions and any warnings. - * Method: `_config/extensions` */ export const zGetExtensionsResponse = z.object({ extensions: z.array(z.unknown()), @@ -155,7 +137,6 @@ export const zExtResponse = z.union([ zEmptyResponse, zGetToolsResponse, zReadResourceResponse, - zListSessionsResponse, zGetSessionResponse, zExportSessionResponse, zImportSessionResponse,