diff --git a/Cargo.lock b/Cargo.lock index dcec4a9c9676..181fe82ecfd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3248,6 +3248,7 @@ dependencies = [ "tracing-subscriber", "url", "utoipa", + "uuid", "winreg 0.55.0", "wiremock", ] diff --git a/crates/goose-acp/tests/common.rs b/crates/goose-acp/tests/common.rs index b2ab84911bb3..620d52199547 100644 --- a/crates/goose-acp/tests/common.rs +++ b/crates/goose-acp/tests/common.rs @@ -1,10 +1,13 @@ use assert_json_diff::{assert_json_matches_no_panic, CompareMode, Config}; +use goose::session_context::SESSION_ID_HEADER; +use rmcp::model::{ClientNotification, ClientRequest, Meta, ServerResult}; +use rmcp::service::{NotificationContext, RequestContext, ServiceRole}; use rmcp::transport::streamable_http_server::{ session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService, }; use rmcp::{ handler::server::router::tool::ToolRouter, model::*, tool, tool_handler, tool_router, - ErrorData as McpError, ServerHandler, + ErrorData as McpError, RoleServer, ServerHandler, Service, }; use std::collections::VecDeque; use std::sync::{Arc, Mutex}; @@ -14,57 +17,128 @@ use wiremock::{Mock, MockServer, ResponseTemplate}; pub const FAKE_CODE: &str = "test-uuid-12345-67890"; -/// Mock OpenAI streaming endpoint. Exchanges are (pattern, response) pairs. -/// On mismatch, returns 417 of the diff in OpenAI error format. -pub async fn setup_mock_openai(exchanges: Vec<(String, &'static str)>) -> MockServer { - let mock_server = MockServer::start().await; - let queue: VecDeque<(String, &'static str)> = exchanges.into_iter().collect(); - let queue = Arc::new(Mutex::new(queue)); - - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with({ - let queue = queue.clone(); - move |req: &wiremock::Request| { - let body = String::from_utf8_lossy(&req.body); - - // Special case session rename request which doesn't happen in a predictable order. - if body.contains("Reply with only a description in four words or less") { - return ResponseTemplate::new(200) - .insert_header("content-type", "application/json") - .set_body_string(include_str!( - "./test_data/openai_session_description.json" - )); - } +const NOT_YET_SET: &str = "session-id-not-yet-set"; - let (expected, response) = { - let mut q = queue.lock().unwrap(); - q.pop_front().unwrap_or_default() - }; +#[derive(Clone)] +pub struct ExpectedSessionId { + value: Arc>, + errors: Arc>>, +} - if body.contains(&expected) && !expected.is_empty() { - return ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_string(response); - } +impl Default for ExpectedSessionId { + fn default() -> Self { + Self { + value: Arc::new(Mutex::new(NOT_YET_SET.to_string())), + errors: Arc::new(Mutex::new(Vec::new())), + } + } +} + +impl ExpectedSessionId { + pub fn set(&self, id: &sacp::schema::SessionId) { + *self.value.lock().unwrap() = id.0.to_string(); + } + + pub fn validate(&self, actual: Option<&str>) -> Result<(), String> { + let expected = self.value.lock().unwrap(); - // Coerce non-json to allow a uniform JSON diff error response. - let exp = serde_json::from_str(&expected) - .unwrap_or(serde_json::Value::String(expected.clone())); - let act = serde_json::from_str(&body) - .unwrap_or(serde_json::Value::String(body.to_string())); - let diff = - assert_json_matches_no_panic(&exp, &act, Config::new(CompareMode::Strict)) - .unwrap_err(); - ResponseTemplate::new(417) - .insert_header("content-type", "text/event-stream") - .set_body_json(serde_json::json!({"error": {"message": diff}})) + let err = match actual { + Some(act) if act == *expected => None, + _ => Some(format!( + "{} mismatch: expected '{}', got {:?}", + SESSION_ID_HEADER, expected, actual + )), + }; + match err { + Some(e) => { + self.errors.lock().unwrap().push(e.clone()); + Err(e) } - }) - .mount(&mock_server) - .await; + None => Ok(()), + } + } - mock_server + /// Calling this ensures incidental requests that might error asynchronously, such as + /// session rename have coherent session IDs. + pub fn assert_no_errors(&self) { + let e = self.errors.lock().unwrap(); + assert!(e.is_empty(), "Session ID validation errors: {:?}", *e); + } +} + +pub struct OpenAiFixture { + pub server: MockServer, +} + +impl OpenAiFixture { + /// Mock OpenAI streaming endpoint. Exchanges are (pattern, response) pairs. + /// On mismatch, returns 417 of the diff in OpenAI error format. + pub async fn new( + exchanges: Vec<(String, &'static str)>, + expected_session_id: ExpectedSessionId, + ) -> Self { + let mock_server = MockServer::start().await; + let queue: VecDeque<(String, &'static str)> = exchanges.into_iter().collect(); + let queue = Arc::new(Mutex::new(queue)); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with({ + let queue = queue.clone(); + let expected_session_id = expected_session_id.clone(); + move |req: &wiremock::Request| { + let body = String::from_utf8_lossy(&req.body); + + let actual = req + .headers + .get(SESSION_ID_HEADER) + .and_then(|v| v.to_str().ok()); + if let Err(e) = expected_session_id.validate(actual) { + return ResponseTemplate::new(417) + .insert_header("content-type", "application/json") + .set_body_json(serde_json::json!({"error": {"message": e}})); + } + + // Session rename (async, unpredictable order) - canned response + if body.contains("Reply with only a description in four words or less") { + return ResponseTemplate::new(200) + .insert_header("content-type", "application/json") + .set_body_string(include_str!( + "./test_data/openai_session_description.json" + )); + } + + let (expected_body, response) = { + let mut q = queue.lock().unwrap(); + q.pop_front().unwrap_or_default() + }; + + if body.contains(&expected_body) && !expected_body.is_empty() { + return ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_string(response); + } + + // Coerce non-json to allow a uniform JSON diff error response. + let exp = serde_json::from_str(&expected_body) + .unwrap_or(serde_json::Value::String(expected_body.clone())); + let act = serde_json::from_str(&body) + .unwrap_or(serde_json::Value::String(body.to_string())); + let diff = + assert_json_matches_no_panic(&exp, &act, Config::new(CompareMode::Strict)) + .unwrap_err(); + ResponseTemplate::new(417) + .insert_header("content-type", "application/json") + .set_body_json(serde_json::json!({"error": {"message": diff}})) + } + }) + .mount(&mock_server) + .await; + + Self { + server: mock_server, + } + } } #[derive(Clone)] @@ -108,20 +182,108 @@ impl ServerHandler for Lookup { } } -pub async fn spawn_mcp_http_server() -> (String, JoinHandle<()>) { - let service = StreamableHttpService::new( - || Ok(Lookup::new()), - LocalSessionManager::default().into(), - StreamableHttpServerConfig::default(), - ); - let router = axum::Router::new().nest_service("/mcp", service); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let url = format!("http://{addr}/mcp"); +trait HasMeta { + fn meta(&self) -> &Meta; +} + +impl HasMeta for RequestContext { + fn meta(&self) -> &Meta { + &self.meta + } +} + +impl HasMeta for NotificationContext { + fn meta(&self) -> &Meta { + &self.meta + } +} + +pub struct ValidatingService { + inner: S, + expected_session_id: ExpectedSessionId, +} + +impl ValidatingService { + pub fn new(inner: S, expected_session_id: ExpectedSessionId) -> Self { + Self { + inner, + expected_session_id, + } + } + + fn validate(&self, context: &C) -> Result<(), McpError> { + let actual = context + .meta() + .0 + .get(SESSION_ID_HEADER) + .and_then(|v| v.as_str()); + self.expected_session_id + .validate(actual) + .map_err(|e| McpError::new(ErrorCode::INVALID_REQUEST, e, None)) + } +} + +impl> Service for ValidatingService { + async fn handle_request( + &self, + request: ClientRequest, + context: RequestContext, + ) -> Result { + if !matches!(request, ClientRequest::InitializeRequest(_)) { + self.validate(&context)?; + } + self.inner.handle_request(request, context).await + } + + async fn handle_notification( + &self, + notification: ClientNotification, + context: NotificationContext, + ) -> Result<(), McpError> { + if !matches!(notification, ClientNotification::InitializedNotification(_)) { + self.validate(&context).ok(); + } + self.inner.handle_notification(notification, context).await + } + + fn get_info(&self) -> ServerInfo { + self.inner.get_info() + } +} - let handle = tokio::spawn(async move { - axum::serve(listener, router).await.unwrap(); - }); +pub struct McpFixture { + pub url: String, + // Keep the server alive in tests; underscore avoids unused field warnings. + _handle: JoinHandle<()>, +} + +impl McpFixture { + pub async fn new(expected_session_id: ExpectedSessionId) -> Self { + let service = StreamableHttpService::new( + { + let expected_session_id = expected_session_id.clone(); + move || { + Ok(ValidatingService::new( + Lookup::new(), + expected_session_id.clone(), + )) + } + }, + LocalSessionManager::default().into(), + StreamableHttpServerConfig::default(), + ); + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let url = format!("http://{addr}/mcp"); - (url, handle) + let handle = tokio::spawn(async move { + axum::serve(listener, router).await.unwrap(); + }); + + Self { + url, + _handle: handle, + } + } } diff --git a/crates/goose-acp/tests/server_test.rs b/crates/goose-acp/tests/server_test.rs index 51d4bb1b73d2..4df7b4e7ef3e 100644 --- a/crates/goose-acp/tests/server_test.rs +++ b/crates/goose-acp/tests/server_test.rs @@ -1,6 +1,6 @@ mod common; -use common::{setup_mock_openai, spawn_mcp_http_server, FAKE_CODE}; +use common::{ExpectedSessionId, McpFixture, OpenAiFixture, FAKE_CODE}; use fs_err as fs; use goose::config::GooseMode; use goose::model::ModelConfig; @@ -26,19 +26,24 @@ use wiremock::MockServer; async fn test_acp_basic_completion() { let temp_dir = tempfile::tempdir().unwrap(); let prompt = "what is 1+1"; - let mock_server = setup_mock_openai(vec![( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_basic_response.txt"), - )]) + let expected_session_id = ExpectedSessionId::default(); + let openai = OpenAiFixture::new( + vec![( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_basic_response.txt"), + )], + expected_session_id.clone(), + ) .await; run_acp_session( - &mock_server, + &openai.server, vec![], &[], temp_dir.path(), GooseMode::Auto, None, + expected_session_id.clone(), |cx, session_id, updates| async move { let response = cx .send_request(PromptRequest::new( @@ -60,33 +65,39 @@ async fn test_acp_basic_completion() { }, ) .await; + + expected_session_id.assert_no_errors(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_acp_with_mcp_http_server() { let temp_dir = tempfile::tempdir().unwrap(); let prompt = "Use the get_code tool and output only its result."; - let (mcp_url, _handle) = spawn_mcp_http_server().await; - - let mock_server = setup_mock_openai(vec![ - ( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_tool_call_response.txt"), - ), - ( - format!(r#""content":"{FAKE_CODE}""#), - include_str!("./test_data/openai_tool_result_response.txt"), - ), - ]) + let expected_session_id = ExpectedSessionId::default(); + let mcp = McpFixture::new(expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("./test_data/openai_tool_result_response.txt"), + ), + ], + expected_session_id.clone(), + ) .await; run_acp_session( - &mock_server, - vec![McpServer::Http(McpServerHttp::new("lookup", mcp_url))], + &openai.server, + vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], &[], temp_dir.path(), GooseMode::Auto, None, + expected_session_id.clone(), |cx, session_id, updates| async move { let response = cx .send_request(PromptRequest::new( @@ -108,6 +119,8 @@ async fn test_acp_with_mcp_http_server() { }, ) .await; + + expected_session_id.assert_no_errors(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -115,35 +128,39 @@ async fn test_acp_with_builtin_and_mcp() { let temp_dir = tempfile::tempdir().unwrap(); let prompt = "Search for get_code and text_editor tools. Use them to save the code to /tmp/result.txt."; - let (lookup_url, _lookup_handle) = spawn_mcp_http_server().await; - - let mock_server = setup_mock_openai(vec![ - ( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_builtin_search.txt"), - ), - ( - r#"lookup/get_code: Get the code"#.into(), - include_str!("./test_data/openai_builtin_read_modules.txt"), - ), - ( - r#"lookup[\"get_code\"]({}): string - Get the code"#.into(), - include_str!("./test_data/openai_builtin_execute.txt"), - ), - ( - r#"Successfully wrote to /tmp/result.txt"#.into(), - include_str!("./test_data/openai_builtin_final.txt"), - ), - ]) + let expected_session_id = ExpectedSessionId::default(); + let mcp = McpFixture::new(expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_builtin_search.txt"), + ), + ( + r#"lookup/get_code: Get the code"#.into(), + include_str!("./test_data/openai_builtin_read_modules.txt"), + ), + ( + r#"lookup[\"get_code\"]({}): string - Get the code"#.into(), + include_str!("./test_data/openai_builtin_execute.txt"), + ), + ( + r#"Successfully wrote to /tmp/result.txt"#.into(), + include_str!("./test_data/openai_builtin_final.txt"), + ), + ], + expected_session_id.clone(), + ) .await; run_acp_session( - &mock_server, - vec![McpServer::Http(McpServerHttp::new("lookup", lookup_url))], + &openai.server, + vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], &["code_execution", "developer"], temp_dir.path(), GooseMode::Auto, None, + expected_session_id.clone(), |cx, session_id, updates| async move { let response = cx .send_request(PromptRequest::new( @@ -165,6 +182,8 @@ async fn test_acp_with_builtin_and_mcp() { }, ) .await; + + expected_session_id.assert_no_errors(); } async fn wait_for(updates: &Arc>>, expected: &SessionUpdate) { @@ -260,6 +279,7 @@ async fn spawn_server_in_process( (client_read, client_write, handle) } +#[allow(clippy::too_many_arguments)] async fn run_acp_session( mock_server: &MockServer, mcp_servers: Vec, @@ -267,6 +287,7 @@ async fn run_acp_session( data_root: &Path, mode: GooseMode, select: Option, + expected_session_id: ExpectedSessionId, test_fn: F, ) where F: FnOnce( @@ -319,6 +340,7 @@ async fn run_acp_session( .unwrap() .run_until({ let updates = updates.clone(); + let expected_session_id = expected_session_id.clone(); move |cx: JrConnectionCx| async move { cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) .block_task() @@ -331,6 +353,8 @@ async fn run_acp_session( .await .unwrap(); + expected_session_id.set(&session.session_id); + test_fn(cx.clone(), session.session_id, updates).await; Ok(()) } @@ -352,27 +376,31 @@ async fn test_permission_persistence( ) { let temp_dir = tempfile::tempdir().unwrap(); let prompt = "Use the get_code tool and output only its result."; - let (mcp_url, _handle) = spawn_mcp_http_server().await; - - let mock_server = setup_mock_openai(vec![ - ( - format!(r#"\n{prompt}""#), - include_str!("./test_data/openai_tool_call_response.txt"), - ), - ( - format!(r#""content":"{FAKE_CODE}""#), - include_str!("./test_data/openai_tool_result_response.txt"), - ), - ]) + let expected_session_id = ExpectedSessionId::default(); + let mcp = McpFixture::new(expected_session_id.clone()).await; + let openai = OpenAiFixture::new( + vec![ + ( + format!(r#"\n{prompt}""#), + include_str!("./test_data/openai_tool_call_response.txt"), + ), + ( + format!(r#""content":"{FAKE_CODE}""#), + include_str!("./test_data/openai_tool_result_response.txt"), + ), + ], + expected_session_id.clone(), + ) .await; run_acp_session( - &mock_server, - vec![McpServer::Http(McpServerHttp::new("lookup", mcp_url))], + &openai.server, + vec![McpServer::Http(McpServerHttp::new("lookup", &mcp.url))], &[], temp_dir.path(), GooseMode::Approve, kind, + expected_session_id.clone(), |cx, session_id, updates| async move { cx.send_request(PromptRequest::new( session_id, @@ -393,6 +421,8 @@ async fn test_permission_persistence( ) .await; + expected_session_id.assert_no_errors(); + assert_eq!( fs::read_to_string(temp_dir.path().join("permission.yaml")).unwrap_or_default(), expected_yaml diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 8247221c7697..48c9c8c54883 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -25,6 +25,7 @@ use goose::providers::{create, providers, retry_operation, RetryConfig}; use goose::session::SessionType; use serde_json::Value; use std::collections::HashMap; +use uuid::Uuid; // useful for light themes where there is no dicernible colour contrast between // cursor-selected and cursor-unselected items. @@ -682,8 +683,10 @@ pub async fn configure_provider_dialog() -> anyhow::Result { let models_res = { let temp_model_config = ModelConfig::new(&provider_meta.default_model)?; let temp_provider = create(provider_name, temp_model_config).await?; + // Provider setup runs before any user session exists; use an ephemeral id. + let session_id = Uuid::new_v4().to_string(); retry_operation(&RetryConfig::default(), || async { - temp_provider.fetch_recommended_models().await + temp_provider.fetch_recommended_models(&session_id).await }) .await }; @@ -1655,9 +1658,11 @@ pub async fn handle_openrouter_auth() -> anyhow::Result<()> { match create("openrouter", model_config).await { Ok(provider) => { - // Simple test request + // Config verification runs before any user session exists; use an ephemeral id. + let session_id = Uuid::new_v4().to_string(); let test_result = provider .complete( + &session_id, "You are goose, an AI assistant.", &[Message::user().with_text("Say 'Configuration test successful!'")], &[], @@ -1733,8 +1738,11 @@ pub async fn handle_tetrate_auth() -> anyhow::Result<()> { match create("tetrate", model_config).await { Ok(provider) => { + // Config verification runs before any user session exists; use an ephemeral id. + let session_id = Uuid::new_v4().to_string(); let test_result = provider .complete( + &session_id, "You are goose, an AI assistant.", &[Message::user().with_text("Say 'Configuration test successful!'")], &[], diff --git a/crates/goose-cli/src/scenario_tests/mock_client.rs b/crates/goose-cli/src/scenario_tests/mock_client.rs index 6bb6e6ceb36d..149e79eb0e8b 100644 --- a/crates/goose-cli/src/scenario_tests/mock_client.rs +++ b/crates/goose-cli/src/scenario_tests/mock_client.rs @@ -1,7 +1,7 @@ //! MockClient is a mock implementation of the McpClientTrait for testing purposes. //! add a tool you want to have around and then add the client to the extension router -use goose::agents::mcp_client::{Error, McpClientTrait, McpMeta}; +use goose::agents::mcp_client::{Error, McpClientTrait}; use rmcp::{ model::{ CallToolResult, Content, ErrorData, GetPromptResult, ListPromptsResult, @@ -44,6 +44,7 @@ impl MockClient { impl McpClientTrait for MockClient { async fn list_resources( &self, + _session_id: &str, _next_cursor: Option, _cancel_token: CancellationToken, ) -> Result { @@ -60,6 +61,7 @@ impl McpClientTrait for MockClient { async fn read_resource( &self, + _session_id: &str, _uri: &str, _cancel_token: CancellationToken, ) -> Result { @@ -68,6 +70,7 @@ impl McpClientTrait for MockClient { async fn list_tools( &self, + _session_id: &str, _: Option, _cancel_token: CancellationToken, ) -> Result { @@ -92,9 +95,9 @@ impl McpClientTrait for MockClient { async fn call_tool( &self, + _session_id: &str, name: &str, arguments: Option>, - _meta: McpMeta, _cancel_token: CancellationToken, ) -> Result { if let Some(handler) = self.handlers.get(name) { @@ -114,6 +117,7 @@ impl McpClientTrait for MockClient { async fn list_prompts( &self, + _session_id: &str, _next_cursor: Option, _cancel_token: CancellationToken, ) -> Result { @@ -126,6 +130,7 @@ impl McpClientTrait for MockClient { async fn get_prompt( &self, + _session_id: &str, _name: &str, _arguments: Value, _cancel_token: CancellationToken, diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 0ee42f406168..db351e20f93d 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -193,15 +193,16 @@ pub enum PlannerResponseType { /// to the user's message. The response is either a plan or a clarifying /// question. pub async fn classify_planner_response( + session_id: &str, message_text: String, provider: Arc, ) -> Result { let prompt = format!("The text below is the output from an AI model which can either provide a plan or list of clarifying questions. Based on the text below, decide if the output is a \"plan\" or \"clarifying questions\".\n---\n{message_text}"); - // Generate the description let message = Message::user().with_text(&prompt); let (result, _usage) = provider .complete( + session_id, "Reply only with the classification label: \"plan\" or \"clarifying questions\"", &[message], &[], @@ -367,7 +368,7 @@ impl CliSession { &mut self, extension: Option, ) -> Result>> { - let prompts = self.agent.list_extension_prompts().await; + let prompts = self.agent.list_extension_prompts(&self.session_id).await; // Early validation if filtering by extension if let Some(filter) = &extension { @@ -388,7 +389,7 @@ impl CliSession { } pub async fn get_prompt_info(&mut self, name: &str) -> Result> { - let prompts = self.agent.list_extension_prompts().await; + let prompts = self.agent.list_extension_prompts(&self.session_id).await; // Find which extension has this prompt for (extension, prompt_list) in prompts { @@ -406,7 +407,11 @@ impl CliSession { } pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result> { - Ok(self.agent.get_prompt(name, arguments).await?.messages) + Ok(self + .agent + .get_prompt(&self.session_id, name, arguments) + .await? + .messages) } /// Process a single message and get the response @@ -728,7 +733,10 @@ impl CliSession { println!("{}", console::style("Generating Recipe").green()); output::show_thinking(); - let recipe = self.agent.create_recipe(self.messages.clone()).await; + let recipe = self + .agent + .create_recipe(&self.session_id, self.messages.clone()) + .await; output::hide_thinking(); match recipe { @@ -782,16 +790,24 @@ impl CliSession { plan_messages: Conversation, reasoner: Arc, ) -> Result<(), anyhow::Error> { - let plan_prompt = self.agent.get_plan_prompt().await?; + let plan_prompt = self.agent.get_plan_prompt(&self.session_id).await?; output::show_thinking(); let (plan_response, _usage) = reasoner - .complete(&plan_prompt, plan_messages.messages(), &[]) + .complete( + &self.session_id, + &plan_prompt, + plan_messages.messages(), + &[], + ) .await?; output::render_message(&plan_response, self.debug); output::hide_thinking(); - let planner_response_type = - classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?) - .await?; + let planner_response_type = classify_planner_response( + &self.session_id, + plan_response.as_concat_text(), + self.agent.provider().await?, + ) + .await?; match planner_response_type { PlannerResponseType::Plan => { @@ -1154,7 +1170,7 @@ impl CliSession { /// This should be called before the interactive session starts pub async fn update_completion_cache(&mut self) -> Result<()> { // Get fresh data - let prompts = self.agent.list_extension_prompts().await; + let prompts = self.agent.list_extension_prompts(&self.session_id).await; // Update the cache with write lock let mut cache = self.completion_cache.write().unwrap(); diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index f3e42576f244..7a6142a1bcb1 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -44,6 +44,7 @@ hex = "0.4.3" socket2 = "0.6.1" fs2 = "0.4.3" rustls = { version = "0.23", features = ["ring"] } +uuid = { version = "1.19.0", features = ["v4"] } [target.'cfg(windows)'.dependencies] winreg = { version = "0.55.0" } diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 692da6dec831..ee0eafddaf6b 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -846,6 +846,7 @@ async fn read_resource( let read_result = agent .extension_manager .read_resource( + &payload.session_id, &payload.uri, &payload.extension_name, CancellationToken::default(), @@ -984,14 +985,14 @@ async fn list_apps( }; let agent = state - .get_agent_for_route(session_id) + .get_agent_for_route(session_id.clone()) .await .map_err(|status| ErrorResponse { message: "Failed to get agent".to_string(), status, })?; - let apps = fetch_mcp_apps(&agent.extension_manager) + let apps = fetch_mcp_apps(&agent.extension_manager, &session_id) .await .map_err(|e| ErrorResponse { message: format!("Failed to list apps: {}", e.message), diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index d1d5a18902a4..05b0d013269c 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -28,6 +28,7 @@ use serde_json::Value; use serde_yaml; use std::{collections::HashMap, sync::Arc}; use utoipa::ToSchema; +use uuid::Uuid; #[derive(Serialize, ToSchema)] pub struct ExtensionResponse { @@ -401,8 +402,10 @@ pub async fn get_provider_models( .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + // Config endpoints have no user session; use an ephemeral id for the probe. + let session_id = Uuid::new_v4().to_string(); let models_result = retry_operation(&RetryConfig::default(), || async { - provider.fetch_recommended_models().await + provider.fetch_recommended_models(&session_id).await }) .await; @@ -582,8 +585,10 @@ pub async fn detect_provider( Json(detect_request): Json, ) -> Result, StatusCode> { let api_key = detect_request.api_key.trim(); + // Provider detection runs without a user session; use an ephemeral id. + let session_id = Uuid::new_v4().to_string(); - match detect_provider_from_api_key(api_key).await { + match detect_provider_from_api_key(&session_id, api_key).await { Some((provider_name, models)) => Ok(Json(DetectProviderResponse { provider_name, models, diff --git a/crates/goose-server/src/routes/recipe.rs b/crates/goose-server/src/routes/recipe.rs index 34a4f0e73dc6..7e2875c35d71 100644 --- a/crates/goose-server/src/routes/recipe.rs +++ b/crates/goose-server/src/routes/recipe.rs @@ -179,7 +179,7 @@ async fn create_recipe( } }; - let conversation = match session.conversation { + let conversation = match session.conversation.clone() { Some(conversation) => conversation, None => { let error_message = "Session has no conversation".to_string(); @@ -193,7 +193,7 @@ async fn create_recipe( let agent = state.get_agent_for_route(request.session_id).await?; - let recipe_result = agent.create_recipe(conversation).await; + let recipe_result = agent.create_recipe(&session.id, conversation).await; match recipe_result { Ok(mut recipe) => { diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index 45815bef9a05..c241a9eb8c13 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -4,6 +4,7 @@ use goose::conversation::message::Message; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::{base::Usage, create_with_named_model}; use tokio_stream::StreamExt; +use uuid::Uuid; #[tokio::main] async fn main() -> Result<()> { @@ -19,8 +20,9 @@ async fn main() -> Result<()> { let message = Message::user().with_text("Tell me a short joke about programming."); // Get a response + let session_id = Uuid::new_v4().to_string(); let mut stream = provider - .stream("You are a helpful assistant.", &[message], &[]) + .stream(&session_id, "You are a helpful assistant.", &[message], &[]) .await?; println!("\nResponse from AI:"); diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index b2ffc188da6f..3bfbe2027cba 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -10,6 +10,7 @@ use rmcp::model::{CallToolRequestParam, Content, Tool}; use rmcp::object; use std::fs; use std::sync::Arc; +use uuid::Uuid; #[tokio::main] async fn main() -> Result<()> { @@ -61,8 +62,10 @@ async fn main() -> Result<()> { }, } }); + let session_id = Uuid::new_v4().to_string(); let (response, usage) = provider .complete( + &session_id, "You are a helpful assistant. Please describe any text you see in the image.", &messages, &[Tool::new("view_image", "View an image", input_schema)], diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 5f89c9c5d4e0..e62c5c33dbf8 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -23,8 +23,7 @@ use crate::agents::subagent_task_config::TaskConfig; use crate::agents::subagent_tool::{ create_subagent_tool, handle_subagent_tool, SUBAGENT_TOOL_NAME, }; -use crate::agents::types::SessionConfig; -use crate::agents::types::{FrontendTool, SharedProvider, ToolResultReceiver}; +use crate::agents::types::{FrontendTool, SessionConfig, SharedProvider, ToolResultReceiver}; use crate::config::permission::PermissionManager; use crate::config::{get_enabled_extensions, Config, GooseMode}; use crate::context_mgmt::{ @@ -785,7 +784,7 @@ impl Agent { pub async fn list_tools(&self, session_id: &str, extension_name: Option) -> Vec { let mut prefixed_tools = self .extension_manager - .get_prefixed_tools(extension_name.clone()) + .get_prefixed_tools(session_id, extension_name.clone()) .await .unwrap_or_default(); @@ -999,7 +998,14 @@ impl Agent { ) ); - match compact_messages(self.provider().await?.as_ref(), &conversation_to_compact, false).await { + match compact_messages( + self.provider().await?.as_ref(), + &session_config.id, + &conversation_to_compact, + false, + ) + .await + { Ok((compacted_conversation, summarization_usage)) => { session_manager.replace_conversation(&session_config.id, &compacted_conversation).await?; self.update_session_metrics(&session_config, &summarization_usage, true).await?; @@ -1041,7 +1047,7 @@ impl Agent { cancel_token: Option, ) -> Result>> { let context = self - .prepare_reply_context(&session_config.id, conversation, &session.working_dir) + .prepare_reply_context(&session.id, conversation, session.working_dir.as_path()) .await?; let ReplyContext { mut conversation, @@ -1108,6 +1114,7 @@ impl Agent { let mut stream = Self::stream_response_from_provider( self.provider().await?, + &session_config.id, &system_prompt, conversation_with_moim.messages(), &tools, @@ -1388,7 +1395,14 @@ impl Agent { ) ); - match compact_messages(self.provider().await?.as_ref(), &conversation, false).await { + match compact_messages( + self.provider().await?.as_ref(), + &session_config.id, + &conversation, + false, + ) + .await + { Ok((compacted_conversation, usage)) => { session_manager.replace_conversation(&session_config.id, &compacted_conversation).await?; self.update_session_metrics(&session_config, &usage, true).await?; @@ -1533,18 +1547,23 @@ impl Agent { prompt_manager.set_system_prompt_override(template); } - pub async fn list_extension_prompts(&self) -> HashMap> { + pub async fn list_extension_prompts(&self, session_id: &str) -> HashMap> { self.extension_manager - .list_prompts(CancellationToken::default()) + .list_prompts(session_id, CancellationToken::default()) .await .expect("Failed to list prompts") } - pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + pub async fn get_prompt( + &self, + session_id: &str, + name: &str, + arguments: Value, + ) -> Result { // First find which extension has this prompt let prompts = self .extension_manager - .list_prompts(CancellationToken::default()) + .list_prompts(session_id, CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; @@ -1555,7 +1574,13 @@ impl Agent { { return self .extension_manager - .get_prompt(extension, name, arguments, CancellationToken::default()) + .get_prompt( + session_id, + extension, + name, + arguments, + CancellationToken::default(), + ) .await .map_err(|e| anyhow!("Failed to get prompt: {}", e)); } @@ -1563,8 +1588,11 @@ impl Agent { Err(anyhow!("Prompt '{}' not found", name)) } - pub async fn get_plan_prompt(&self) -> Result { - let tools = self.extension_manager.get_prefixed_tools(None).await?; + pub async fn get_plan_prompt(&self, session_id: &str) -> Result { + let tools = self + .extension_manager + .get_prefixed_tools(session_id, None) + .await?; let tools_info = tools .into_iter() .map(|tool| { @@ -1591,13 +1619,19 @@ impl Agent { } } - pub async fn create_recipe(&self, mut messages: Conversation) -> Result { + pub async fn create_recipe( + &self, + session_id: &str, + mut messages: Conversation, + ) -> Result { tracing::info!("Starting recipe creation with {} messages", messages.len()); let extensions_info = self.extension_manager.get_extensions_info().await; tracing::debug!("Retrieved {} extensions info", extensions_info.len()); - let (extension_count, tool_count) = - self.extension_manager.get_extension_and_tool_counts().await; + let (extension_count, tool_count) = self + .extension_manager + .get_extension_and_tool_counts(session_id) + .await; // Get model name from provider let provider = self.provider().await.map_err(|e| { @@ -1619,7 +1653,7 @@ impl Agent { let recipe_prompt = prompt_manager.get_recipe_prompt().await; let tools = self .extension_manager - .get_prefixed_tools(None) + .get_prefixed_tools(session_id, None) .await .map_err(|e| { tracing::error!("Failed to get tools for recipe creation: {}", e); @@ -1651,7 +1685,7 @@ impl Agent { tracing::error!("{}", error); error })? - .complete(&system_prompt, messages.messages(), &tools) + .complete(session_id, &system_prompt, messages.messages(), &tools) .await .map_err(|e| { tracing::error!("Provider completion failed during recipe creation: {}", e); diff --git a/crates/goose/src/agents/chatrecall_extension.rs b/crates/goose/src/agents/chatrecall_extension.rs index f8c337e08d13..e946d5407be4 100644 --- a/crates/goose/src/agents/chatrecall_extension.rs +++ b/crates/goose/src/agents/chatrecall_extension.rs @@ -1,5 +1,5 @@ use crate::agents::extension::PlatformExtensionContext; -use crate::agents::mcp_client::{Error, McpClientTrait, McpMeta}; +use crate::agents::mcp_client::{Error, McpClientTrait}; use anyhow::Result; use async_trait::async_trait; use indoc::indoc; @@ -281,6 +281,7 @@ impl ChatRecallClient { impl McpClientTrait for ChatRecallClient { async fn list_tools( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -293,12 +294,11 @@ impl McpClientTrait for ChatRecallClient { async fn call_tool( &self, + session_id: &str, name: &str, arguments: Option, - meta: McpMeta, _cancellation_token: CancellationToken, ) -> Result { - let session_id = &meta.session_id; let content = match name { "chatrecall" => self.handle_chatrecall(session_id, arguments).await, _ => Err(format!("Unknown tool: {}", name)), diff --git a/crates/goose/src/agents/code_execution_extension.rs b/crates/goose/src/agents/code_execution_extension.rs index 85a529689f95..61f0b6d72201 100644 --- a/crates/goose/src/agents/code_execution_extension.rs +++ b/crates/goose/src/agents/code_execution_extension.rs @@ -1,6 +1,6 @@ use crate::agents::extension::PlatformExtensionContext; use crate::agents::extension_manager::get_parameter_names; -use crate::agents::mcp_client::{Error, McpClientTrait, McpMeta}; +use crate::agents::mcp_client::{Error, McpClientTrait}; use anyhow::Result; use async_trait::async_trait; use boa_engine::builtins::promise::PromiseState; @@ -458,7 +458,7 @@ impl CodeExecutionClient { Ok(Self { info, context }) } - async fn get_tool_infos(&self) -> Vec { + async fn get_tool_infos(&self, session_id: &str) -> Vec { let Some(manager) = self .context .extension_manager @@ -468,7 +468,10 @@ impl CodeExecutionClient { return Vec::new(); }; - match manager.get_prefixed_tools_excluding(EXTENSION_NAME).await { + match manager + .get_prefixed_tools_excluding(session_id, EXTENSION_NAME) + .await + { Ok(tools) if !tools.is_empty() => { tools.iter().filter_map(ToolInfo::from_mcp_tool).collect() } @@ -488,7 +491,7 @@ impl CodeExecutionClient { .ok_or("Missing required parameter: code")? .to_string(); - let tools = self.get_tool_infos().await; + let tools = self.get_tool_infos(session_id).await; let (call_tx, call_rx) = mpsc::unbounded_channel(); let tool_handler = tokio::spawn(Self::run_tool_handler( session_id.to_string(), @@ -506,6 +509,7 @@ impl CodeExecutionClient { async fn handle_read_module( &self, + session_id: &str, arguments: Option, ) -> Result, String> { let path = arguments @@ -514,7 +518,7 @@ impl CodeExecutionClient { .and_then(|v| v.as_str()) .ok_or("Missing required parameter: module_path")?; - let tools = self.get_tool_infos().await; + let tools = self.get_tool_infos(session_id).await; let parts: Vec<&str> = path.trim_start_matches('/').split('/').collect(); match parts.as_slice() { @@ -549,6 +553,7 @@ impl CodeExecutionClient { async fn handle_search_modules( &self, + session_id: &str, arguments: Option, ) -> Result, String> { let terms = arguments @@ -580,7 +585,7 @@ impl CodeExecutionClient { .and_then(|v| v.as_bool()) .unwrap_or(false); - let tools = self.get_tool_infos().await; + let tools = self.get_tool_infos(session_id).await; Self::handle_search(&tools, &terms_vec, use_regex) } @@ -707,6 +712,7 @@ impl McpClientTrait for CodeExecutionClient { #[allow(clippy::too_many_lines)] async fn list_tools( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -831,15 +837,15 @@ impl McpClientTrait for CodeExecutionClient { async fn call_tool( &self, + session_id: &str, name: &str, arguments: Option, - meta: McpMeta, _cancellation_token: CancellationToken, ) -> Result { let content = match name { - "execute_code" => self.handle_execute_code(&meta.session_id, arguments).await, - "read_module" => self.handle_read_module(arguments).await, - "search_modules" => self.handle_search_modules(arguments).await, + "execute_code" => self.handle_execute_code(session_id, arguments).await, + "read_module" => self.handle_read_module(session_id, arguments).await, + "search_modules" => self.handle_search_modules(session_id, arguments).await, _ => Err(format!("Unknown tool: {name}")), }; @@ -855,8 +861,8 @@ impl McpClientTrait for CodeExecutionClient { Some(&self.info) } - async fn get_moim(&self, _session_id: &str) -> Option { - let tools = self.get_tool_infos().await; + async fn get_moim(&self, session_id: &str) -> Option { + let tools = self.get_tool_infos(session_id).await; if tools.is_empty() { return None; } @@ -909,9 +915,9 @@ mod tests { let result = client .call_tool( + "test-session-id", "execute_code", Some(args), - McpMeta::new("test-session-id"), CancellationToken::new(), ) .await @@ -946,9 +952,9 @@ mod tests { let result = client .call_tool( + "test-session-id", "execute_code", Some(args), - McpMeta::new("test-session-id"), CancellationToken::new(), ) .await @@ -984,7 +990,9 @@ mod tests { Value::String("nonexistent".to_string()), ); - let result = client.handle_read_module(Some(args)).await; + let result = client + .handle_read_module("test-session-id", Some(args)) + .await; assert!(result.is_err()); } diff --git a/crates/goose/src/agents/execute_commands.rs b/crates/goose/src/agents/execute_commands.rs index a545b1d58354..bd81da503555 100644 --- a/crates/goose/src/agents/execute_commands.rs +++ b/crates/goose/src/agents/execute_commands.rs @@ -88,6 +88,7 @@ impl Agent { let (compacted_conversation, _usage) = compact_messages( self.provider().await?.as_ref(), + session_id, &conversation, true, // is_manual_compact ) @@ -128,11 +129,11 @@ impl Agent { async fn handle_prompts_command( &self, params: &[&str], - _session_id: &str, + session_id: &str, ) -> Result> { let extension_filter = params.first().map(|s| s.to_string()); - let prompts = self.list_extension_prompts().await; + let prompts = self.list_extension_prompts(session_id).await; if let Some(filter) = &extension_filter { if !prompts.contains_key(filter) { @@ -182,7 +183,7 @@ impl Agent { let is_info = params.get(1).map(|s| *s == "--info").unwrap_or(false); if is_info { - let prompts = self.list_extension_prompts().await; + let prompts = self.list_extension_prompts(session_id).await; let mut prompt_info = None; for (extension, prompt_list) in prompts { @@ -225,7 +226,10 @@ impl Agent { let arguments_value = serde_json::to_value(arguments) .map_err(|e| anyhow!("Failed to serialize arguments: {}", e))?; - match self.get_prompt(&prompt_name, arguments_value).await { + match self + .get_prompt(session_id, &prompt_name, arguments_value) + .await + { Ok(prompt_result) => { for (i, prompt_message) in prompt_result.messages.into_iter().enumerate() { let msg = Message::from(prompt_message); diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 15ef6fcd2cbf..be1a8f88b6aa 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -33,7 +33,7 @@ use super::tool_execution::ToolCallResult; use super::types::SharedProvider; use crate::agents::extension::{Envs, ProcessExit}; use crate::agents::extension_malware_check; -use crate::agents::mcp_client::{McpClient, McpClientTrait, McpMeta}; +use crate::agents::mcp_client::{McpClient, McpClientTrait}; use crate::config::search_path::SearchPaths; use crate::config::{get_all_extensions, Config}; use crate::oauth::oauth_flow; @@ -685,11 +685,11 @@ impl ExtensionManager { Ok(()) } - pub async fn get_extension_and_tool_counts(&self) -> (usize, usize) { + pub async fn get_extension_and_tool_counts(&self, session_id: &str) -> (usize, usize) { let enabled_extensions_count = self.extensions.lock().await.len(); let total_tools = self - .get_prefixed_tools(None) + .get_prefixed_tools(session_id, None) .await .map(|tools| tools.len()) .unwrap_or(0); @@ -717,14 +717,19 @@ impl ExtensionManager { /// Get all tools from all clients with proper prefixing pub async fn get_prefixed_tools( &self, + session_id: &str, extension_name: Option, ) -> ExtensionResult> { - let all_tools = self.get_all_tools_cached().await?; + let all_tools = self.get_all_tools_cached(session_id).await?; Ok(self.filter_tools(&all_tools, extension_name.as_deref(), None)) } - pub async fn get_prefixed_tools_excluding(&self, exclude: &str) -> ExtensionResult> { - let all_tools = self.get_all_tools_cached().await?; + pub async fn get_prefixed_tools_excluding( + &self, + session_id: &str, + exclude: &str, + ) -> ExtensionResult> { + let all_tools = self.get_all_tools_cached(session_id).await?; Ok(self.filter_tools(&all_tools, None, Some(exclude))) } @@ -755,7 +760,7 @@ impl ExtensionManager { .collect() } - async fn get_all_tools_cached(&self) -> ExtensionResult>> { + async fn get_all_tools_cached(&self, session_id: &str) -> ExtensionResult>> { { let cache = self.tools_cache.lock().await; if let Some(ref tools) = *cache { @@ -764,7 +769,7 @@ impl ExtensionManager { } let version_before = self.tools_cache_version.load(Ordering::SeqCst); - let tools = Arc::new(self.fetch_all_tools().await?); + let tools = Arc::new(self.fetch_all_tools(session_id).await?); { let mut cache = self.tools_cache.lock().await; @@ -782,7 +787,7 @@ impl ExtensionManager { *self.tools_cache.lock().await = None; } - async fn fetch_all_tools(&self) -> ExtensionResult> { + async fn fetch_all_tools(&self, session_id: &str) -> ExtensionResult> { let clients: Vec<_> = self .extensions .lock() @@ -799,7 +804,7 @@ impl ExtensionManager { let mut tools = Vec::new(); let client_guard = client.lock().await; let mut client_tools = match client_guard - .list_tools(None, cancel_token.clone()) + .list_tools(session_id, None, cancel_token.clone()) .await { Ok(t) => t, @@ -830,7 +835,7 @@ impl ExtensionManager { } client_tools = match client_guard - .list_tools(client_tools.next_cursor, cancel_token.clone()) + .list_tools(session_id, client_tools.next_cursor, cancel_token.clone()) .await { Ok(t) => t, @@ -876,6 +881,7 @@ impl ExtensionManager { // Function that gets executed for read_resource tool pub async fn read_resource_tool( &self, + session_id: &str, params: Value, cancellation_token: CancellationToken, ) -> Result, ErrorData> { @@ -886,7 +892,7 @@ impl ExtensionManager { // If extension name is provided, we can just look it up if let Some(ext_name) = extension_name { let read_result = self - .read_resource(uri, ext_name, cancellation_token.clone()) + .read_resource(session_id, uri, ext_name, cancellation_token.clone()) .await?; let mut result = Vec::new(); @@ -909,7 +915,7 @@ impl ExtensionManager { for extension_name in extension_names { let read_result = self - .read_resource(uri, &extension_name, cancellation_token.clone()) + .read_resource(session_id, uri, &extension_name, cancellation_token.clone()) .await; match read_result { Ok(read_result) => { @@ -949,6 +955,7 @@ impl ExtensionManager { pub async fn read_resource( &self, + session_id: &str, uri: &str, extension_name: &str, cancellation_token: CancellationToken, @@ -973,7 +980,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .read_resource(uri, cancellation_token) + .read_resource(session_id, uri, cancellation_token) .await .map_err(|_| { ErrorData::new( @@ -984,7 +991,10 @@ impl ExtensionManager { }) } - pub async fn get_ui_resources(&self) -> Result, ErrorData> { + pub async fn get_ui_resources( + &self, + session_id: &str, + ) -> Result, ErrorData> { let mut ui_resources = Vec::new(); let extensions_to_check: Vec<(String, McpClientBox)> = { @@ -999,7 +1009,7 @@ impl ExtensionManager { let client_guard = client.lock().await; match client_guard - .list_resources(None, CancellationToken::default()) + .list_resources(session_id, None, CancellationToken::default()) .await { Ok(list_response) => { @@ -1020,6 +1030,7 @@ impl ExtensionManager { async fn list_resources_from_extension( &self, + session_id: &str, extension_name: &str, cancellation_token: CancellationToken, ) -> Result, ErrorData> { @@ -1036,7 +1047,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .list_resources(None, cancellation_token) + .list_resources(session_id, None, cancellation_token) .await .map_err(|e| { ErrorData::new( @@ -1059,6 +1070,7 @@ impl ExtensionManager { pub async fn list_resources( &self, + session_id: &str, params: Value, cancellation_token: CancellationToken, ) -> Result, ErrorData> { @@ -1067,7 +1079,7 @@ impl ExtensionManager { match extension { Some(extension_name) => { // Handle single extension case - self.list_resources_from_extension(extension_name, cancellation_token) + self.list_resources_from_extension(session_id, extension_name, cancellation_token) .await } None => { @@ -1084,7 +1096,7 @@ impl ExtensionManager { .for_each(|name| { let token = cancellation_token.clone(); futures.push(async move { - self.list_resources_from_extension(&name.clone(), token) + self.list_resources_from_extension(session_id, name.as_str(), token) .await }); }); @@ -1190,9 +1202,8 @@ impl ExtensionManager { session_id ); let client_guard = client.lock().await; - let meta = McpMeta::new(&session_id); client_guard - .call_tool(&tool_name, arguments, meta, cancellation_token) + .call_tool(&session_id, &tool_name, arguments, cancellation_token) .await .map_err(|e| match e { ServiceError::McpError(error_data) => error_data, @@ -1210,6 +1221,7 @@ impl ExtensionManager { pub async fn list_prompts_from_extension( &self, + session_id: &str, extension_name: &str, cancellation_token: CancellationToken, ) -> Result, ErrorData> { @@ -1226,7 +1238,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .list_prompts(None, cancellation_token) + .list_prompts(session_id, None, cancellation_token) .await .map_err(|e| { ErrorData::new( @@ -1240,6 +1252,7 @@ impl ExtensionManager { pub async fn list_prompts( &self, + session_id: &str, cancellation_token: CancellationToken, ) -> Result>, ErrorData> { let mut futures = FuturesUnordered::new(); @@ -1250,7 +1263,7 @@ impl ExtensionManager { futures.push(async move { ( extension_name.clone(), - self.list_prompts_from_extension(extension_name.as_str(), token) + self.list_prompts_from_extension(session_id, extension_name.as_str(), token) .await, ) }); @@ -1287,6 +1300,7 @@ impl ExtensionManager { pub async fn get_prompt( &self, + session_id: &str, extension_name: &str, name: &str, arguments: Value, @@ -1299,7 +1313,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .get_prompt(name, arguments, cancellation_token) + .get_prompt(session_id, name, arguments, cancellation_token) .await .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) } @@ -1470,6 +1484,7 @@ mod tests { async fn list_resources( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -1478,6 +1493,7 @@ mod tests { async fn read_resource( &self, + _session_id: &str, _uri: &str, _cancellation_token: CancellationToken, ) -> Result { @@ -1486,6 +1502,7 @@ mod tests { async fn list_tools( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -1516,9 +1533,9 @@ mod tests { async fn call_tool( &self, + _session_id: &str, name: &str, _arguments: Option, - _meta: McpMeta, _cancellation_token: CancellationToken, ) -> Result { match name { @@ -1534,6 +1551,7 @@ mod tests { async fn list_prompts( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -1542,6 +1560,7 @@ mod tests { async fn get_prompt( &self, + _session_id: &str, _name: &str, _arguments: Value, _cancellation_token: CancellationToken, @@ -1767,7 +1786,10 @@ mod tests { ) .await; - let tools = extension_manager.get_prefixed_tools(None).await.unwrap(); + let tools = extension_manager + .get_prefixed_tools("test-session-id", None) + .await + .unwrap(); let tool_names: Vec = tools.iter().map(|t| t.name.to_string()).collect(); assert!(!tool_names.iter().any(|name| name == "test_extension__tool")); // Default unavailable @@ -1794,7 +1816,10 @@ mod tests { ) .await; - let tools = extension_manager.get_prefixed_tools(None).await.unwrap(); + let tools = extension_manager + .get_prefixed_tools("test-session-id", None) + .await + .unwrap(); let tool_names: Vec = tools.iter().map(|t| t.name.to_string()).collect(); assert!(tool_names.iter().any(|name| name == "test_extension__tool")); @@ -1956,7 +1981,10 @@ mod tests { ) .await; - let tools_after_first = extension_manager.get_prefixed_tools(None).await.unwrap(); + let tools_after_first = extension_manager + .get_prefixed_tools("test-session-id", None) + .await + .unwrap(); let tool_names: Vec = tools_after_first .iter() .map(|t| t.name.to_string()) @@ -1971,7 +1999,10 @@ mod tests { ) .await; - let tools_after_second = extension_manager.get_prefixed_tools(None).await.unwrap(); + let tools_after_second = extension_manager + .get_prefixed_tools("test-session-id", None) + .await + .unwrap(); let tool_names: Vec = tools_after_second .iter() .map(|t| t.name.to_string()) @@ -1999,14 +2030,20 @@ mod tests { ) .await; - let tools_before = extension_manager.get_prefixed_tools(None).await.unwrap(); + let tools_before = extension_manager + .get_prefixed_tools("test-session-id", None) + .await + .unwrap(); let tool_names: Vec = tools_before.iter().map(|t| t.name.to_string()).collect(); assert!(tool_names.iter().any(|n| n.starts_with("ext_a__"))); assert!(tool_names.iter().any(|n| n.starts_with("ext_b__"))); extension_manager.remove_extension("ext_b").await.unwrap(); - let tools_after = extension_manager.get_prefixed_tools(None).await.unwrap(); + let tools_after = extension_manager + .get_prefixed_tools("test-session-id", None) + .await + .unwrap(); let tool_names: Vec = tools_after.iter().map(|t| t.name.to_string()).collect(); assert!(tool_names.iter().any(|n| n.starts_with("ext_a__"))); assert!(!tool_names.iter().any(|n| n.starts_with("ext_b__"))); @@ -2032,7 +2069,7 @@ mod tests { .await; let tools = extension_manager - .get_prefixed_tools_excluding("ext_a") + .get_prefixed_tools_excluding("test-session-id", "ext_a") .await .unwrap(); let tool_names: Vec = tools.iter().map(|t| t.name.to_string()).collect(); @@ -2061,7 +2098,7 @@ mod tests { .await; let tools = extension_manager - .get_prefixed_tools(Some("ext_a".to_string())) + .get_prefixed_tools("test-session-id", Some("ext_a".to_string())) .await .unwrap(); let tool_names: Vec = tools.iter().map(|t| t.name.to_string()).collect(); diff --git a/crates/goose/src/agents/extension_manager_extension.rs b/crates/goose/src/agents/extension_manager_extension.rs index 377c592d3559..d57ba85d3861 100644 --- a/crates/goose/src/agents/extension_manager_extension.rs +++ b/crates/goose/src/agents/extension_manager_extension.rs @@ -1,5 +1,5 @@ use crate::agents::extension::PlatformExtensionContext; -use crate::agents::mcp_client::{Error, McpClientTrait, McpMeta}; +use crate::agents::mcp_client::{Error, McpClientTrait}; use crate::config::get_extension_by_name; use anyhow::Result; use async_trait::async_trait; @@ -224,6 +224,7 @@ impl ExtensionManagerClient { async fn handle_list_resources( &self, + session_id: &str, arguments: Option, ) -> Result, ExtensionManagerToolError> { if let Some(weak_ref) = &self.context.extension_manager { @@ -233,7 +234,11 @@ impl ExtensionManagerClient { .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); match extension_manager - .list_resources(params, tokio_util::sync::CancellationToken::default()) + .list_resources( + session_id, + params, + tokio_util::sync::CancellationToken::default(), + ) .await { Ok(content) => Ok(content), @@ -251,6 +256,7 @@ impl ExtensionManagerClient { async fn handle_read_resource( &self, + session_id: &str, arguments: Option, ) -> Result, ExtensionManagerToolError> { if let Some(weak_ref) = &self.context.extension_manager { @@ -260,7 +266,11 @@ impl ExtensionManagerClient { .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); match extension_manager - .read_resource_tool(params, tokio_util::sync::CancellationToken::default()) + .read_resource_tool( + session_id, + params, + tokio_util::sync::CancellationToken::default(), + ) .await { Ok(content) => Ok(content), @@ -390,6 +400,7 @@ impl ExtensionManagerClient { impl McpClientTrait for ExtensionManagerClient { async fn list_resources( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -398,6 +409,7 @@ impl McpClientTrait for ExtensionManagerClient { async fn read_resource( &self, + _session_id: &str, _uri: &str, _cancellation_token: CancellationToken, ) -> Result { @@ -407,6 +419,7 @@ impl McpClientTrait for ExtensionManagerClient { async fn list_tools( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -419,9 +432,9 @@ impl McpClientTrait for ExtensionManagerClient { async fn call_tool( &self, + session_id: &str, name: &str, arguments: Option, - _meta: McpMeta, _cancellation_token: CancellationToken, ) -> Result { let result = match name { @@ -429,8 +442,8 @@ impl McpClientTrait for ExtensionManagerClient { self.handle_search_available_extensions().await } MANAGE_EXTENSIONS_TOOL_NAME => self.handle_manage_extensions(arguments).await, - LIST_RESOURCES_TOOL_NAME => self.handle_list_resources(arguments).await, - READ_RESOURCE_TOOL_NAME => self.handle_read_resource(arguments).await, + LIST_RESOURCES_TOOL_NAME => self.handle_list_resources(session_id, arguments).await, + READ_RESOURCE_TOOL_NAME => self.handle_read_resource(session_id, arguments).await, _ => Err(ExtensionManagerToolError::UnknownTool { tool_name: name.to_string(), }), @@ -455,6 +468,7 @@ impl McpClientTrait for ExtensionManagerClient { async fn list_prompts( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -463,6 +477,7 @@ impl McpClientTrait for ExtensionManagerClient { async fn get_prompt( &self, + _session_id: &str, _name: &str, _arguments: Value, _cancellation_token: CancellationToken, diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index 3e46ee118800..bb767b5ce5c5 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -26,47 +26,35 @@ use rmcp::{ ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt, }; use serde_json::Value; -use std::{sync::Arc, time::Duration}; +use std::{ + sync::{Arc, OnceLock}, + time::Duration, +}; use tokio::sync::{ mpsc::{self, Sender}, Mutex, }; use tokio_util::sync::CancellationToken; +use uuid::Uuid; pub type BoxError = Box; pub type Error = rmcp::ServiceError; -#[derive(Clone, Debug)] -pub struct McpMeta { - pub session_id: String, -} - -impl McpMeta { - pub fn new(session_id: impl Into) -> Self { - Self { - session_id: session_id.into(), - } - } - - fn inject_into_extensions(&self, extensions: Extensions) -> Extensions { - inject_session_id_into_extensions(extensions, &self.session_id) - } -} - #[async_trait::async_trait] pub trait McpClientTrait: Send + Sync { async fn list_tools( &self, + session_id: &str, next_cursor: Option, cancel_token: CancellationToken, ) -> Result; async fn call_tool( &self, + session_id: &str, name: &str, arguments: Option, - meta: McpMeta, cancel_token: CancellationToken, ) -> Result; @@ -74,6 +62,7 @@ pub trait McpClientTrait: Send + Sync { async fn list_resources( &self, + _session_id: &str, _next_cursor: Option, _cancel_token: CancellationToken, ) -> Result { @@ -82,6 +71,7 @@ pub trait McpClientTrait: Send + Sync { async fn read_resource( &self, + _session_id: &str, _uri: &str, _cancel_token: CancellationToken, ) -> Result { @@ -90,6 +80,7 @@ pub trait McpClientTrait: Send + Sync { async fn list_prompts( &self, + _session_id: &str, _next_cursor: Option, _cancel_token: CancellationToken, ) -> Result { @@ -98,6 +89,7 @@ pub trait McpClientTrait: Send + Sync { async fn get_prompt( &self, + _session_id: &str, _name: &str, _arguments: Value, _cancel_token: CancellationToken, @@ -117,6 +109,10 @@ pub trait McpClientTrait: Send + Sync { pub struct GooseClient { notification_handlers: Arc>>>, provider: SharedProvider, + // Single-slot because calls are serialized per MCP client; see send_request_with_session. + current_session_id: Arc>>, + // Connection-scoped fallback for server-initiated sampling. + client_session_id: OnceLock, } impl GooseClient { @@ -127,7 +123,51 @@ impl GooseClient { GooseClient { notification_handlers: handlers, provider, + current_session_id: Arc::new(Mutex::new(None)), + client_session_id: OnceLock::new(), + } + } + + async fn set_current_session_id(&self, session_id: &str) { + let mut slot = self.current_session_id.lock().await; + *slot = Some(session_id.to_string()); + } + + async fn clear_current_session_id(&self) { + let mut slot = self.current_session_id.lock().await; + *slot = None; + } + + async fn current_session_id(&self) -> Option { + let slot = self.current_session_id.lock().await; + slot.clone() + } + + async fn resolve_session_id(&self, extensions: &Extensions) -> String { + // Prefer explicit MCP metadata, then the active request scope. + if let Some(session_id) = Self::session_id_from_extensions(extensions) { + return session_id; } + if let Some(session_id) = self.current_session_id().await { + return session_id; + } + // Fallback for server-initiated sampling not tied to a request session. + self.client_session_id() + } + + fn client_session_id(&self) -> String { + self.client_session_id + .get_or_init(|| Uuid::new_v4().to_string()) + .clone() + } + + fn session_id_from_extensions(extensions: &Extensions) -> Option { + let meta = extensions.get::()?; + meta.0 + .iter() + .find(|(key, _)| key.eq_ignore_ascii_case(SESSION_ID_HEADER)) + .and_then(|(_, value)| value.as_str()) + .map(|value| value.to_string()) } } @@ -175,7 +215,7 @@ impl ClientHandler for GooseClient { async fn create_message( &self, params: CreateMessageRequestParam, - _context: RequestContext, + context: RequestContext, ) -> Result { let provider = self .provider @@ -189,6 +229,9 @@ impl ClientHandler for GooseClient { ))? .clone(); + // Prefer explicit MCP metadata, then the active request scope. + let session_id = self.resolve_session_id(&context.extensions).await; + let provider_ready_messages: Vec = params .messages .iter() @@ -211,7 +254,7 @@ impl ClientHandler for GooseClient { .unwrap_or("You are a general-purpose AI agent called goose"); let (response, usage) = provider - .complete(system_prompt, &provider_ready_messages, &[]) + .complete(&session_id, system_prompt, &provider_ready_messages, &[]) .await .map_err(|e| { ErrorData::new( @@ -336,19 +379,38 @@ impl McpClient { }) } - async fn send_request( + async fn send_request_with_session( &self, + session_id: &str, request: ClientRequest, cancel_token: CancellationToken, ) -> Result { - let handle = self - .client - .lock() - .await - .send_cancellable_request(request, PeerRequestOptions::no_options()) - .await?; + let request = inject_session_id_into_request(request, session_id); + // ExtensionManager serializes calls per MCP connection, so one current_session_id slot + // is sufficient for mapping callbacks to the active request session. + let handle = { + let client = self.client.lock().await; + client.service().set_current_session_id(session_id).await; + client + .send_cancellable_request(request, PeerRequestOptions::no_options()) + .await + }; - await_response(handle, self.timeout, &cancel_token).await + let handle = match handle { + Ok(handle) => handle, + Err(err) => { + let client = self.client.lock().await; + client.service().clear_current_session_id().await; + return Err(err); + } + }; + + let result = await_response(handle, self.timeout, &cancel_token).await; + + let client = self.client.lock().await; + client.service().clear_current_session_id().await; + + result } } @@ -399,15 +461,17 @@ impl McpClientTrait for McpClient { async fn list_resources( &self, + session_id: &str, cursor: Option, cancel_token: CancellationToken, ) -> Result { let res = self - .send_request( + .send_request_with_session( + session_id, ClientRequest::ListResourcesRequest(ListResourcesRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: inject_current_session_id_into_extensions(Default::default()), + extensions: Default::default(), }), cancel_token, ) @@ -421,17 +485,19 @@ impl McpClientTrait for McpClient { async fn read_resource( &self, + session_id: &str, uri: &str, cancel_token: CancellationToken, ) -> Result { let res = self - .send_request( + .send_request_with_session( + session_id, ClientRequest::ReadResourceRequest(ReadResourceRequest { params: ReadResourceRequestParam { uri: uri.to_string(), }, method: Default::default(), - extensions: inject_current_session_id_into_extensions(Default::default()), + extensions: Default::default(), }), cancel_token, ) @@ -445,15 +511,17 @@ impl McpClientTrait for McpClient { async fn list_tools( &self, + session_id: &str, cursor: Option, cancel_token: CancellationToken, ) -> Result { let res = self - .send_request( + .send_request_with_session( + session_id, ClientRequest::ListToolsRequest(ListToolsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: inject_current_session_id_into_extensions(Default::default()), + extensions: Default::default(), }), cancel_token, ) @@ -467,27 +535,26 @@ impl McpClientTrait for McpClient { async fn call_tool( &self, + session_id: &str, name: &str, arguments: Option, - meta: McpMeta, cancel_token: CancellationToken, ) -> Result { - let res = self - .send_request( - ClientRequest::CallToolRequest(CallToolRequest { - params: CallToolRequestParam { - task: None, - name: name.to_string().into(), - arguments, - }, - method: Default::default(), - extensions: meta.inject_into_extensions(Default::default()), - }), - cancel_token, - ) - .await?; + let request = ClientRequest::CallToolRequest(CallToolRequest { + params: CallToolRequestParam { + task: None, + name: name.to_string().into(), + arguments, + }, + method: Default::default(), + extensions: Default::default(), + }); - match res { + let result = self + .send_request_with_session(session_id, request, cancel_token) + .await; + + match result? { ServerResult::CallToolResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } @@ -495,15 +562,17 @@ impl McpClientTrait for McpClient { async fn list_prompts( &self, + session_id: &str, cursor: Option, cancel_token: CancellationToken, ) -> Result { let res = self - .send_request( + .send_request_with_session( + session_id, ClientRequest::ListPromptsRequest(ListPromptsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: inject_current_session_id_into_extensions(Default::default()), + extensions: Default::default(), }), cancel_token, ) @@ -517,6 +586,7 @@ impl McpClientTrait for McpClient { async fn get_prompt( &self, + session_id: &str, name: &str, arguments: Value, cancel_token: CancellationToken, @@ -526,14 +596,15 @@ impl McpClientTrait for McpClient { _ => None, }; let res = self - .send_request( + .send_request_with_session( + session_id, ClientRequest::GetPromptRequest(GetPromptRequest { params: GetPromptRequestParam { name: name.to_string(), arguments, }, method: Default::default(), - extensions: inject_current_session_id_into_extensions(Default::default()), + extensions: Default::default(), }), cancel_token, ) @@ -571,103 +642,241 @@ fn inject_session_id_into_extensions(mut extensions: Extensions, session_id: &st extensions } -/// Injects session ID from task-local context into Extensions._meta. -fn inject_current_session_id_into_extensions(extensions: Extensions) -> Extensions { - if let Some(session_id) = crate::session_context::current_session_id() { - inject_session_id_into_extensions(extensions, &session_id) - } else { - extensions +fn inject_session_id_into_request(request: ClientRequest, session_id: &str) -> ClientRequest { + match request { + ClientRequest::ListResourcesRequest(mut req) => { + req.extensions = inject_session_id_into_extensions(req.extensions, session_id); + ClientRequest::ListResourcesRequest(req) + } + ClientRequest::ReadResourceRequest(mut req) => { + req.extensions = inject_session_id_into_extensions(req.extensions, session_id); + ClientRequest::ReadResourceRequest(req) + } + ClientRequest::ListToolsRequest(mut req) => { + req.extensions = inject_session_id_into_extensions(req.extensions, session_id); + ClientRequest::ListToolsRequest(req) + } + ClientRequest::CallToolRequest(mut req) => { + req.extensions = inject_session_id_into_extensions(req.extensions, session_id); + ClientRequest::CallToolRequest(req) + } + ClientRequest::ListPromptsRequest(mut req) => { + req.extensions = inject_session_id_into_extensions(req.extensions, session_id); + ClientRequest::ListPromptsRequest(req) + } + ClientRequest::GetPromptRequest(mut req) => { + req.extensions = inject_session_id_into_extensions(req.extensions, session_id); + ClientRequest::GetPromptRequest(req) + } + other => other, } } #[cfg(test)] mod tests { use super::*; - use rmcp::model::Meta; + use test_case::test_case; - #[tokio::test] - async fn test_session_id_in_mcp_meta() { - use serde_json::json; + fn new_client() -> GooseClient { + GooseClient::new(Arc::new(Mutex::new(Vec::new())), Arc::new(Mutex::new(None))) + } - let session_id = "test-session-789"; - crate::session_context::with_session_id(Some(session_id.to_string()), async { - let extensions = inject_current_session_id_into_extensions(Default::default()); - let meta = extensions.get::().unwrap(); - - assert_eq!( - &meta.0, - json!({ - SESSION_ID_HEADER: session_id - }) - .as_object() - .unwrap() - ); + fn request_extensions(request: &ClientRequest) -> Option<&Extensions> { + match request { + ClientRequest::ListResourcesRequest(req) => Some(&req.extensions), + ClientRequest::ReadResourceRequest(req) => Some(&req.extensions), + ClientRequest::ListToolsRequest(req) => Some(&req.extensions), + ClientRequest::CallToolRequest(req) => Some(&req.extensions), + ClientRequest::ListPromptsRequest(req) => Some(&req.extensions), + ClientRequest::GetPromptRequest(req) => Some(&req.extensions), + _ => None, + } + } + + fn list_resources_request(extensions: Extensions) -> ClientRequest { + ClientRequest::ListResourcesRequest(ListResourcesRequest { + params: Some(PaginatedRequestParam { cursor: None }), + method: Default::default(), + extensions, }) - .await; } - #[tokio::test] - async fn test_no_session_id_in_mcp_when_absent() { - let extensions = inject_current_session_id_into_extensions(Default::default()); - let meta = extensions.get::(); + fn read_resource_request(extensions: Extensions) -> ClientRequest { + ClientRequest::ReadResourceRequest(ReadResourceRequest { + params: ReadResourceRequestParam { + uri: "test://resource".to_string(), + }, + method: Default::default(), + extensions, + }) + } - assert!(meta.is_none()); + fn list_tools_request(extensions: Extensions) -> ClientRequest { + ClientRequest::ListToolsRequest(ListToolsRequest { + params: Some(PaginatedRequestParam { cursor: None }), + method: Default::default(), + extensions, + }) } - #[tokio::test] - async fn test_all_mcp_operations_include_session() { - use serde_json::json; + fn call_tool_request(extensions: Extensions) -> ClientRequest { + ClientRequest::CallToolRequest(CallToolRequest { + params: CallToolRequestParam { + task: None, + name: "tool".to_string().into(), + arguments: None, + }, + method: Default::default(), + extensions, + }) + } - let session_id = "consistent-session-id"; - crate::session_context::with_session_id(Some(session_id.to_string()), async { - let ext1 = inject_current_session_id_into_extensions(Default::default()); - let ext2 = inject_current_session_id_into_extensions(Default::default()); - let ext3 = inject_current_session_id_into_extensions(Default::default()); - - for ext in [&ext1, &ext2, &ext3] { - assert_eq!( - &ext.get::().unwrap().0, - json!({ - SESSION_ID_HEADER: session_id - }) - .as_object() - .unwrap() - ); - } + fn list_prompts_request(extensions: Extensions) -> ClientRequest { + ClientRequest::ListPromptsRequest(ListPromptsRequest { + params: Some(PaginatedRequestParam { cursor: None }), + method: Default::default(), + extensions, + }) + } + + fn get_prompt_request(extensions: Extensions) -> ClientRequest { + ClientRequest::GetPromptRequest(GetPromptRequest { + params: GetPromptRequestParam { + name: "prompt".to_string(), + arguments: None, + }, + method: Default::default(), + extensions, }) - .await; } - #[tokio::test] - async fn test_session_id_case_insensitive_replacement() { - use rmcp::model::{Extensions, Meta}; + #[test_case( + Some("ext-session"), + Some("current-session"), + "ext-session"; + "extensions win" + )] + #[test_case( + None, + Some("current-session"), + "current-session"; + "current when no extensions" + )] + #[test_case( + None, + None, + "client-session"; + "client fallback when no session" + )] + fn test_resolve_session_id( + ext_session: Option<&str>, + current_session: Option<&str>, + expected: &str, + ) { + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let client = new_client(); + // Make the fallback deterministic so the expected value can live in the test_case row. + client + .client_session_id + .get_or_init(|| "client-session".to_string()); + if let Some(session_id) = current_session { + let mut slot = client.current_session_id.lock().await; + *slot = Some(session_id.to_string()); + } + + let mut extensions = Extensions::new(); + if let Some(session_id) = ext_session { + extensions = inject_session_id_into_extensions(extensions, session_id); + } + + let resolved = client.resolve_session_id(&extensions).await; + + assert_eq!(resolved, expected); + }); + } + + #[test_case(list_resources_request; "list_resources")] + #[test_case(read_resource_request; "read_resource")] + #[test_case(list_tools_request; "list_tools")] + #[test_case(call_tool_request; "call_tool")] + #[test_case(list_prompts_request; "list_prompts")] + #[test_case(get_prompt_request; "get_prompt")] + fn test_request_injects_session(request_builder: fn(Extensions) -> ClientRequest) { + use serde_json::json; + + let session_id = "test-session-id"; + let mut extensions = Extensions::new(); + extensions.insert( + serde_json::from_value::(json!({ + "Goose-Session-Id": "old-session-id", + "other-key": "preserve-me" + })) + .unwrap(), + ); + + let request = request_builder(extensions); + let request = inject_session_id_into_request(request, session_id); + let extensions = request_extensions(&request).expect("request should have extensions"); + let meta = extensions + .get::() + .expect("extensions should contain meta"); + + assert_eq!( + meta.0.get(SESSION_ID_HEADER), + Some(&Value::String(session_id.to_string())) + ); + assert_eq!( + meta.0.get("other-key"), + Some(&Value::String("preserve-me".to_string())) + ); + } + + #[test] + fn test_session_id_in_mcp_meta() { + use serde_json::json; + + let session_id = "test-session-789"; + let extensions = inject_session_id_into_extensions(Default::default(), session_id); + let mcp_meta = extensions.get::().unwrap(); + + assert_eq!( + &mcp_meta.0, + json!({ + SESSION_ID_HEADER: session_id + }) + .as_object() + .unwrap() + ); + } + + #[test] + fn test_session_id_case_insensitive_replacement() { + use rmcp::model::Extensions; use serde_json::{from_value, json}; let session_id = "new-session-id"; - crate::session_context::with_session_id(Some(session_id.to_string()), async { - let mut extensions = Extensions::new(); - extensions.insert( - from_value::(json!({ - "GOOSE-SESSION-ID": "old-session-1", - "Goose-Session-Id": "old-session-2", - "other-key": "preserve-me" - })) - .unwrap(), - ); - - let extensions = inject_current_session_id_into_extensions(extensions); - let meta = extensions.get::().unwrap(); - - assert_eq!( - &meta.0, - json!({ - SESSION_ID_HEADER: session_id, - "other-key": "preserve-me" - }) - .as_object() - .unwrap() - ); - }) - .await; + let mut extensions = Extensions::new(); + extensions.insert( + from_value::(json!({ + "GOOSE-SESSION-ID": "old-session-1", + "Goose-Session-Id": "old-session-2", + "other-key": "preserve-me" + })) + .unwrap(), + ); + + let extensions = inject_session_id_into_extensions(extensions, session_id); + let mcp_meta = extensions.get::().unwrap(); + + assert_eq!( + &mcp_meta.0, + json!({ + SESSION_ID_HEADER: session_id, + "other-key": "preserve-me" + }) + .as_object() + .unwrap() + ); } } diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index ae1e00421058..f576ec5a1d63 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -7,6 +7,8 @@ use serde_json::{json, Value}; use tracing::debug; use super::super::agents::Agent; +use crate::agents::code_execution_extension::EXTENSION_NAME as CODE_EXECUTION_EXTENSION; +use crate::agents::subagent_tool::SUBAGENT_TOOL_NAME; use crate::conversation::message::{Message, MessageContent, ToolRequest}; use crate::conversation::Conversation; use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; @@ -15,11 +17,6 @@ use crate::providers::toolshim::{ augment_message_with_tool_calls, convert_tool_messages_to_text, modify_system_prompt_for_tool_json, OllamaInterpreter, }; - -use crate::agents::code_execution_extension::EXTENSION_NAME as CODE_EXECUTION_EXTENSION; -use crate::agents::subagent_tool::SUBAGENT_TOOL_NAME; -#[cfg(test)] -use crate::session::SessionType; use rmcp::model::Tool; fn coerce_value(s: &str, schema: &Value) -> Value { @@ -139,8 +136,10 @@ impl Agent { // Prepare system prompt let extensions_info = self.extension_manager.get_extensions_info().await; - let (extension_count, tool_count) = - self.extension_manager.get_extension_and_tool_counts().await; + let (extension_count, tool_count) = self + .extension_manager + .get_extension_and_tool_counts(session_id) + .await; // Get model name from provider let provider = self.provider().await?; @@ -175,6 +174,7 @@ impl Agent { /// Handles toolshim transformations if needed pub(crate) async fn stream_response_from_provider( provider: Arc, + session_id: &str, system_prompt: &str, messages: &[Message], tools: &[Tool], @@ -201,6 +201,7 @@ impl Agent { debug!("WAITING_LLM_STREAM_START"); let result = provider .stream( + session_id, system_prompt.as_str(), messages_for_provider.messages(), &tools, @@ -212,6 +213,7 @@ impl Agent { debug!("WAITING_LLM_START"); let complete_result = provider .complete( + session_id, system_prompt.as_str(), messages_for_provider.messages(), &tools, @@ -408,6 +410,7 @@ mod tests { use crate::model::ModelConfig; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::errors::ProviderError; + use crate::session::session_manager::SessionType; use async_trait::async_trait; use rmcp::object; @@ -432,6 +435,7 @@ mod tests { async fn complete_with_model( &self, + _session_id: &str, _model_config: &ModelConfig, _system: &str, _messages: &[Message], @@ -452,7 +456,7 @@ mod tests { .config .session_manager .create_session( - std::path::PathBuf::default(), + std::env::current_dir().unwrap(), "test-prepare-tools".to_string(), SessionType::Hidden, ) @@ -488,9 +492,8 @@ mod tests { .await .unwrap(); - let working_dir = std::env::current_dir()?; let (tools, _toolshim_tools, _system_prompt) = agent - .prepare_tools_and_prompt(&session.id, &working_dir) + .prepare_tools_and_prompt(&session.id, session.working_dir.as_path()) .await?; let names: Vec = tools.iter().map(|t| t.name.clone().into_owned()).collect(); diff --git a/crates/goose/src/agents/skills_extension.rs b/crates/goose/src/agents/skills_extension.rs index be841bb192bf..ac75f45b2111 100644 --- a/crates/goose/src/agents/skills_extension.rs +++ b/crates/goose/src/agents/skills_extension.rs @@ -1,5 +1,5 @@ use crate::agents::extension::PlatformExtensionContext; -use crate::agents::mcp_client::{Error, McpClientTrait, McpMeta}; +use crate::agents::mcp_client::{Error, McpClientTrait}; use crate::config::paths::Paths; use anyhow::Result; use async_trait::async_trait; @@ -263,6 +263,7 @@ impl SkillsClient { impl McpClientTrait for SkillsClient { async fn list_tools( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -280,9 +281,9 @@ impl McpClientTrait for SkillsClient { async fn call_tool( &self, + _session_id: &str, name: &str, arguments: Option, - _meta: McpMeta, _cancellation_token: CancellationToken, ) -> Result { let content = match name { @@ -601,7 +602,7 @@ Content from dir3 }; let result = client - .list_tools(None, CancellationToken::new()) + .list_tools("test-session-id", None, CancellationToken::new()) .await .unwrap(); assert_eq!(result.tools.len(), 0); @@ -656,7 +657,7 @@ Content }; let result = client - .list_tools(None, CancellationToken::new()) + .list_tools("test-session-id", None, CancellationToken::new()) .await .unwrap(); assert_eq!(result.tools.len(), 1); diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 4561c59fc388..54e0d2dab827 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -182,13 +182,10 @@ fn get_agent_messages( retry_config: recipe.retry, }; - let mut stream = crate::session_context::with_session_id(Some(session_id.clone()), async { - agent - .reply(user_message, session_config, cancellation_token) - .await - }) - .await - .map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?; + let mut stream = agent + .reply(user_message, session_config, cancellation_token) + .await + .map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?; while let Some(message_result) = stream.next().await { match message_result { Ok(AgentEvent::Message(msg)) => conversation.push(msg), diff --git a/crates/goose/src/agents/todo_extension.rs b/crates/goose/src/agents/todo_extension.rs index 7be44ba9bdeb..7aa3ccb49211 100644 --- a/crates/goose/src/agents/todo_extension.rs +++ b/crates/goose/src/agents/todo_extension.rs @@ -1,5 +1,5 @@ use crate::agents::extension::PlatformExtensionContext; -use crate::agents::mcp_client::{Error, McpClientTrait, McpMeta}; +use crate::agents::mcp_client::{Error, McpClientTrait}; use crate::session::extension_data; use crate::session::extension_data::ExtensionState; use anyhow::Result; @@ -158,6 +158,7 @@ impl TodoClient { impl McpClientTrait for TodoClient { async fn list_tools( &self, + _session_id: &str, _next_cursor: Option, _cancellation_token: CancellationToken, ) -> Result { @@ -170,12 +171,11 @@ impl McpClientTrait for TodoClient { async fn call_tool( &self, + session_id: &str, name: &str, arguments: Option, - meta: McpMeta, _cancellation_token: CancellationToken, ) -> Result { - let session_id = &meta.session_id; let content = match name { "todo_write" => self.handle_write_todo(session_id, arguments).await, _ => Err(format!("Unknown tool: {}", name)), diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index b2e8342e00c7..b6f6e095ec80 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -40,6 +40,7 @@ struct SummarizeContext { /// /// # Arguments /// * `provider` - The provider to use for summarization +/// * `session_id` - The session to use for summarization /// * `conversation` - The current conversation history /// * `manual_compact` - If true, this is a manual compaction (don't preserve user message) /// @@ -49,6 +50,7 @@ struct SummarizeContext { /// - `ProviderUsage`: Provider usage from summarization pub async fn compact_messages( provider: &dyn Provider, + session_id: &str, conversation: &Conversation, manual_compact: bool, ) -> Result<(Conversation, ProviderUsage)> { @@ -110,7 +112,8 @@ pub async fn compact_messages( let messages_to_compact = messages.as_slice(); - let (summary_message, summarization_usage) = do_compact(provider, messages_to_compact).await?; + let (summary_message, summarization_usage) = + do_compact(provider, session_id, messages_to_compact).await?; // Create the final message list with updated visibility metadata: // 1. Original messages become user_visible but not agent_visible @@ -271,6 +274,7 @@ fn filter_tool_responses<'a>(messages: &[&'a Message], remove_percent: u32) -> V async fn do_compact( provider: &dyn Provider, + session_id: &str, messages: &[Message], ) -> Result<(Message, ProviderUsage), anyhow::Error> { let agent_visible_messages: Vec<&Message> = messages @@ -301,7 +305,7 @@ async fn do_compact( let summarization_request = vec![user_message]; match provider - .complete_fast(&system_prompt, &summarization_request, &[]) + .complete_fast(session_id, &system_prompt, &summarization_request, &[]) .await { Ok((mut response, mut provider_usage)) => { @@ -468,6 +472,7 @@ mod tests { async fn complete_with_model( &self, + _session_id: &str, _model_config: &ModelConfig, _system: &str, messages: &[Message], @@ -529,9 +534,10 @@ mod tests { ]; let conversation = Conversation::new_unvalidated(basic_conversation); - let (compacted_conversation, _usage) = compact_messages(&provider, &conversation, false) - .await - .unwrap(); + let (compacted_conversation, _usage) = + compact_messages(&provider, "test-session-id", &conversation, false) + .await + .unwrap(); let agent_conversation = compacted_conversation.agent_visible_messages(); @@ -568,7 +574,7 @@ mod tests { } let conversation = Conversation::new_unvalidated(messages); - let result = compact_messages(&provider, &conversation, false).await; + let result = compact_messages(&provider, "test-session-id", &conversation, false).await; // Should succeed after progressive removal assert!( diff --git a/crates/goose/src/goose_apps/mod.rs b/crates/goose/src/goose_apps/mod.rs index b93bf27aee08..298e5f4b4dd8 100644 --- a/crates/goose/src/goose_apps/mod.rs +++ b/crates/goose/src/goose_apps/mod.rs @@ -131,14 +131,20 @@ impl McpAppCache { pub async fn fetch_mcp_apps( extension_manager: &ExtensionManager, + session_id: &str, ) -> Result, ErrorData> { let mut apps = Vec::new(); - let ui_resources = extension_manager.get_ui_resources().await?; + let ui_resources = extension_manager.get_ui_resources(session_id).await?; for (extension_name, resource) in ui_resources { match extension_manager - .read_resource(&resource.uri, &extension_name, CancellationToken::default()) + .read_resource( + session_id, + &resource.uri, + &extension_name, + CancellationToken::default(), + ) .await { Ok(read_result) => { diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index b134a1863f14..99e415e82967 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -131,6 +131,7 @@ fn extract_read_only_tools(response: &Message) -> Option> { /// Executes the read-only tools detection and returns the list of tools with read-only operations. pub async fn detect_read_only_tools( provider: Arc, + session_id: &str, tool_requests: Vec<&ToolRequest>, ) -> Vec { if tool_requests.is_empty() { @@ -145,6 +146,7 @@ pub async fn detect_read_only_tools( let res = provider .complete( + session_id, &system_prompt, check_messages.messages(), std::slice::from_ref(&tool), @@ -168,6 +170,7 @@ pub struct PermissionCheckResult { } pub async fn check_tool_permissions( + session_id: &str, candidate_requests: &[ToolRequest], mode: &str, tools_with_readonly_annotation: HashSet, @@ -238,7 +241,8 @@ pub async fn check_tool_permissions( // 3. LLM detect if !llm_detect_candidates.is_empty() && mode == "smart_approve" { let detected_readonly_tools = - detect_read_only_tools(provider, llm_detect_candidates.iter().collect()).await; + detect_read_only_tools(provider, session_id, llm_detect_candidates.iter().collect()) + .await; for request in llm_detect_candidates { if let Ok(tool_call) = request.tool_call.clone() { if detected_readonly_tools.contains(&tool_call.name.to_string()) { diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 5f2976af4e05..ff5a8c6b6e5a 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -116,8 +116,8 @@ impl AnthropicProvider { headers } - async fn post(&self, payload: &Value) -> Result { - let mut request = self.api_client.request("v1/messages"); + async fn post(&self, session_id: &str, payload: &Value) -> Result { + let mut request = self.api_client.request(session_id, "v1/messages"); for (key, value) in self.get_conditional_headers() { request = request.header(key, value)?; @@ -198,6 +198,7 @@ impl Provider for AnthropicProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -206,7 +207,7 @@ impl Provider for AnthropicProvider { let payload = create_request(model_config, system, messages, tools)?; let response = self - .with_retry(|| async { self.post(&payload).await }) + .with_retry(|| async { self.post(session_id, &payload).await }) .await?; let json_response = Self::anthropic_api_call_result(response)?; @@ -227,8 +228,11 @@ impl Provider for AnthropicProvider { Ok((message, provider_usage)) } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { - let response = self.api_client.api_get("v1/models").await?; + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { + let response = self.api_client.api_get(session_id, "v1/models").await?; if response.status != StatusCode::OK { return Err(map_http_error_to_provider_error( @@ -253,6 +257,7 @@ impl Provider for AnthropicProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -263,7 +268,7 @@ impl Provider for AnthropicProvider { .unwrap() .insert("stream".to_string(), Value::Bool(true)); - let mut request = self.api_client.request("v1/messages"); + let mut request = self.api_client.request(session_id, "v1/messages"); let mut log = RequestLog::start(&self.model, &payload)?; for (key, value) in self.get_conditional_headers() { diff --git a/crates/goose/src/providers/api_client.rs b/crates/goose/src/providers/api_client.rs index 449a74e2086c..541275285440 100644 --- a/crates/goose/src/providers/api_client.rs +++ b/crates/goose/src/providers/api_client.rs @@ -196,6 +196,7 @@ pub struct ApiRequestBuilder<'a> { client: &'a ApiClient, path: &'a str, headers: HeaderMap, + session_id: &'a str, } impl ApiClient { @@ -272,28 +273,39 @@ impl ApiClient { Ok(self) } - pub fn request<'a>(&'a self, path: &'a str) -> ApiRequestBuilder<'a> { + pub fn request<'a>(&'a self, session_id: &'a str, path: &'a str) -> ApiRequestBuilder<'a> { ApiRequestBuilder { client: self, + session_id, path, headers: HeaderMap::new(), } } - pub async fn api_post(&self, path: &str, payload: &Value) -> Result { - self.request(path).api_post(payload).await + pub async fn api_post( + &self, + session_id: &str, + path: &str, + payload: &Value, + ) -> Result { + self.request(session_id, path).api_post(payload).await } - pub async fn response_post(&self, path: &str, payload: &Value) -> Result { - self.request(path).response_post(payload).await + pub async fn response_post( + &self, + session_id: &str, + path: &str, + payload: &Value, + ) -> Result { + self.request(session_id, path).response_post(payload).await } - pub async fn api_get(&self, path: &str) -> Result { - self.request(path).api_get().await + pub async fn api_get(&self, session_id: &str, path: &str) -> Result { + self.request(session_id, path).api_get().await } - pub async fn response_get(&self, path: &str) -> Result { - self.request(path).response_get().await + pub async fn response_get(&self, session_id: &str, path: &str) -> Result { + self.request(session_id, path).response_get().await } fn build_url(&self, path: &str) -> Result { @@ -370,9 +382,7 @@ impl<'a> ApiRequestBuilder<'a> { let mut request = request_builder(url, &self.client.client); request = request.headers(self.headers.clone()); - if let Some(session_id) = crate::session_context::current_session_id() { - request = request.header(SESSION_ID_HEADER, session_id); - } + request = request.header(SESSION_ID_HEADER, self.session_id); request = match &self.client.auth { AuthMethod::BearerToken(token) => { @@ -416,35 +426,53 @@ mod tests { ) .unwrap(); - // Execute request within session context - crate::session_context::with_session_id(Some("test-session-456".to_string()), async { - let builder = client.request("/test"); - let request = builder - .send_request(|url, client| client.get(url)) - .await - .unwrap(); - - let headers = request.build().unwrap().headers().clone(); - - assert!(headers.contains_key(SESSION_ID_HEADER)); - assert_eq!( - headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(), - "test-session-456" - ); - }) - .await; + let builder = client.request("test-session_id-456", "/test"); + let request = builder + .send_request(|url, client| client.get(url)) + .await + .unwrap(); + + let headers = request.build().unwrap().headers().clone(); + + assert!(headers.contains_key(SESSION_ID_HEADER)); + assert_eq!( + headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(), + "test-session_id-456" + ); + } + + #[tokio::test] + async fn test_session_id_header_with_different_id() { + let client = ApiClient::new( + "http://localhost:8080".to_string(), + AuthMethod::BearerToken("test-token".to_string()), + ) + .unwrap(); + + let builder = client.request("another-session_id-789", "/test"); + let request = builder + .send_request(|url, client| client.get(url)) + .await + .unwrap(); + + let headers = request.build().unwrap().headers().clone(); + + assert!(headers.contains_key(SESSION_ID_HEADER)); + assert_eq!( + headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(), + "another-session_id-789" + ); } #[tokio::test] - async fn test_no_session_id_header_when_absent() { + async fn test_session_id_header_always_present() { let client = ApiClient::new( "http://localhost:8080".to_string(), AuthMethod::BearerToken("test-token".to_string()), ) .unwrap(); - // Build a request without session context - let builder = client.request("/test"); + let builder = client.request("required-session_id", "/test"); let request = builder .send_request(|url, client| client.get(url)) .await @@ -452,6 +480,6 @@ mod tests { let headers = request.build().unwrap().headers().clone(); - assert!(!headers.contains_key(SESSION_ID_HEADER)); + assert!(headers.contains_key(SESSION_ID_HEADER)); } } diff --git a/crates/goose/src/providers/auto_detect.rs b/crates/goose/src/providers/auto_detect.rs index 0513fd928be7..36236b634f31 100644 --- a/crates/goose/src/providers/auto_detect.rs +++ b/crates/goose/src/providers/auto_detect.rs @@ -1,7 +1,10 @@ use crate::model::ModelConfig; use crate::providers::retry::{retry_operation, RetryConfig}; -pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec)> { +pub async fn detect_provider_from_api_key( + session_id: &str, + api_key: &str, +) -> Option<(String, Vec)> { let provider_tests = vec![ ("anthropic", "ANTHROPIC_API_KEY"), ("openai", "OPENAI_API_KEY"), @@ -15,6 +18,7 @@ pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec< .into_iter() .map(|(provider_name, env_key)| { let api_key = api_key.to_string(); + let session_id = session_id.to_string(); tokio::spawn(async move { let original_value = std::env::var(env_key).ok(); std::env::set_var(env_key, &api_key); @@ -27,7 +31,7 @@ pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec< { Ok(provider) => { match retry_operation(&RetryConfig::default(), || async { - provider.fetch_supported_models().await + provider.fetch_supported_models(&session_id).await }) .await { diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index d26d3196bd75..1e307e5c0781 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -99,14 +99,17 @@ impl AzureProvider { }) } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { // Build the path for Azure OpenAI let path = format!( "openai/deployments/{}/chat/completions?api-version={}", self.deployment_name, self.api_version ); - let response = self.api_client.response_post(&path, payload).await?; + let response = self + .api_client + .response_post(session_id, &path, payload) + .await?; handle_response_openai_compat(response).await } } @@ -144,6 +147,7 @@ impl Provider for AzureProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -160,7 +164,7 @@ impl Provider for AzureProvider { let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await?; diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 86c69828c20c..c1ee2bcb75cd 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -370,6 +370,7 @@ pub trait Provider: Send + Sync { // Providers should override this to implement their actual completion logic async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -379,18 +380,20 @@ pub trait Provider: Send + Sync { // Default implementation: use the provider's configured model async fn complete( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let model_config = self.get_model_config(); - self.complete_with_model(&model_config, system, messages, tools) + self.complete_with_model(session_id, &model_config, system, messages, tools) .await } // Check if a fast model is configured, otherwise fall back to regular model async fn complete_fast( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -399,7 +402,7 @@ pub trait Provider: Send + Sync { let fast_config = model_config.use_fast_model(); match self - .complete_with_model(&fast_config, system, messages, tools) + .complete_with_model(session_id, &fast_config, system, messages, tools) .await { Ok(result) => Ok(result), @@ -411,7 +414,7 @@ pub trait Provider: Send + Sync { e, model_config.model_name ); - self.complete_with_model(&model_config, system, messages, tools) + self.complete_with_model(session_id, &model_config, system, messages, tools) .await } else { Err(e) @@ -427,13 +430,19 @@ pub trait Provider: Send + Sync { RetryConfig::default() } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + _session_id: &str, + ) -> Result>, ProviderError> { Ok(None) } /// Fetch models filtered by canonical registry and usability - async fn fetch_recommended_models(&self) -> Result>, ProviderError> { - let all_models = match self.fetch_supported_models().await? { + async fn fetch_recommended_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { + let all_models = match self.fetch_supported_models(session_id).await? { Some(models) => models, None => return Ok(None), }; @@ -481,12 +490,16 @@ pub trait Provider: Send + Sync { false } - async fn supports_cache_control(&self) -> bool { + async fn supports_cache_control(&self, _session_id: &str) -> bool { false } /// Create embeddings if supported. Default implementation returns an error. - async fn create_embeddings(&self, _texts: Vec) -> Result>, ProviderError> { + async fn create_embeddings( + &self, + _session_id: &str, + _texts: Vec, + ) -> Result>, ProviderError> { Err(ProviderError::ExecutionError( "This provider does not support embeddings".to_string(), )) @@ -500,6 +513,7 @@ pub trait Provider: Send + Sync { async fn stream( &self, + _session_id: &str, _system: &str, _messages: &[Message], _tools: &[Tool], @@ -538,6 +552,7 @@ pub trait Provider: Send + Sync { /// Creates a prompt asking for a concise description in 4 words or less. async fn generate_session_name( &self, + session_id: &str, messages: &Conversation, ) -> Result { let context = self.get_initial_user_messages(messages); @@ -545,6 +560,7 @@ pub trait Provider: Send + Sync { let message = Message::user().with_text(&prompt); let result = self .complete_fast( + session_id, "Reply with only a description in four words or less", &[message], &[], diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index d09ec44291ef..174a534edbbb 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -227,6 +227,7 @@ impl Provider for BedrockProvider { )] async fn complete_with_model( &self, + _session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/canonical/build_canonical_models.rs b/crates/goose/src/providers/canonical/build_canonical_models.rs index c3a6cb1db783..e8f0dcaf16e6 100644 --- a/crates/goose/src/providers/canonical/build_canonical_models.rs +++ b/crates/goose/src/providers/canonical/build_canonical_models.rs @@ -550,7 +550,9 @@ async fn check_provider( } }; - let fetched_models = match provider.fetch_supported_models().await { + // Provider probe runs outside any user session; use an ephemeral id. + let session_id = uuid::Uuid::new_v4().to_string(); + let fetched_models = match provider.fetch_supported_models(&session_id).await { Ok(Some(models)) => { println!(" ✓ Fetched {} models", models.len()); models diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 6c148aaa1fa7..070fee48bd7c 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -417,6 +417,7 @@ impl Provider for ClaudeCodeProvider { )] async fn complete_with_model( &self, + _session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index abfa54c299d8..f94c1484e18d 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -507,6 +507,7 @@ impl Provider for CodexProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index ffe7c1058de6..3f7af7002280 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -352,6 +352,7 @@ impl Provider for CursorAgentProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 6baa686d29c8..30cd7c0a987a 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -145,24 +145,27 @@ impl DatabricksProvider { }; // Check if the default fast model exists in the workspace - let model_with_fast = if let Ok(Some(models)) = provider.fetch_supported_models().await { - if models.contains(&DATABRICKS_DEFAULT_FAST_MODEL.to_string()) { - tracing::debug!( - "Found {} in Databricks workspace, setting as fast model", - DATABRICKS_DEFAULT_FAST_MODEL - ); - model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL.to_string()) + // Generate UUID for this initialization request since no user session exists yet + let session_id = uuid::Uuid::new_v4().to_string(); + let model_with_fast = + if let Ok(Some(models)) = provider.fetch_supported_models(&session_id).await { + if models.contains(&DATABRICKS_DEFAULT_FAST_MODEL.to_string()) { + tracing::debug!( + "Found {} in Databricks workspace, setting as fast model", + DATABRICKS_DEFAULT_FAST_MODEL + ); + model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL.to_string()) + } else { + tracing::debug!( + "{} not found in Databricks workspace, not setting fast model", + DATABRICKS_DEFAULT_FAST_MODEL + ); + model + } } else { - tracing::debug!( - "{} not found in Databricks workspace, not setting fast model", - DATABRICKS_DEFAULT_FAST_MODEL - ); + tracing::debug!("Could not fetch Databricks models, not setting fast model"); model - } - } else { - tracing::debug!("Could not fetch Databricks models, not setting fast model"); - model - }; + }; provider.model = model_with_fast; Ok(provider) @@ -226,12 +229,20 @@ impl DatabricksProvider { } } - async fn post(&self, payload: Value, model_name: Option<&str>) -> Result { + async fn post( + &self, + session_id: &str, + payload: Value, + model_name: Option<&str>, + ) -> Result { let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none(); let model_to_use = model_name.unwrap_or(&self.model.model_name); let path = self.get_endpoint_path(model_to_use, is_embedding); - let response = self.api_client.response_post(&path, &payload).await?; + let response = self + .api_client + .response_post(session_id, &path, &payload) + .await?; handle_response_openai_compat(response).await } } @@ -271,6 +282,7 @@ impl Provider for DatabricksProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -286,7 +298,7 @@ impl Provider for DatabricksProvider { let mut log = RequestLog::start(&self.model, &payload)?; let response = self - .with_retry(|| self.post(payload.clone(), Some(&model_config.model_name))) + .with_retry(|| self.post(session_id, payload.clone(), Some(&model_config.model_name))) .await?; let message = response_to_message(&response)?; @@ -302,6 +314,7 @@ impl Provider for DatabricksProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -324,7 +337,10 @@ impl Provider for DatabricksProvider { let mut log = RequestLog::start(&self.model, &payload)?; let response = self .with_retry(|| async { - let resp = self.api_client.response_post(&path, &payload).await?; + let resp = self + .api_client + .response_post(session_id, &path, &payload) + .await?; if !resp.status().is_success() { let status = resp.status(); let error_text = resp.text().await.unwrap_or_default(); @@ -351,16 +367,23 @@ impl Provider for DatabricksProvider { true } - async fn create_embeddings(&self, texts: Vec) -> Result>, ProviderError> { - EmbeddingCapable::create_embeddings(self, texts) + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>, ProviderError> { + EmbeddingCapable::create_embeddings(self, session_id, texts) .await .map_err(|e| ProviderError::ExecutionError(e.to_string())) } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { let response = match self .api_client - .response_get("api/2.0/serving-endpoints") + .response_get(session_id, "api/2.0/serving-endpoints") .await { Ok(resp) => resp, @@ -422,7 +445,11 @@ impl Provider for DatabricksProvider { #[async_trait] impl EmbeddingCapable for DatabricksProvider { - async fn create_embeddings(&self, texts: Vec) -> Result>> { + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>> { if texts.is_empty() { return Ok(vec![]); } @@ -431,7 +458,9 @@ impl EmbeddingCapable for DatabricksProvider { "input": texts, }); - let response = self.with_retry(|| self.post(request.clone(), None)).await?; + let response = self + .with_retry(|| self.post(session_id, request.clone(), None)) + .await?; let embeddings = response["data"] .as_array() diff --git a/crates/goose/src/providers/embedding.rs b/crates/goose/src/providers/embedding.rs index 469d22aeb57e..3f956b11c020 100644 --- a/crates/goose/src/providers/embedding.rs +++ b/crates/goose/src/providers/embedding.rs @@ -20,5 +20,9 @@ pub struct EmbeddingData { #[async_trait] pub trait EmbeddingCapable { - async fn create_embeddings(&self, texts: Vec) -> Result>>; + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>>; } diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 3ae7b1e2d8c7..bba880fbf69f 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -585,6 +585,7 @@ impl Provider for GcpVertexAIProvider { )] async fn complete_with_model( &self, + _session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -618,6 +619,7 @@ impl Provider for GcpVertexAIProvider { async fn stream( &self, + _session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -662,7 +664,10 @@ impl Provider for GcpVertexAIProvider { })) } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + _session_id: &str, + ) -> Result>, ProviderError> { let models: Vec = KNOWN_MODELS.iter().map(|s| s.to_string()).collect(); let filtered = self.filter_by_org_policy(models).await; Ok(Some(filtered)) diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 40e111a31f03..c13302ca28e9 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -264,6 +264,7 @@ impl Provider for GeminiCliProvider { )] async fn complete_with_model( &self, + session_id: &str, _model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 0f745e7244e8..8f90d0458bb2 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -169,7 +169,7 @@ impl GithubCopilotProvider { }) } - async fn post(&self, payload: &mut Value) -> Result { + async fn post(&self, session_id: &str, payload: &mut Value) -> Result { let (endpoint, token) = self.get_api_info().await?; let auth = AuthMethod::BearerToken(token); let mut headers = self.get_github_headers(); @@ -179,7 +179,7 @@ impl GithubCopilotProvider { let api_client = ApiClient::new(endpoint.clone(), auth)?.with_headers(headers)?; api_client - .response_post("chat/completions", payload) + .response_post(session_id, "chat/completions", payload) .await .map_err(|e| e.into()) } @@ -411,6 +411,7 @@ impl Provider for GithubCopilotProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -430,7 +431,7 @@ impl Provider for GithubCopilotProvider { let response = self .with_retry(|| async { let mut payload_clone = payload.clone(); - self.post(&mut payload_clone).await + self.post(session_id, &mut payload_clone).await }) .await?; let response = handle_response_openai_compat(response).await?; @@ -450,6 +451,7 @@ impl Provider for GithubCopilotProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -467,7 +469,7 @@ impl Provider for GithubCopilotProvider { let response = self .with_retry(|| async { let mut payload_clone = payload.clone(); - let resp = self.post(&mut payload_clone).await?; + let resp = self.post(session_id, &mut payload_clone).await?; handle_status_openai_compat(resp).await }) .await @@ -478,7 +480,10 @@ impl Provider for GithubCopilotProvider { stream_openai_compat(response, log) } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + _session_id: &str, + ) -> Result>, ProviderError> { let (endpoint, token) = self.get_api_info().await?; let url = format!("{}/models", endpoint); diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index eb907313fef6..b091cf2dd2ab 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -91,19 +91,31 @@ impl GoogleProvider { }) } - async fn post(&self, model_name: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: &str, + model_name: &str, + payload: &Value, + ) -> Result { let path = format!("v1beta/models/{}:generateContent", model_name); - let response = self.api_client.response_post(&path, payload).await?; + let response = self + .api_client + .response_post(session_id, &path, payload) + .await?; handle_response_google_compat(response).await } async fn post_stream( &self, + session_id: &str, model_name: &str, payload: &Value, ) -> Result { let path = format!("v1beta/models/{}:streamGenerateContent?alt=sse", model_name); - let response = self.api_client.response_post(&path, payload).await?; + let response = self + .api_client + .response_post(session_id, &path, payload) + .await?; handle_status_openai_compat(response).await } } @@ -139,6 +151,7 @@ impl Provider for GoogleProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -148,7 +161,10 @@ impl Provider for GoogleProvider { let mut log = RequestLog::start(model_config, &payload)?; let response = self - .with_retry(|| async { self.post(&model_config.model_name, &payload).await }) + .with_retry(|| async { + self.post(session_id, &model_config.model_name, &payload) + .await + }) .await?; let message = response_to_message(unescape_json_values(&response))?; @@ -162,8 +178,14 @@ impl Provider for GoogleProvider { Ok((message, provider_usage)) } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { - let response = self.api_client.response_get("v1beta/models").await?; + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { + let response = self + .api_client + .response_get(session_id, "v1beta/models") + .await?; let json: serde_json::Value = response.json().await?; let arr = match json.get("models").and_then(|v| v.as_array()) { Some(arr) => arr, @@ -184,6 +206,7 @@ impl Provider for GoogleProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -192,7 +215,10 @@ impl Provider for GoogleProvider { let mut log = RequestLog::start(&self.model, &payload)?; let response = self - .with_retry(|| async { self.post_stream(&self.model.model_name, &payload).await }) + .with_retry(|| async { + self.post_stream(session_id, &self.model.model_name, &payload) + .await + }) .await .inspect_err(|e| { let _ = log.error(e); diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 9cc44347c70b..15b98ef48922 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -342,6 +342,7 @@ impl Provider for LeadWorkerProvider { async fn complete_with_model( &self, + session_id: &str, _model_config: &ModelConfig, system: &str, messages: &[Message], @@ -392,7 +393,7 @@ impl Provider for LeadWorkerProvider { } // Make the completion request - let result = provider.complete(system, messages, tools).await; + let result = provider.complete(session_id, system, messages, tools).await; // For technical failures, try with default model (lead provider) instead let final_result = match &result { @@ -400,7 +401,10 @@ impl Provider for LeadWorkerProvider { tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type); // Try with lead provider as the default/fallback for technical failures - let default_result = self.lead_provider.complete(system, messages, tools).await; + let default_result = self + .lead_provider + .complete(session_id, system, messages, tools) + .await; match &default_result { Ok(_) => { @@ -424,10 +428,19 @@ impl Provider for LeadWorkerProvider { final_result } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { // Combine models from both providers - let lead_models = self.lead_provider.fetch_supported_models().await?; - let worker_models = self.worker_provider.fetch_supported_models().await?; + let lead_models = self + .lead_provider + .fetch_supported_models(session_id) + .await?; + let worker_models = self + .worker_provider + .fetch_supported_models(session_id) + .await?; match (lead_models, worker_models) { (Some(lead), Some(worker)) => { @@ -447,12 +460,20 @@ impl Provider for LeadWorkerProvider { self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings() } - async fn create_embeddings(&self, texts: Vec) -> Result>, ProviderError> { + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>, ProviderError> { // Use the lead provider for embeddings if it supports them, otherwise use worker if self.lead_provider.supports_embeddings() { - self.lead_provider.create_embeddings(texts).await + self.lead_provider + .create_embeddings(session_id, texts) + .await } else if self.worker_provider.supports_embeddings() { - self.worker_provider.create_embeddings(texts).await + self.worker_provider + .create_embeddings(session_id, texts) + .await } else { Err(ProviderError::ExecutionError( "Neither lead nor worker provider supports embeddings".to_string(), @@ -496,6 +517,7 @@ mod tests { async fn complete_with_model( &self, + _session_id: &str, _model_config: &ModelConfig, _system: &str, _messages: &[Message], @@ -534,7 +556,10 @@ mod tests { // First three turns should use lead provider for i in 0..3 { - let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap(); + let (_message, usage) = provider + .complete("test-session-id", "system", &[], &[]) + .await + .unwrap(); assert_eq!(usage.model, "lead"); assert_eq!(provider.get_turn_count().await, i + 1); assert!(!provider.is_in_fallback_mode().await); @@ -542,7 +567,10 @@ mod tests { // Subsequent turns should use worker provider for i in 3..6 { - let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap(); + let (_message, usage) = provider + .complete("test-session-id", "system", &[], &[]) + .await + .unwrap(); assert_eq!(usage.model, "worker"); assert_eq!(provider.get_turn_count().await, i + 1); assert!(!provider.is_in_fallback_mode().await); @@ -554,7 +582,10 @@ mod tests { assert_eq!(provider.get_failure_count().await, 0); assert!(!provider.is_in_fallback_mode().await); - let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap(); + let (_message, usage) = provider + .complete("test-session-id", "system", &[], &[]) + .await + .unwrap(); assert_eq!(usage.model, "lead"); } @@ -576,21 +607,27 @@ mod tests { // First two turns use lead (should succeed) for _i in 0..2 { - let result = provider.complete("system", &[], &[]).await; + let result = provider + .complete("test-session-id", "system", &[], &[]) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); assert!(!provider.is_in_fallback_mode().await); } // Next turn uses worker (will fail, but should retry with lead and succeed) - let result = provider.complete("system", &[], &[]).await; + let result = provider + .complete("test-session-id", "system", &[], &[]) + .await; assert!(result.is_ok()); // Should succeed because lead provider is used as fallback assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode // Another turn - should still try worker first, then retry with lead - let result = provider.complete("system", &[], &[]).await; + let result = provider + .complete("test-session-id", "system", &[], &[]) + .await; assert!(result.is_ok()); // Should succeed because lead provider is used as fallback assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking @@ -627,13 +664,17 @@ mod tests { } // Should use lead provider in fallback mode - let result = provider.complete("system", &[], &[]).await; + let result = provider + .complete("test-session-id", "system", &[], &[]) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); assert!(provider.is_in_fallback_mode().await); // One more fallback turn - let result = provider.complete("system", &[], &[]).await; + let result = provider + .complete("test-session-id", "system", &[], &[]) + .await; assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode @@ -662,6 +703,7 @@ mod tests { async fn complete_with_model( &self, + _session_id: &str, _model_config: &ModelConfig, _system: &str, _messages: &[Message], diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index e62fea728960..0648ee9ce216 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -73,8 +73,8 @@ impl LiteLLMProvider { }) } - async fn fetch_models(&self) -> Result, ProviderError> { - let response = self.api_client.response_get("model/info").await?; + async fn fetch_models(&self, session: &str) -> Result, ProviderError> { + let response = self.api_client.response_get(session, "model/info").await?; if !response.status().is_success() { return Err(ProviderError::RequestFailed(format!( @@ -112,10 +112,10 @@ impl LiteLLMProvider { Ok(models) } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { let response = self .api_client - .response_post(&self.base_path, payload) + .response_post(session_id, &self.base_path, payload) .await?; handle_response_openai_compat(response).await } @@ -168,6 +168,7 @@ impl Provider for LiteLLMProvider { #[tracing::instrument(skip_all, name = "provider_complete")] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -182,14 +183,14 @@ impl Provider for LiteLLMProvider { false, )?; - if self.supports_cache_control().await { + if self.supports_cache_control(session_id).await { payload = update_request_for_cache_control(&payload); } let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await?; @@ -205,8 +206,8 @@ impl Provider for LiteLLMProvider { true } - async fn supports_cache_control(&self) -> bool { - if let Ok(models) = self.fetch_models().await { + async fn supports_cache_control(&self, session_id: &str) -> bool { + if let Ok(models) = self.fetch_models(session_id).await { if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) { return model_info.supports_cache_control.unwrap_or(false); } @@ -215,8 +216,11 @@ impl Provider for LiteLLMProvider { self.model.model_name.to_lowercase().contains("claude") } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { - match self.fetch_models().await { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { + match self.fetch_models(session_id).await { Ok(models) => { let model_names: Vec = models.into_iter().map(|m| m.name).collect(); Ok(Some(model_names)) @@ -231,7 +235,11 @@ impl Provider for LiteLLMProvider { #[async_trait] impl EmbeddingCapable for LiteLLMProvider { - async fn create_embeddings(&self, texts: Vec) -> Result>, anyhow::Error> { + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>, anyhow::Error> { let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL") .unwrap_or_else(|_| "text-embedding-3-small".to_string()); @@ -243,7 +251,7 @@ impl EmbeddingCapable for LiteLLMProvider { let response = self .api_client - .response_post("v1/embeddings", &payload) + .response_post(session_id, "v1/embeddings", &payload) .await?; let response_text = response.text().await?; let response_json: Value = serde_json::from_str(&response_text)?; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index be1d927ff9d8..3d3746bedfe9 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -119,10 +119,10 @@ impl OllamaProvider { }) } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { let response = self .api_client - .response_post("v1/chat/completions", payload) + .response_post(session_id, "v1/chat/completions", payload) .await?; handle_response_openai_compat(response).await } @@ -173,6 +173,7 @@ impl Provider for OllamaProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -199,7 +200,7 @@ impl Provider for OllamaProvider { let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await .inspect_err(|e| { @@ -219,12 +220,14 @@ impl Provider for OllamaProvider { async fn generate_session_name( &self, + session_id: &str, messages: &Conversation, ) -> Result { let context = self.get_initial_user_messages(messages); let message = Message::user().with_text(self.create_session_name_prompt(&context)); let result = self .complete( + session_id, "You are a title generator. Output only the requested title of 4 words or less, with no additional text, reasoning, or explanations.", &[message], &[], @@ -243,6 +246,7 @@ impl Provider for OllamaProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -269,7 +273,7 @@ impl Provider for OllamaProvider { .with_retry(|| async { let resp = self .api_client - .response_post("v1/chat/completions", &payload) + .response_post(session_id, "v1/chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) @@ -280,10 +284,13 @@ impl Provider for OllamaProvider { stream_openai_compat(response, log) } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { let response = self .api_client - .response_get("api/tags") + .response_get(session_id, "api/tags") .await .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 61e6afee25fe..20a069fb9c2f 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -192,18 +192,22 @@ impl OpenAiProvider { model_name.starts_with("gpt-5-codex") || model_name.starts_with("gpt-5.1-codex") } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { let response = self .api_client - .response_post(&self.base_path, payload) + .response_post(session_id, &self.base_path, payload) .await?; handle_response_openai_compat(response).await } - async fn post_responses(&self, payload: &Value) -> Result { + async fn post_responses( + &self, + session_id: &str, + payload: &Value, + ) -> Result { let response = self .api_client - .response_post("v1/responses", payload) + .response_post(session_id, "v1/responses", payload) .await?; handle_response_openai_compat(response).await } @@ -249,6 +253,7 @@ impl Provider for OpenAiProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -261,7 +266,7 @@ impl Provider for OpenAiProvider { let json_response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post_responses(&payload_clone).await + self.post_responses(session_id, &payload_clone).await }) .await .inspect_err(|e| { @@ -296,7 +301,7 @@ impl Provider for OpenAiProvider { let json_response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await .inspect_err(|e| { @@ -318,9 +323,15 @@ impl Provider for OpenAiProvider { } } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { let models_path = self.base_path.replace("v1/chat/completions", "v1/models"); - let response = self.api_client.response_get(&models_path).await?; + let response = self + .api_client + .response_get(session_id, &models_path) + .await?; let json = handle_response_openai_compat(response).await?; if let Some(err_obj) = json.get("error") { let msg = err_obj @@ -345,8 +356,12 @@ impl Provider for OpenAiProvider { true } - async fn create_embeddings(&self, texts: Vec) -> Result>, ProviderError> { - EmbeddingCapable::create_embeddings(self, texts) + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>, ProviderError> { + EmbeddingCapable::create_embeddings(self, session_id, texts) .await .map_err(|e| ProviderError::ExecutionError(e.to_string())) } @@ -357,6 +372,7 @@ impl Provider for OpenAiProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -372,7 +388,7 @@ impl Provider for OpenAiProvider { let payload_clone = payload.clone(); let resp = self .api_client - .response_post("v1/responses", &payload_clone) + .response_post(session_id, "v1/responses", &payload_clone) .await?; handle_status_openai_compat(resp).await }) @@ -410,7 +426,7 @@ impl Provider for OpenAiProvider { .with_retry(|| async { let resp = self .api_client - .response_post(&self.base_path, &payload) + .response_post(session_id, &self.base_path, &payload) .await?; handle_status_openai_compat(resp).await }) @@ -437,7 +453,11 @@ fn parse_custom_headers(s: String) -> HashMap { #[async_trait] impl EmbeddingCapable for OpenAiProvider { - async fn create_embeddings(&self, texts: Vec) -> Result>> { + async fn create_embeddings( + &self, + session_id: &str, + texts: Vec, + ) -> Result>> { if texts.is_empty() { return Ok(vec![]); } @@ -459,7 +479,7 @@ impl EmbeddingCapable for OpenAiProvider { let request_value = serde_json::to_value(request_clone) .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; self.api_client - .api_post("v1/embeddings", &request_value) + .api_post(session_id, "v1/embeddings", &request_value) .await .map_err(|e| ProviderError::ExecutionError(e.to_string())) }) diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 967659b36f9e..344206e315d7 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -69,10 +69,10 @@ impl OpenRouterProvider { }) } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { let response = self .api_client - .response_post("api/v1/chat/completions", payload) + .response_post(session_id, "api/v1/chat/completions", payload) .await?; let response_body = handle_response_openai_compat(response) @@ -189,6 +189,7 @@ fn is_gemini_model(model_name: &str) -> bool { async fn create_request_based_on_model( provider: &OpenRouterProvider, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -202,7 +203,7 @@ async fn create_request_based_on_model( false, )?; - if provider.supports_cache_control().await { + if provider.supports_cache_control(session_id).await { payload = update_request_for_anthropic(&payload); } @@ -253,18 +254,20 @@ impl Provider for OpenRouterProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request_based_on_model(self, system, messages, tools).await?; + let payload = + create_request_based_on_model(self, session_id, system, messages, tools).await?; let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await?; @@ -284,10 +287,17 @@ impl Provider for OpenRouterProvider { } /// Fetch supported models from OpenRouter API (only models with tool support) - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { // Handle request failures gracefully // If the request fails, fall back to manual entry - let response = match self.api_client.response_get("api/v1/models").await { + let response = match self + .api_client + .response_get(session_id, "api/v1/models") + .await + { Ok(response) => response, Err(e) => { tracing::warn!("Failed to fetch models from OpenRouter API: {}, falling back to manual model entry", e); @@ -360,7 +370,7 @@ impl Provider for OpenRouterProvider { Ok(Some(models)) } - async fn supports_cache_control(&self) -> bool { + async fn supports_cache_control(&self, _session_id: &str) -> bool { self.model .model_name .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) @@ -372,6 +382,7 @@ impl Provider for OpenRouterProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -385,7 +396,7 @@ impl Provider for OpenRouterProvider { true, )?; - if self.supports_cache_control().await { + if self.supports_cache_control(session_id).await { payload = update_request_for_anthropic(&payload); } @@ -403,7 +414,7 @@ impl Provider for OpenRouterProvider { .with_retry(|| async { let resp = self .api_client - .response_post("api/v1/chat/completions", &payload) + .response_post(session_id, "api/v1/chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) diff --git a/crates/goose/src/providers/provider_test.rs b/crates/goose/src/providers/provider_test.rs index 386ddc491875..990e562e9f9b 100644 --- a/crates/goose/src/providers/provider_test.rs +++ b/crates/goose/src/providers/provider_test.rs @@ -27,9 +27,10 @@ pub async fn test_provider_configuration( let _result = provider .complete( + "test-session-id", "You are an AI agent called goose. You use tools of connected extensions to solve problems.", &messages, - &tools.into_iter().collect::>() + &tools.into_iter().collect::>(), ) .await?; diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 74dbff3fa683..6cbabdbbdd63 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -289,6 +289,7 @@ impl Provider for SageMakerTgiProvider { )] async fn complete_with_model( &self, + _session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 7176d59d2055..0e0e452b0ca8 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -107,10 +107,10 @@ impl SnowflakeProvider { }) } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { let response = self .api_client - .response_post("api/v2/cortex/inference:complete", payload) + .response_post(session_id, "api/v2/cortex/inference:complete", payload) .await?; let status = response.status(); @@ -319,6 +319,7 @@ impl Provider for SnowflakeProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -331,7 +332,7 @@ impl Provider for SnowflakeProvider { let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await?; diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index c9e455bd69d3..8415cc4d0a71 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -121,6 +121,7 @@ impl Provider for TestProvider { async fn complete_with_model( &self, + session_id: &str, _model_config: &ModelConfig, system: &str, messages: &[Message], @@ -129,7 +130,7 @@ impl Provider for TestProvider { let hash = Self::hash_input(messages); if let Some(inner) = &self.inner { - let (message, usage) = inner.complete(system, messages, tools).await?; + let (message, usage) = inner.complete(session_id, system, messages, tools).await?; let record = TestRecord { input: TestInput { @@ -202,6 +203,7 @@ mod tests { async fn complete_with_model( &self, + _session_id: &str, _model_config: &ModelConfig, _system: &str, _messages: &[Message], @@ -244,7 +246,9 @@ mod tests { { let test_provider = TestProvider::new_recording(mock, &temp_file); - let result = test_provider.complete("You are helpful", &[], &[]).await; + let result = test_provider + .complete("test-session-id", "You are helpful", &[], &[]) + .await; assert!(result.is_ok()); let (message, _) = result.unwrap(); @@ -260,7 +264,9 @@ mod tests { { let replay_provider = TestProvider::new_replaying(&temp_file).unwrap(); - let result = replay_provider.complete("You are helpful", &[], &[]).await; + let result = replay_provider + .complete("test-session-id", "You are helpful", &[], &[]) + .await; assert!(result.is_ok()); let (message, _) = result.unwrap(); @@ -284,7 +290,7 @@ mod tests { let replay_provider = TestProvider::new_replaying(&temp_file).unwrap(); let result = replay_provider - .complete("Different system prompt", &[], &[]) + .complete("test-session-id", "Different system prompt", &[], &[]) .await; assert!(result.is_err()); diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index ab3e7af0caf9..73d753d5e72b 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -63,10 +63,10 @@ impl TetrateProvider { }) } - async fn post(&self, payload: &Value) -> Result { + async fn post(&self, session_id: &str, payload: &Value) -> Result { let response = self .api_client - .response_post("v1/chat/completions", payload) + .response_post(session_id, "v1/chat/completions", payload) .await?; // Handle Google-compatible model responses differently @@ -158,6 +158,7 @@ impl Provider for TetrateProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -177,7 +178,7 @@ impl Provider for TetrateProvider { let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post(&payload_clone).await + self.post(session_id, &payload_clone).await }) .await?; @@ -194,6 +195,7 @@ impl Provider for TetrateProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -213,7 +215,7 @@ impl Provider for TetrateProvider { .with_retry(|| async { let resp = self .api_client - .response_post("v1/chat/completions", &payload) + .response_post(session_id, "v1/chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) @@ -226,9 +228,12 @@ impl Provider for TetrateProvider { } /// Fetch supported models from Tetrate Agent Router Service API (only models with tool support) - async fn fetch_supported_models(&self) -> Result>, ProviderError> { + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { // Use the existing api_client which already has authentication configured - let response = match self.api_client.response_get("v1/models").await { + let response = match self.api_client.response_get(session_id, "v1/models").await { Ok(response) => response, Err(e) => { tracing::warn!("Failed to fetch models from Tetrate Agent Router Service API: {}, falling back to manual model entry", e); diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 9f95c298343c..1ca28cba6090 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -113,8 +113,16 @@ impl VeniceProvider { Ok(instance) } - async fn post(&self, path: &str, payload: &Value) -> Result { - let response = self.api_client.response_post(path, payload).await?; + async fn post( + &self, + session_id: &str, + path: &str, + payload: &Value, + ) -> Result { + let response = self + .api_client + .response_post(session_id, path, payload) + .await?; let status = response.status(); tracing::debug!("Venice response status: {}", status); @@ -221,8 +229,14 @@ impl Provider for VeniceProvider { self.model.clone() } - async fn fetch_supported_models(&self) -> Result>, ProviderError> { - let response = self.api_client.response_get(&self.models_path).await?; + async fn fetch_supported_models( + &self, + session_id: &str, + ) -> Result>, ProviderError> { + let response = self + .api_client + .response_get(session_id, &self.models_path) + .await?; let json: serde_json::Value = response.json().await?; let mut models = json["data"] @@ -251,6 +265,7 @@ impl Provider for VeniceProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -431,7 +446,7 @@ impl Provider for VeniceProvider { // Send request with retry let response = self - .with_retry(|| self.post(&self.base_path, &payload)) + .with_retry(|| self.post(session_id, &self.base_path, &payload)) .await?; // Parse the response - response is already a Value from our post method diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index b151aaf18434..38d980b5fb14 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -69,10 +69,10 @@ impl XaiProvider { }) } - async fn post(&self, payload: Value) -> Result { + async fn post(&self, session_id: &str, payload: Value) -> Result { let response = self .api_client - .response_post("chat/completions", &payload) + .response_post(session_id, "chat/completions", &payload) .await?; handle_response_openai_compat(response).await @@ -110,6 +110,7 @@ impl Provider for XaiProvider { )] async fn complete_with_model( &self, + session_id: &str, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -125,7 +126,9 @@ impl Provider for XaiProvider { )?; let mut log = RequestLog::start(&self.model, &payload)?; - let response = self.with_retry(|| self.post(payload.clone())).await?; + let response = self + .with_retry(|| self.post(session_id, payload.clone())) + .await?; let message = response_to_message(&response)?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { @@ -143,6 +146,7 @@ impl Provider for XaiProvider { async fn stream( &self, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -161,7 +165,7 @@ impl Provider for XaiProvider { .with_retry(|| async { let resp = self .api_client - .response_post("chat/completions", &payload) + .response_post(session_id, "chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index b4520d80d2a5..c3d2a46a0ab7 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -794,13 +794,9 @@ async fn execute_job( retry_config: None, }; - let session_id = session_config.id.clone(); - let stream = crate::session_context::with_session_id(Some(session_id.clone()), async { - agent - .reply(user_message, session_config, Some(cancel_token)) - .await - }) - .await?; + let stream = agent + .reply(user_message, session_config, Some(cancel_token)) + .await?; use futures::StreamExt; let mut stream = std::pin::pin!(stream); diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index ce5bf1d9b5d0..7523ccc33720 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -344,7 +344,7 @@ impl SessionManager { .count(); if user_message_count <= MSG_COUNT_FOR_SESSION_NAME_GENERATION { - let name = provider.generate_session_name(&conversation).await?; + let name = provider.generate_session_name(id, &conversation).await?; self.update(id).system_generated_name(name).apply().await } else { Ok(()) diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 9ba0a903c4ec..356010d59b2f 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -355,6 +355,7 @@ mod tests { impl Provider for MockToolProvider { async fn complete( &self, + _session_id: &str, _system_prompt: &str, _messages: &[Message], _tools: &[Tool], @@ -376,12 +377,14 @@ mod tests { async fn complete_with_model( &self, + session_id: &str, _model_config: &ModelConfig, system_prompt: &str, messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { - self.complete(system_prompt, messages, tools).await + self.complete(session_id, system_prompt, messages, tools) + .await } fn get_model_config(&self) -> ModelConfig { @@ -496,7 +499,7 @@ mod tests { use goose::config::GooseMode; use goose::session::SessionManager; - async fn setup_agent_with_extension_manager() -> Agent { + async fn setup_agent_with_extension_manager() -> (Agent, String) { // Add the TODO extension to the config so it can be discovered by search_available_extensions // Set it as disabled initially so tests can enable it let todo_extension_entry = ExtensionEntry { @@ -515,6 +518,7 @@ mod tests { // Create agent with session_id from the start let temp_dir = tempfile::tempdir().unwrap(); let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let session_id = "test-session-id".to_string(); let config = AgentConfig::new( session_manager, PermissionManager::instance(), @@ -536,13 +540,13 @@ mod tests { .add_extension(ext_config) .await .expect("Failed to add extension manager"); - agent + (agent, session_id) } #[tokio::test] async fn test_extension_manager_tools_available() { - let agent = setup_agent_with_extension_manager().await; - let tools = agent.list_tools("test-session-id", None).await; + let (agent, session_id) = setup_agent_with_extension_manager().await; + let tools = agent.list_tools(&session_id, None).await; // Note: Tool names are prefixed with the normalized extension name "extensionmanager" // not the display name "Extension Manager" diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 71fae27d8e3f..5a65fd7fa843 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -59,6 +59,7 @@ impl Provider for MockProvider { async fn complete_with_model( &self, + _session_id: &str, _model_config: &ModelConfig, _system: &str, _messages: &[Message], diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 430506f8aff7..7f693f2d57d1 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -100,7 +100,12 @@ impl ProviderTester { let (response, _) = self .provider - .complete("You are a helpful assistant.", &[message], &[]) + .complete( + "test-session-id", + "You are a helpful assistant.", + &[message], + &[], + ) .await?; assert_eq!( @@ -138,6 +143,7 @@ impl ProviderTester { let (response1, _) = self .provider .complete( + "test-session-id", "You are a helpful weather assistant.", std::slice::from_ref(&message), std::slice::from_ref(&weather_tool), @@ -186,6 +192,7 @@ impl ProviderTester { let (response2, _) = self .provider .complete( + "test-session-id", "You are a helpful weather assistant.", &[message, response1, weather], &[weather_tool], @@ -228,7 +235,12 @@ impl ProviderTester { let result = self .provider - .complete("You are a helpful assistant.", &messages, &[]) + .complete( + "test-session-id", + "You are a helpful assistant.", + &messages, + &[], + ) .await; println!("=== {}::context_length_exceeded_error ===", self.name); @@ -286,6 +298,7 @@ impl ProviderTester { let result = self .provider .complete( + "test-session-id", "You are a helpful assistant. Describe what you see in the image briefly.", &[message_with_image], &[], @@ -338,6 +351,7 @@ impl ProviderTester { let result2 = self .provider .complete( + "test-session-id", "You are a helpful assistant.", &[user_message, tool_request, tool_response], &[screenshot_tool], diff --git a/crates/goose/tests/session_id_propagation_test.rs b/crates/goose/tests/session_id_propagation_test.rs index 10a83a6261f3..5142e18fc5d7 100644 --- a/crates/goose/tests/session_id_propagation_test.rs +++ b/crates/goose/tests/session_id_propagation_test.rs @@ -3,7 +3,6 @@ use goose::model::ModelConfig; use goose::providers::api_client::{ApiClient, AuthMethod}; use goose::providers::base::Provider; use goose::providers::openai::OpenAiProvider; -use goose::session_context; use goose::session_context::SESSION_ID_HEADER; use serde_json::json; use std::sync::Arc; @@ -81,30 +80,19 @@ async fn setup_mock_server() -> (MockServer, HeaderCapture, Box) { (mock_server, capture, provider) } -async fn make_request(provider: &dyn Provider, session_id: Option<&str>) { +async fn make_request(provider: &dyn Provider, session_id: &str) { let message = Message::user().with_text("test message"); - let request_fn = async { - provider - .complete("You are a helpful assistant.", &[message], &[]) - .await - .unwrap() - }; - - match session_id { - Some(id) => { - session_context::with_session_id(Some(id.to_string()), request_fn).await; - } - None => { - request_fn.await; - } - } + let _ = provider + .complete(session_id, "You are a helpful assistant.", &[message], &[]) + .await + .unwrap(); } #[tokio::test] async fn test_session_id_propagation_to_llm() { let (_, capture, provider) = setup_mock_server().await; - make_request(provider.as_ref(), Some("integration-test-session-123")).await; + make_request(provider.as_ref(), "integration-test-session-123").await; assert_eq!( capture.get_captured(), @@ -113,26 +101,29 @@ async fn test_session_id_propagation_to_llm() { } #[tokio::test] -async fn test_no_session_id_when_absent() { +async fn test_session_id_always_present() { let (_, capture, provider) = setup_mock_server().await; - make_request(provider.as_ref(), None).await; + make_request(provider.as_ref(), "test-session-id").await; - assert_eq!(capture.get_captured(), vec![None]); + assert_eq!( + capture.get_captured(), + vec![Some("test-session-id".to_string())] + ); } #[tokio::test] async fn test_session_id_matches_across_calls() { let (_, capture, provider) = setup_mock_server().await; - let test_session_id = "consistent-session-456"; - make_request(provider.as_ref(), Some(test_session_id)).await; - make_request(provider.as_ref(), Some(test_session_id)).await; - make_request(provider.as_ref(), Some(test_session_id)).await; + let session_id = "consistent-session-456"; + make_request(provider.as_ref(), session_id).await; + make_request(provider.as_ref(), session_id).await; + make_request(provider.as_ref(), session_id).await; assert_eq!( capture.get_captured(), - vec![Some(test_session_id.to_string()); 3] + vec![Some(session_id.to_string()); 3] ); } @@ -142,8 +133,8 @@ async fn test_different_sessions_have_different_ids() { let session_id_1 = "session-one"; let session_id_2 = "session-two"; - make_request(provider.as_ref(), Some(session_id_1)).await; - make_request(provider.as_ref(), Some(session_id_2)).await; + make_request(provider.as_ref(), session_id_1).await; + make_request(provider.as_ref(), session_id_2).await; assert_eq!( capture.get_captured(), diff --git a/crates/goose/tests/tetrate_streaming.rs b/crates/goose/tests/tetrate_streaming.rs index 784adbba68f6..ab4663f38cee 100644 --- a/crates/goose/tests/tetrate_streaming.rs +++ b/crates/goose/tests/tetrate_streaming.rs @@ -29,6 +29,7 @@ mod tetrate_streaming_tests { let mut stream = provider .stream( + "test-session-id", "You are a helpful assistant that counts numbers.", &messages, &[], @@ -100,6 +101,7 @@ mod tetrate_streaming_tests { let mut stream = provider .stream( + "test-session-id", "You are a helpful assistant with access to weather information.", &messages, &[weather_tool], @@ -146,7 +148,12 @@ mod tetrate_streaming_tests { let messages = vec![Message::user().with_text("")]; let mut stream = provider - .stream("You are a helpful assistant.", &messages, &[]) + .stream( + "test-session-id", + "You are a helpful assistant.", + &messages, + &[], + ) .await?; let mut chunk_count = 0; @@ -177,6 +184,7 @@ mod tetrate_streaming_tests { let mut stream = provider .stream( + "test-session-id", "You are a helpful assistant that writes detailed essays.", &messages, &[], @@ -235,7 +243,12 @@ mod tetrate_streaming_tests { let messages = vec![Message::user().with_text("Hello")]; let result = provider - .stream("You are a helpful assistant.", &messages, &[]) + .stream( + "test-session-id", + "You are a helpful assistant.", + &messages, + &[], + ) .await; // We expect this to fail with an authentication error @@ -258,11 +271,21 @@ mod tetrate_streaming_tests { let messages2 = vec![Message::user().with_text("Say 'Stream 2'")]; let stream1 = provider - .stream("You are a helpful assistant.", &messages1, &[]) + .stream( + "test-session-id", + "You are a helpful assistant.", + &messages1, + &[], + ) .await?; let stream2 = provider - .stream("You are a helpful assistant.", &messages2, &[]) + .stream( + "test-session-id", + "You are a helpful assistant.", + &messages2, + &[], + ) .await?; // Process both streams concurrently