diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index dd44f28850dd..444c7666f13a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -825,11 +825,9 @@ impl Agent { } } Err(e) => { - yield AgentEvent::Message( - Message::assistant().with_text( - format!("Ran into this error trying to compact: {e}.\n\nPlease try again or create a new session") - ) - ); + yield AgentEvent::Message(Message::assistant().with_text( + format!("Ran into this error trying to compact: {e}.\n\nPlease try again or create a new session") + )); } } })) diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index 88c017a2e402..d42d62488d9b 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -1,4 +1,5 @@ use crate::agents::types::SharedProvider; +use crate::session_context::SESSION_ID_HEADER; use rmcp::model::{Content, ErrorCode, JsonObject}; /// MCP client implementation for Goose use rmcp::{ @@ -334,7 +335,7 @@ impl McpClientTrait for McpClient { ClientRequest::ListResourcesRequest(ListResourcesRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: Default::default(), + extensions: inject_session_into_extensions(Default::default()), }), cancel_token, ) @@ -358,7 +359,7 @@ impl McpClientTrait for McpClient { uri: uri.to_string(), }, method: Default::default(), - extensions: Default::default(), + extensions: inject_session_into_extensions(Default::default()), }), cancel_token, ) @@ -380,7 +381,7 @@ impl McpClientTrait for McpClient { ClientRequest::ListToolsRequest(ListToolsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: Default::default(), + extensions: inject_session_into_extensions(Default::default()), }), cancel_token, ) @@ -406,7 +407,7 @@ impl McpClientTrait for McpClient { arguments, }, method: Default::default(), - extensions: Default::default(), + extensions: inject_session_into_extensions(Default::default()), }), cancel_token, ) @@ -428,7 +429,7 @@ impl McpClientTrait for McpClient { ClientRequest::ListPromptsRequest(ListPromptsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), - extensions: Default::default(), + extensions: inject_session_into_extensions(Default::default()), }), cancel_token, ) @@ -458,7 +459,7 @@ impl McpClientTrait for McpClient { arguments, }, method: Default::default(), - extensions: Default::default(), + extensions: inject_session_into_extensions(Default::default()), }), cancel_token, ) @@ -476,3 +477,118 @@ impl McpClientTrait for McpClient { rx } } + +/// Replaces session ID, case-insensitively, in Extensions._meta. +fn inject_session_into_extensions( + mut extensions: rmcp::model::Extensions, +) -> rmcp::model::Extensions { + use rmcp::model::Meta; + + if let Some(session_id) = crate::session_context::current_session_id() { + let mut meta_map = extensions + .get::() + .map(|meta| meta.0.clone()) + .unwrap_or_default(); + + // JsonObject is case-sensitive, so we use retain for case-insensitive removal + meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER)); + + meta_map.insert(SESSION_ID_HEADER.to_string(), Value::String(session_id)); + + extensions.insert(Meta(meta_map)); + } + + extensions +} + +#[cfg(test)] +mod tests { + use super::*; + use rmcp::model::Meta; + + #[tokio::test] + async fn test_session_id_in_mcp_meta() { + use serde_json::json; + + let session_id = "test-session-789"; + crate::session_context::with_session_id(Some(session_id.to_string()), async { + let extensions = inject_session_into_extensions(Default::default()); + let meta = extensions.get::().unwrap(); + + assert_eq!( + &meta.0, + json!({ + SESSION_ID_HEADER: session_id + }) + .as_object() + .unwrap() + ); + }) + .await; + } + + #[tokio::test] + async fn test_no_session_id_in_mcp_when_absent() { + let extensions = inject_session_into_extensions(Default::default()); + let meta = extensions.get::(); + + assert!(meta.is_none()); + } + + #[tokio::test] + async fn test_all_mcp_operations_include_session() { + use serde_json::json; + + let session_id = "consistent-session-id"; + crate::session_context::with_session_id(Some(session_id.to_string()), async { + let ext1 = inject_session_into_extensions(Default::default()); + let ext2 = inject_session_into_extensions(Default::default()); + let ext3 = inject_session_into_extensions(Default::default()); + + for ext in [&ext1, &ext2, &ext3] { + assert_eq!( + &ext.get::().unwrap().0, + json!({ + SESSION_ID_HEADER: session_id + }) + .as_object() + .unwrap() + ); + } + }) + .await; + } + + #[tokio::test] + async fn test_session_id_case_insensitive_replacement() { + use rmcp::model::{Extensions, Meta}; + use serde_json::{from_value, json}; + + let session_id = "new-session-id"; + crate::session_context::with_session_id(Some(session_id.to_string()), async { + let mut extensions = Extensions::new(); + extensions.insert( + from_value::(json!({ + "GOOSE-SESSION-ID": "old-session-1", + "Goose-Session-Id": "old-session-2", + "other-key": "preserve-me" + })) + .unwrap(), + ); + + let extensions = inject_session_into_extensions(extensions); + let meta = extensions.get::().unwrap(); + + assert_eq!( + &meta.0, + json!({ + SESSION_ID_HEADER: session_id, + "other-key": "preserve-me" + }) + .as_object() + .unwrap() + ); + }) + .await; + } +} diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 5359573fe248..47ee99d08879 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -164,10 +164,15 @@ fn get_agent_messages( } else { None }; - let mut stream = agent - .reply(conversation.clone(), session_config, None) - .await - .map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?; + + let session_id = session_config.as_ref().map(|s| s.id.clone()); + let mut stream = crate::session_context::with_session_id(session_id, async { + agent + .reply(conversation.clone(), session_config, None) + .await + }) + .await + .map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?; while let Some(message_result) = stream.next().await { match message_result { Ok(AgentEvent::Message(msg)) => conversation.push(msg), diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 9e8fdad19002..f5383c8be583 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -17,6 +17,7 @@ pub mod scheduler_factory; pub mod scheduler_trait; pub mod security; pub mod session; +pub mod session_context; pub mod token_counter; pub mod tool_inspection; pub mod tool_monitor; diff --git a/crates/goose/src/providers/api_client.rs b/crates/goose/src/providers/api_client.rs index 821148bae757..449a74e2086c 100644 --- a/crates/goose/src/providers/api_client.rs +++ b/crates/goose/src/providers/api_client.rs @@ -1,3 +1,4 @@ +use crate::session_context::SESSION_ID_HEADER; use anyhow::Result; use async_trait::async_trait; use reqwest::{ @@ -369,6 +370,10 @@ impl<'a> ApiRequestBuilder<'a> { let mut request = request_builder(url, &self.client.client); request = request.headers(self.headers.clone()); + if let Some(session_id) = crate::session_context::current_session_id() { + request = request.header(SESSION_ID_HEADER, session_id); + } + request = match &self.client.auth { AuthMethod::BearerToken(token) => { request.header("Authorization", format!("Bearer {}", token)) @@ -398,3 +403,55 @@ impl fmt::Debug for ApiClient { .finish_non_exhaustive() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_session_id_header_injection() { + let client = ApiClient::new( + "http://localhost:8080".to_string(), + AuthMethod::BearerToken("test-token".to_string()), + ) + .unwrap(); + + // Execute request within session context + crate::session_context::with_session_id(Some("test-session-456".to_string()), async { + let builder = client.request("/test"); + let request = builder + .send_request(|url, client| client.get(url)) + .await + .unwrap(); + + let headers = request.build().unwrap().headers().clone(); + + assert!(headers.contains_key(SESSION_ID_HEADER)); + assert_eq!( + headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(), + "test-session-456" + ); + }) + .await; + } + + #[tokio::test] + async fn test_no_session_id_header_when_absent() { + let client = ApiClient::new( + "http://localhost:8080".to_string(), + AuthMethod::BearerToken("test-token".to_string()), + ) + .unwrap(); + + // Build a request without session context + let builder = client.request("/test"); + let request = builder + .send_request(|url, client| client.get(url)) + .await + .unwrap(); + + let headers = request.build().unwrap().headers().clone(); + + assert!(!headers.contains_key(SESSION_ID_HEADER)); + } +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index d50502a42b92..5658bdaff854 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,5 +1,5 @@ pub mod anthropic; -mod api_client; +pub mod api_client; pub mod azure; pub mod azureauth; pub mod base; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index d869d48d2eb4..8bb9f1a9e780 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -112,6 +112,20 @@ impl OpenAiProvider { }) } + #[doc(hidden)] + pub fn new(api_client: ApiClient, model: ModelConfig) -> Self { + Self { + api_client, + base_path: "v1/chat/completions".to_string(), + organization: None, + project: None, + model, + custom_headers: None, + supports_streaming: true, + name: Self::metadata().name, + } + } + pub fn from_custom_config( model: ModelConfig, config: DeclarativeProviderConfig, diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 0fd040817e1b..45eaeb13fa3b 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1216,9 +1216,13 @@ async fn run_scheduled_job_internal( retry_config: None, }; - match agent - .reply(conversation.clone(), Some(session_config.clone()), None) - .await + let session_id = Some(session_config.id.clone()); + match crate::session_context::with_session_id(session_id, async { + agent + .reply(conversation.clone(), Some(session_config.clone()), None) + .await + }) + .await { Ok(mut stream) => { use futures::StreamExt; diff --git a/crates/goose/src/session_context.rs b/crates/goose/src/session_context.rs new file mode 100644 index 000000000000..7379e348df62 --- /dev/null +++ b/crates/goose/src/session_context.rs @@ -0,0 +1,80 @@ +use tokio::task_local; + +pub const SESSION_ID_HEADER: &str = "goose-session-id"; + +task_local! { + pub static SESSION_ID: Option; +} + +pub async fn with_session_id(session_id: Option, f: F) -> F::Output +where + F: std::future::Future, +{ + if let Some(id) = session_id { + SESSION_ID.scope(Some(id), f).await + } else { + f.await + } +} + +pub fn current_session_id() -> Option { + SESSION_ID.try_with(|id| id.clone()).ok().flatten() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_session_id_available_when_set() { + with_session_id(Some("test-session-123".to_string()), async { + assert_eq!(current_session_id(), Some("test-session-123".to_string())); + }) + .await; + } + + #[tokio::test] + async fn test_session_id_none_when_not_set() { + let id = current_session_id(); + assert_eq!(id, None); + } + + #[tokio::test] + async fn test_session_id_none_when_explicitly_none() { + with_session_id(None, async { + assert_eq!(current_session_id(), None); + }) + .await; + } + + #[tokio::test] + async fn test_session_id_scoped_correctly() { + assert_eq!(current_session_id(), None); + + with_session_id(Some("outer-session".to_string()), async { + assert_eq!(current_session_id(), Some("outer-session".to_string())); + + with_session_id(Some("inner-session".to_string()), async { + assert_eq!(current_session_id(), Some("inner-session".to_string())); + }) + .await; + + assert_eq!(current_session_id(), Some("outer-session".to_string())); + }) + .await; + + assert_eq!(current_session_id(), None); + } + + #[tokio::test] + async fn test_session_id_across_await_points() { + with_session_id(Some("persistent-session".to_string()), async { + assert_eq!(current_session_id(), Some("persistent-session".to_string())); + + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + assert_eq!(current_session_id(), Some("persistent-session".to_string())); + }) + .await; + } +} diff --git a/crates/goose/tests/session_id_propagation_test.rs b/crates/goose/tests/session_id_propagation_test.rs new file mode 100644 index 000000000000..10a83a6261f3 --- /dev/null +++ b/crates/goose/tests/session_id_propagation_test.rs @@ -0,0 +1,155 @@ +use goose::conversation::message::Message; +use goose::model::ModelConfig; +use goose::providers::api_client::{ApiClient, AuthMethod}; +use goose::providers::base::Provider; +use goose::providers::openai::OpenAiProvider; +use goose::session_context; +use goose::session_context::SESSION_ID_HEADER; +use serde_json::json; +use std::sync::Arc; +use std::sync::Mutex; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, Request, ResponseTemplate}; + +#[derive(Clone, Default)] +struct HeaderCapture { + captured_headers: Arc>>>, +} + +impl HeaderCapture { + fn new() -> Self { + Self { + captured_headers: Arc::new(Mutex::new(Vec::new())), + } + } + + fn capture_session_header(&self, req: &Request) { + let session_id = req + .headers + .get(SESSION_ID_HEADER) + .map(|v| v.to_str().unwrap().to_string()); + self.captured_headers.lock().unwrap().push(session_id); + } + + fn get_captured(&self) -> Vec> { + self.captured_headers.lock().unwrap().clone() + } +} + +fn create_test_provider(mock_server_url: &str) -> Box { + let api_client = ApiClient::new( + mock_server_url.to_string(), + AuthMethod::BearerToken("test-key".to_string()), + ) + .unwrap(); + let model = ModelConfig::new_or_fail("gpt-5-nano"); + Box::new(OpenAiProvider::new(api_client, model)) +} + +async fn setup_mock_server() -> (MockServer, HeaderCapture, Box) { + let mock_server = MockServer::start().await; + let capture = HeaderCapture::new(); + let capture_clone = capture.clone(); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(move |req: &Request| { + capture_clone.capture_session_header(req); + ResponseTemplate::new(200).set_body_json(json!({ + "choices": [{ + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Hi there! How can I help you today?", + "role": "assistant" + } + }], + "created": 1755133833, + "id": "chatcmpl-test", + "model": "gpt-5-nano", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 8, + "total_tokens": 18 + } + })) + }) + .mount(&mock_server) + .await; + + let provider = create_test_provider(&mock_server.uri()); + (mock_server, capture, provider) +} + +async fn make_request(provider: &dyn Provider, session_id: Option<&str>) { + let message = Message::user().with_text("test message"); + let request_fn = async { + provider + .complete("You are a helpful assistant.", &[message], &[]) + .await + .unwrap() + }; + + match session_id { + Some(id) => { + session_context::with_session_id(Some(id.to_string()), request_fn).await; + } + None => { + request_fn.await; + } + } +} + +#[tokio::test] +async fn test_session_id_propagation_to_llm() { + let (_, capture, provider) = setup_mock_server().await; + + make_request(provider.as_ref(), Some("integration-test-session-123")).await; + + assert_eq!( + capture.get_captured(), + vec![Some("integration-test-session-123".to_string())] + ); +} + +#[tokio::test] +async fn test_no_session_id_when_absent() { + let (_, capture, provider) = setup_mock_server().await; + + make_request(provider.as_ref(), None).await; + + assert_eq!(capture.get_captured(), vec![None]); +} + +#[tokio::test] +async fn test_session_id_matches_across_calls() { + let (_, capture, provider) = setup_mock_server().await; + + let test_session_id = "consistent-session-456"; + make_request(provider.as_ref(), Some(test_session_id)).await; + make_request(provider.as_ref(), Some(test_session_id)).await; + make_request(provider.as_ref(), Some(test_session_id)).await; + + assert_eq!( + capture.get_captured(), + vec![Some(test_session_id.to_string()); 3] + ); +} + +#[tokio::test] +async fn test_different_sessions_have_different_ids() { + let (_, capture, provider) = setup_mock_server().await; + + let session_id_1 = "session-one"; + let session_id_2 = "session-two"; + make_request(provider.as_ref(), Some(session_id_1)).await; + make_request(provider.as_ref(), Some(session_id_2)).await; + + assert_eq!( + capture.get_captured(), + vec![ + Some(session_id_1.to_string()), + Some(session_id_2.to_string()) + ] + ); +}