diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 48c9c8c54883..02ee76862aa3 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -17,7 +17,6 @@ use goose::config::{ configure_tetrate, Config, ConfigError, ExperimentManager, ExtensionEntry, GooseMode, PermissionManager, }; -use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::posthog::{get_telemetry_choice, TELEMETRY_ENABLED_KEY}; use goose::providers::provider_test::test_provider_configuration; @@ -25,7 +24,6 @@ use goose::providers::{create, providers, retry_operation, RetryConfig}; use goose::session::SessionType; use serde_json::Value; use std::collections::HashMap; -use uuid::Uuid; // useful for light themes where there is no dicernible colour contrast between // cursor-selected and cursor-unselected items. @@ -683,10 +681,8 @@ pub async fn configure_provider_dialog() -> anyhow::Result { let models_res = { let temp_model_config = ModelConfig::new(&provider_meta.default_model)?; let temp_provider = create(provider_name, temp_model_config).await?; - // Provider setup runs before any user session exists; use an ephemeral id. - let session_id = Uuid::new_v4().to_string(); retry_operation(&RetryConfig::default(), || async { - temp_provider.fetch_recommended_models(&session_id).await + temp_provider.fetch_recommended_models().await }) .await }; @@ -1658,11 +1654,11 @@ pub async fn handle_openrouter_auth() -> anyhow::Result<()> { match create("openrouter", model_config).await { Ok(provider) => { - // Config verification runs before any user session exists; use an ephemeral id. - let session_id = Uuid::new_v4().to_string(); + let model_config = provider.get_model_config(); let test_result = provider - .complete( - &session_id, + .complete_with_model( + None, + &model_config, "You are goose, an AI assistant.", &[Message::user().with_text("Say 'Configuration test successful!'")], &[], @@ -1738,16 +1734,7 @@ pub async fn handle_tetrate_auth() -> anyhow::Result<()> { match create("tetrate", model_config).await { Ok(provider) => { - // Config verification runs before any user session exists; use an ephemeral id. - let session_id = Uuid::new_v4().to_string(); - let test_result = provider - .complete( - &session_id, - "You are goose, an AI assistant.", - &[Message::user().with_text("Say 'Configuration test successful!'")], - &[], - ) - .await; + let test_result = provider.fetch_supported_models().await; match test_result { Ok(_) => { diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 3dbc23d4f8fe..a42a2e97a000 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -28,7 +28,6 @@ use serde_json::Value; use serde_yaml; use std::{collections::HashMap, sync::Arc}; use utoipa::ToSchema; -use uuid::Uuid; #[derive(Serialize, ToSchema)] pub struct ExtensionResponse { @@ -409,10 +408,8 @@ pub async fn get_provider_models( .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - // Config endpoints have no user session; use an ephemeral id for the probe. - let session_id = Uuid::new_v4().to_string(); let models_result = retry_operation(&RetryConfig::default(), || async { - provider.fetch_recommended_models(&session_id).await + provider.fetch_recommended_models().await }) .await; @@ -592,10 +589,8 @@ pub async fn detect_provider( Json(detect_request): Json, ) -> Result, StatusCode> { let api_key = detect_request.api_key.trim(); - // Provider detection runs without a user session; use an ephemeral id. - let session_id = Uuid::new_v4().to_string(); - match detect_provider_from_api_key(&session_id, api_key).await { + match detect_provider_from_api_key(api_key).await { Some((provider_name, models)) => Ok(Json(DetectProviderResponse { provider_name, models, diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index c241a9eb8c13..9d602f77da96 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -1,10 +1,8 @@ use anyhow::Result; use dotenvy::dotenv; use goose::conversation::message::Message; +use goose::providers::create_with_named_model; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; -use goose::providers::{base::Usage, create_with_named_model}; -use tokio_stream::StreamExt; -use uuid::Uuid; #[tokio::main] async fn main() -> Result<()> { @@ -20,25 +18,25 @@ async fn main() -> Result<()> { let message = Message::user().with_text("Tell me a short joke about programming."); // Get a response - let session_id = Uuid::new_v4().to_string(); - let mut stream = provider - .stream(&session_id, "You are a helpful assistant.", &[message], &[]) + let (response, usage) = provider + .complete_with_model( + None, + &provider.get_model_config(), + "You are a helpful assistant.", + &[message], + &[], + ) .await?; println!("\nResponse from AI:"); println!("---------------"); - let mut usage = Usage::default(); - while let Some(Ok((msg, usage_part))) = stream.next().await { - dbg!(msg); - if let Some(u) = usage_part { - usage += u.usage; - } - } + println!("{:?}", response); + println!("\nToken Usage:"); println!("------------"); - println!("Input tokens: {:?}", usage.input_tokens); - println!("Output tokens: {:?}", usage.output_tokens); - println!("Total tokens: {:?}", usage.total_tokens); + println!("Input tokens: {:?}", usage.usage.input_tokens); + println!("Output tokens: {:?}", usage.usage.output_tokens); + println!("Total tokens: {:?}", usage.usage.total_tokens); Ok(()) } diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 62d2457352f8..96eb4c48078c 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -10,7 +10,6 @@ use rmcp::model::{CallToolRequestParams, Content, Tool}; use rmcp::object; use std::fs; use std::sync::Arc; -use uuid::Uuid; #[tokio::main] async fn main() -> Result<()> { @@ -63,10 +62,10 @@ async fn main() -> Result<()> { }, } }); - let session_id = Uuid::new_v4().to_string(); let (response, usage) = provider - .complete( - &session_id, + .complete_with_model( + None, + &provider.get_model_config(), "You are a helpful assistant. Please describe any text you see in the image.", &messages, &[Tool::new("view_image", "View an image", input_schema)], diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index 3a09e3affe83..3b4d7d59230e 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -27,17 +27,12 @@ use rmcp::{ ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt, }; use serde_json::Value; -use std::{ - sync::{Arc, OnceLock}, - time::Duration, -}; +use std::{sync::Arc, time::Duration}; use tokio::sync::{ mpsc::{self, Sender}, Mutex, }; use tokio_util::sync::CancellationToken; -use uuid::Uuid; - pub type BoxError = Box; pub type Error = rmcp::ServiceError; @@ -112,8 +107,6 @@ pub struct GooseClient { provider: SharedProvider, // Single-slot because calls are serialized per MCP client; see send_request_with_session. current_session_id: Arc>>, - // Connection-scoped fallback for server-initiated sampling. - client_session_id: OnceLock, } impl GooseClient { @@ -125,7 +118,6 @@ impl GooseClient { notification_handlers: handlers, provider, current_session_id: Arc::new(Mutex::new(None)), - client_session_id: OnceLock::new(), } } @@ -144,22 +136,10 @@ impl GooseClient { slot.clone() } - async fn resolve_session_id(&self, extensions: &Extensions) -> String { + async fn resolve_session_id(&self, extensions: &Extensions) -> Option { // Prefer explicit MCP metadata, then the active request scope. - if let Some(session_id) = Self::session_id_from_extensions(extensions) { - return session_id; - } - if let Some(session_id) = self.current_session_id().await { - return session_id; - } - // Fallback for server-initiated sampling not tied to a request session. - self.client_session_id() - } - - fn client_session_id(&self) -> String { - self.client_session_id - .get_or_init(|| Uuid::new_v4().to_string()) - .clone() + let current_session_id = self.current_session_id().await; + Self::session_id_from_extensions(extensions).or(current_session_id) } fn session_id_from_extensions(extensions: &Extensions) -> Option { @@ -255,7 +235,13 @@ impl ClientHandler for GooseClient { .unwrap_or("You are a general-purpose AI agent called goose"); let (response, usage) = provider - .complete(&session_id, system_prompt, &provider_ready_messages, &[]) + .complete_with_model( + session_id.as_deref(), + &provider.get_model_config(), + system_prompt, + &provider_ready_messages, + &[], + ) .await .map_err(|e| { ErrorData::new( @@ -387,7 +373,7 @@ impl McpClient { request: ClientRequest, cancel_token: CancellationToken, ) -> Result { - let request = inject_session_id_into_request(request, session_id); + let request = inject_session_id_into_request(request, Some(session_id)); // ExtensionManager serializes calls per MCP connection, so one current_session_id slot // is sufficient for mapping callbacks to the active request session. let handle = { @@ -629,7 +615,12 @@ impl McpClientTrait for McpClient { } /// Injects the given session_id into Extensions._meta. -fn inject_session_id_into_extensions(mut extensions: Extensions, session_id: &str) -> Extensions { +/// None (or empty) removes any existing session id. +fn inject_session_id_into_extensions( + mut extensions: Extensions, + session_id: Option<&str>, +) -> Extensions { + let session_id = session_id.filter(|id| !id.is_empty()); let mut meta_map = extensions .get::() .map(|meta| meta.0.clone()) @@ -638,16 +629,21 @@ fn inject_session_id_into_extensions(mut extensions: Extensions, session_id: &st // JsonObject is case-sensitive, so we use retain for case-insensitive removal meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER)); - meta_map.insert( - SESSION_ID_HEADER.to_string(), - Value::String(session_id.to_string()), - ); + if let Some(session_id) = session_id { + meta_map.insert( + SESSION_ID_HEADER.to_string(), + Value::String(session_id.to_string()), + ); + } extensions.insert(Meta(meta_map)); extensions } -fn inject_session_id_into_request(request: ClientRequest, session_id: &str) -> ClientRequest { +fn inject_session_id_into_request( + request: ClientRequest, + session_id: Option<&str>, +) -> ClientRequest { match request { ClientRequest::ListResourcesRequest(mut req) => { req.extensions = inject_session_id_into_extensions(req.extensions, session_id); @@ -680,6 +676,7 @@ fn inject_session_id_into_request(request: ClientRequest, session_id: &str) -> C #[cfg(test)] mod tests { use super::*; + use serde_json::json; use test_case::test_case; fn new_client() -> GooseClient { @@ -770,45 +767,39 @@ mod tests { #[test_case( Some("ext-session"), Some("current-session"), - "ext-session"; + Some("ext-session"); "extensions win" )] #[test_case( None, Some("current-session"), - "current-session"; + Some("current-session"); "current when no extensions" )] #[test_case( None, None, - "client-session"; - "client fallback when no session" + None; + "no session when no extensions or current" )] fn test_resolve_session_id( ext_session: Option<&str>, current_session: Option<&str>, - expected: &str, + expected: Option<&str>, ) { let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(async { let client = new_client(); - // Make the fallback deterministic so the expected value can live in the test_case row. - client - .client_session_id - .get_or_init(|| "client-session".to_string()); if let Some(session_id) = current_session { let mut slot = client.current_session_id.lock().await; *slot = Some(session_id.to_string()); } - let mut extensions = Extensions::new(); - if let Some(session_id) = ext_session { - extensions = inject_session_id_into_extensions(extensions, session_id); - } + let extensions = inject_session_id_into_extensions(Extensions::new(), ext_session); let resolved = client.resolve_session_id(&extensions).await; + let expected = expected.map(str::to_string); assert_eq!(resolved, expected); }); } @@ -833,7 +824,7 @@ mod tests { ); let request = request_builder(extensions); - let request = inject_session_id_into_request(request, session_id); + let request = inject_session_id_into_request(request, Some(session_id)); let extensions = request_extensions(&request).expect("request should have extensions"); let meta = extensions .get::() @@ -854,7 +845,7 @@ mod tests { use serde_json::json; let session_id = "test-session-789"; - let extensions = inject_session_id_into_extensions(Default::default(), session_id); + let extensions = inject_session_id_into_extensions(Default::default(), Some(session_id)); let mcp_meta = extensions.get::().unwrap(); assert_eq!( @@ -867,12 +858,35 @@ mod tests { ); } - #[test] - fn test_session_id_case_insensitive_replacement() { + #[test_case( + Some("new-session-id"), + json!({ + SESSION_ID_HEADER: "new-session-id", + "other-key": "preserve-me" + }); + "replace" + )] + #[test_case( + None, + json!({ + "other-key": "preserve-me" + }); + "remove" + )] + #[test_case( + Some(""), + json!({ + "other-key": "preserve-me" + }); + "empty removes" + )] + fn test_session_id_case_insensitive_replacement( + session_id: Option<&str>, + expected_meta: serde_json::Value, + ) { use rmcp::model::Extensions; use serde_json::{from_value, json}; - let session_id = "new-session-id"; let mut extensions = Extensions::new(); extensions.insert( from_value::(json!({ @@ -886,14 +900,6 @@ mod tests { let extensions = inject_session_id_into_extensions(extensions, session_id); let mcp_meta = extensions.get::().unwrap(); - assert_eq!( - &mcp_meta.0, - json!({ - SESSION_ID_HEADER: session_id, - "other-key": "preserve-me" - }) - .as_object() - .unwrap() - ); + assert_eq!(&mcp_meta.0, expected_meta.as_object().unwrap()); } } diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 54d62e423642..a6f34002307d 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -435,7 +435,7 @@ mod tests { async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, _model_config: &ModelConfig, _system: &str, _messages: &[Message], diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index d3fea4b889ab..e06eb37f93d5 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -473,7 +473,7 @@ mod tests { async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, _model_config: &ModelConfig, _system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index ff5a8c6b6e5a..9de2060b3ebb 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -116,7 +116,11 @@ impl AnthropicProvider { headers } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let mut request = self.api_client.request(session_id, "v1/messages"); for (key, value) in self.get_conditional_headers() { @@ -198,7 +202,7 @@ impl Provider for AnthropicProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -228,11 +232,8 @@ impl Provider for AnthropicProvider { Ok((message, provider_usage)) } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { - let response = self.api_client.api_get(session_id, "v1/models").await?; + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self.api_client.request(None, "v1/models").api_get().await?; if response.status != StatusCode::OK { return Err(map_http_error_to_provider_error( @@ -268,7 +269,7 @@ impl Provider for AnthropicProvider { .unwrap() .insert("stream".to_string(), Value::Bool(true)); - let mut request = self.api_client.request(session_id, "v1/messages"); + let mut request = self.api_client.request(Some(session_id), "v1/messages"); let mut log = RequestLog::start(&self.model, &payload)?; for (key, value) in self.get_conditional_headers() { diff --git a/crates/goose/src/providers/api_client.rs b/crates/goose/src/providers/api_client.rs index 163257f5b6c8..e92fd3229137 100644 --- a/crates/goose/src/providers/api_client.rs +++ b/crates/goose/src/providers/api_client.rs @@ -196,7 +196,7 @@ pub struct ApiRequestBuilder<'a> { client: &'a ApiClient, path: &'a str, headers: HeaderMap, - session_id: &'a str, + session_id: Option<&'a str>, } impl ApiClient { @@ -273,10 +273,15 @@ impl ApiClient { Ok(self) } - pub fn request<'a>(&'a self, session_id: &'a str, path: &'a str) -> ApiRequestBuilder<'a> { + /// - `session_id`: Use `None` only for configuration or pre-session tasks. + pub fn request<'a>( + &'a self, + session_id: Option<&'a str>, + path: &'a str, + ) -> ApiRequestBuilder<'a> { ApiRequestBuilder { client: self, - session_id, + session_id: session_id.filter(|id| !id.is_empty()), path, headers: HeaderMap::new(), } @@ -284,7 +289,7 @@ impl ApiClient { pub async fn api_post( &self, - session_id: &str, + session_id: Option<&str>, path: &str, payload: &Value, ) -> Result { @@ -293,18 +298,18 @@ impl ApiClient { pub async fn response_post( &self, - session_id: &str, + session_id: Option<&str>, path: &str, payload: &Value, ) -> Result { self.request(session_id, path).response_post(payload).await } - pub async fn api_get(&self, session_id: &str, path: &str) -> Result { + pub async fn api_get(&self, session_id: Option<&str>, path: &str) -> Result { self.request(session_id, path).api_get().await } - pub async fn response_get(&self, session_id: &str, path: &str) -> Result { + pub async fn response_get(&self, session_id: Option<&str>, path: &str) -> Result { self.request(session_id, path).response_get().await } @@ -373,10 +378,16 @@ impl<'a> ApiRequestBuilder<'a> { F: FnOnce(url::Url, &Client) -> reqwest::RequestBuilder, { let url = self.client.build_url(self.path)?; - let mut request = request_builder(url, &self.client.client); - request = request.headers(self.headers.clone()); + let mut headers = self.headers.clone(); + headers.remove(SESSION_ID_HEADER); + if let Some(session_id) = self.session_id { + let header_name = HeaderName::from_static(SESSION_ID_HEADER); + let header_value = HeaderValue::from_str(session_id)?; + headers.insert(header_name, header_value); + } - request = request.header(SESSION_ID_HEADER, self.session_id); + let mut request = request_builder(url, &self.client.client); + request = request.headers(headers); request = match &self.client.auth { AuthMethod::BearerToken(token) => { @@ -411,69 +422,40 @@ impl fmt::Debug for ApiClient { #[cfg(test)] mod tests { use super::*; - - #[tokio::test] - async fn test_session_id_header_injection() { - let client = ApiClient::new( - "http://localhost:8080".to_string(), - AuthMethod::BearerToken("test-token".to_string()), - ) - .unwrap(); - - let builder = client.request("test-session_id-456", "/test"); - let request = builder - .send_request(|url, client| client.get(url)) - .await + use test_case::test_case; + + #[test_case(Some("test-session_id-456"), None, Some("test-session_id-456"); "header set")] + #[test_case(Some("new-session"), Some(("Agent-Session-Id", "old-session")), Some("new-session"); "replaces existing")] + #[test_case(None, Some(("Agent-Session-Id", "old-session")), None; "removes existing on none")] + #[test_case(Some(""), Some(("agent-session-id", "old-session")), None; "removes existing on empty")] + fn test_session_id_header( + session_id: Option<&str>, + existing_header: Option<(&str, &str)>, + expected: Option<&str>, + ) { + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let client = ApiClient::new( + "http://localhost:8080".to_string(), + AuthMethod::BearerToken("test-token".to_string()), + ) .unwrap(); - let headers = request.build().unwrap().headers().clone(); - - assert!(headers.contains_key(SESSION_ID_HEADER)); - assert_eq!( - headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(), - "test-session_id-456" - ); - } - - #[tokio::test] - async fn test_session_id_header_with_different_id() { - let client = ApiClient::new( - "http://localhost:8080".to_string(), - AuthMethod::BearerToken("test-token".to_string()), - ) - .unwrap(); - - let builder = client.request("another-session_id-789", "/test"); - let request = builder - .send_request(|url, client| client.get(url)) - .await - .unwrap(); - - let headers = request.build().unwrap().headers().clone(); - - assert!(headers.contains_key(SESSION_ID_HEADER)); - assert_eq!( - headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(), - "another-session_id-789" - ); - } - - #[tokio::test] - async fn test_session_id_header_always_present() { - let client = ApiClient::new( - "http://localhost:8080".to_string(), - AuthMethod::BearerToken("test-token".to_string()), - ) - .unwrap(); - - let builder = client.request("required-session_id", "/test"); - let request = builder - .send_request(|url, client| client.get(url)) - .await - .unwrap(); + let mut builder = client.request(session_id, "/test"); + if let Some((key, value)) = existing_header { + builder = builder.header(key, value).unwrap(); + } + let request = builder + .send_request(|url, client| client.get(url)) + .await + .unwrap(); - let headers = request.build().unwrap().headers().clone(); + let headers = request.build().unwrap().headers().clone(); - assert!(headers.contains_key(SESSION_ID_HEADER)); + let actual = headers + .get(SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()); + assert_eq!(actual, expected); + }); } } diff --git a/crates/goose/src/providers/auto_detect.rs b/crates/goose/src/providers/auto_detect.rs index 36236b634f31..0513fd928be7 100644 --- a/crates/goose/src/providers/auto_detect.rs +++ b/crates/goose/src/providers/auto_detect.rs @@ -1,10 +1,7 @@ use crate::model::ModelConfig; use crate::providers::retry::{retry_operation, RetryConfig}; -pub async fn detect_provider_from_api_key( - session_id: &str, - api_key: &str, -) -> Option<(String, Vec)> { +pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec)> { let provider_tests = vec![ ("anthropic", "ANTHROPIC_API_KEY"), ("openai", "OPENAI_API_KEY"), @@ -18,7 +15,6 @@ pub async fn detect_provider_from_api_key( .into_iter() .map(|(provider_name, env_key)| { let api_key = api_key.to_string(); - let session_id = session_id.to_string(); tokio::spawn(async move { let original_value = std::env::var(env_key).ok(); std::env::set_var(env_key, &api_key); @@ -31,7 +27,7 @@ pub async fn detect_provider_from_api_key( { Ok(provider) => { match retry_operation(&RetryConfig::default(), || async { - provider.fetch_supported_models(&session_id).await + provider.fetch_supported_models().await }) .await { diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 1e307e5c0781..cbecb88dd171 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -99,7 +99,11 @@ impl AzureProvider { }) } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { // Build the path for Azure OpenAI let path = format!( "openai/deployments/{}/chat/completions?api-version={}", @@ -147,7 +151,7 @@ impl Provider for AzureProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index c1ee2bcb75cd..cbd982ed0bca 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -368,9 +368,12 @@ pub trait Provider: Send + Sync { // Internal implementation of complete, used by complete_fast and complete // Providers should override this to implement their actual completion logic + // + /// # Parameters + /// - `session_id`: Use `None` only for configuration or pre-session tasks. async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -386,7 +389,7 @@ pub trait Provider: Send + Sync { tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let model_config = self.get_model_config(); - self.complete_with_model(session_id, &model_config, system, messages, tools) + self.complete_with_model(Some(session_id), &model_config, system, messages, tools) .await } @@ -402,7 +405,7 @@ pub trait Provider: Send + Sync { let fast_config = model_config.use_fast_model(); match self - .complete_with_model(session_id, &fast_config, system, messages, tools) + .complete_with_model(Some(session_id), &fast_config, system, messages, tools) .await { Ok(result) => Ok(result), @@ -414,8 +417,14 @@ pub trait Provider: Send + Sync { e, model_config.model_name ); - self.complete_with_model(session_id, &model_config, system, messages, tools) - .await + self.complete_with_model( + Some(session_id), + &model_config, + system, + messages, + tools, + ) + .await } else { Err(e) } @@ -430,19 +439,13 @@ pub trait Provider: Send + Sync { RetryConfig::default() } - async fn fetch_supported_models( - &self, - _session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { Ok(None) } /// Fetch models filtered by canonical registry and usability - async fn fetch_recommended_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { - let all_models = match self.fetch_supported_models(session_id).await? { + async fn fetch_recommended_models(&self) -> Result>, ProviderError> { + let all_models = match self.fetch_supported_models().await? { Some(models) => models, None => return Ok(None), }; @@ -490,7 +493,7 @@ pub trait Provider: Send + Sync { false } - async fn supports_cache_control(&self, _session_id: &str) -> bool { + async fn supports_cache_control(&self) -> bool { false } diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 174a534edbbb..84d7474310c2 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -11,6 +11,7 @@ use async_trait::async_trait; use aws_sdk_bedrockruntime::config::ProvideCredentials; use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; +use reqwest::header::HeaderValue; use rmcp::model::Tool; use serde_json::Value; @@ -18,6 +19,7 @@ use serde_json::Value; use super::formats::bedrock::{ from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config, }; +use crate::session_context::SESSION_ID_HEADER; pub const BEDROCK_DOC_LINK: &str = "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"; @@ -130,6 +132,7 @@ impl BedrockProvider { async fn converse( &self, + session_id: Option<&str>, system: &str, messages: &[Message], tools: &[Tool], @@ -153,6 +156,17 @@ impl BedrockProvider { request = request.tool_config(to_bedrock_tool_config(tools)?); } + let mut request = request.customize(); + + if let Some(session_id) = session_id.filter(|id| !id.is_empty()) { + let session_id = session_id.to_string(); + request = request.mutate_request(move |req| { + if let Ok(value) = HeaderValue::from_str(&session_id) { + req.headers_mut().insert(SESSION_ID_HEADER, value); + } + }); + } + let response = request .send() .await @@ -227,7 +241,7 @@ impl Provider for BedrockProvider { )] async fn complete_with_model( &self, - _session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -236,7 +250,7 @@ impl Provider for BedrockProvider { let model_name = model_config.model_name.clone(); let (bedrock_message, bedrock_usage) = self - .with_retry(|| self.converse(system, messages, tools)) + .with_retry(|| self.converse(session_id, system, messages, tools)) .await?; let usage = bedrock_usage diff --git a/crates/goose/src/providers/canonical/build_canonical_models.rs b/crates/goose/src/providers/canonical/build_canonical_models.rs index e8f0dcaf16e6..c3a6cb1db783 100644 --- a/crates/goose/src/providers/canonical/build_canonical_models.rs +++ b/crates/goose/src/providers/canonical/build_canonical_models.rs @@ -550,9 +550,7 @@ async fn check_provider( } }; - // Provider probe runs outside any user session; use an ephemeral id. - let session_id = uuid::Uuid::new_v4().to_string(); - let fetched_models = match provider.fetch_supported_models(&session_id).await { + let fetched_models = match provider.fetch_supported_models().await { Ok(Some(models)) => { println!(" ✓ Fetched {} models", models.len()); models diff --git a/crates/goose/src/providers/chatgpt_codex.rs b/crates/goose/src/providers/chatgpt_codex.rs index 974373fdc62b..8a6060531317 100644 --- a/crates/goose/src/providers/chatgpt_codex.rs +++ b/crates/goose/src/providers/chatgpt_codex.rs @@ -7,6 +7,7 @@ use crate::providers::errors::ProviderError; use crate::providers::formats::openai_responses::responses_api_to_streaming_message; use crate::providers::retry::ProviderRetry; use crate::providers::utils::handle_status_openai_compat; +use crate::session_context::SESSION_ID_HEADER; use anyhow::{anyhow, Result}; use async_stream::try_stream; use async_trait::async_trait; @@ -16,6 +17,7 @@ use chrono::{DateTime, Utc}; use futures::{StreamExt, TryStreamExt}; use jsonwebtoken::jwk::JwkSet; use jsonwebtoken::{decode, decode_header, DecodingKey, Validation}; +use reqwest::header::{HeaderName, HeaderValue}; use rmcp::model::{RawContent, Role, Tool}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -789,7 +791,11 @@ impl ChatGptCodexProvider { }) } - async fn post_streaming(&self, payload: &Value) -> Result { + async fn post_streaming( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let token_data = self .auth_provider .get_valid_token() @@ -805,6 +811,14 @@ impl ChatGptCodexProvider { ); } + if let Some(session_id) = session_id.filter(|id| !id.is_empty()) { + headers.insert( + HeaderName::from_static(SESSION_ID_HEADER), + HeaderValue::from_str(session_id) + .map_err(|e| ProviderError::ExecutionError(e.to_string()))?, + ); + } + let client = reqwest::Client::new(); let response = client .post(format!("{}/responses", CODEX_API_ENDPOINT)) @@ -856,7 +870,7 @@ impl Provider for ChatGptCodexProvider { )] async fn complete_with_model( &self, - _session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -870,7 +884,7 @@ impl Provider for ChatGptCodexProvider { let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post_streaming(&payload_clone).await + self.post_streaming(session_id, &payload_clone).await }) .await?; @@ -914,7 +928,7 @@ impl Provider for ChatGptCodexProvider { async fn stream( &self, - _session_id: &str, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -926,7 +940,7 @@ impl Provider for ChatGptCodexProvider { let response = self .with_retry(|| async { let payload_clone = payload.clone(); - self.post_streaming(&payload_clone).await + self.post_streaming(Some(session_id), &payload_clone).await }) .await?; @@ -953,10 +967,7 @@ impl Provider for ChatGptCodexProvider { Ok(()) } - async fn fetch_supported_models( - &self, - _session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { Ok(Some( CHATGPT_CODEX_KNOWN_MODELS .iter() diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 070fee48bd7c..594c0bcf1855 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -417,7 +417,7 @@ impl Provider for ClaudeCodeProvider { )] async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, // create_session == YYYYMMDD_N, but --session-id requires a UUID model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index f94c1484e18d..87ccf8f2148e 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -507,7 +507,7 @@ impl Provider for CodexProvider { )] async fn complete_with_model( &self, - session_id: &str, + _session_id: Option<&str>, // CLI has no external session-id flag to propagate. model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index 3f7af7002280..10193aa0c59f 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -352,7 +352,7 @@ impl Provider for CursorAgentProvider { )] async fn complete_with_model( &self, - session_id: &str, + _session_id: Option<&str>, // CLI has no external session-id flag to propagate. model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index ea9d6fbbeabe..7139a2024b62 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -206,7 +206,7 @@ impl DatabricksProvider { async fn post( &self, - session_id: &str, + session_id: Option<&str>, payload: Value, model_name: Option<&str>, ) -> Result { @@ -257,7 +257,7 @@ impl Provider for DatabricksProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -314,7 +314,7 @@ impl Provider for DatabricksProvider { .with_retry(|| async { let resp = self .api_client - .response_post(session_id, &path, &payload) + .response_post(Some(session_id), &path, &payload) .await?; if !resp.status().is_success() { let status = resp.status(); @@ -352,13 +352,11 @@ impl Provider for DatabricksProvider { .map_err(|e| ProviderError::ExecutionError(e.to_string())) } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let response = match self .api_client - .response_get(session_id, "api/2.0/serving-endpoints") + .request(None, "api/2.0/serving-endpoints") + .response_get() .await { Ok(resp) => resp, @@ -434,7 +432,7 @@ impl EmbeddingCapable for DatabricksProvider { }); let response = self - .with_retry(|| self.post(session_id, request.clone(), None)) + .with_retry(|| self.post(Some(session_id), request.clone(), None)) .await?; let embeddings = response["data"] diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 5e567c6e2043..fafda26f180f 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -25,6 +25,7 @@ use crate::providers::formats::gcpvertexai::{ use crate::providers::gcpauth::GcpAuth; use crate::providers::retry::RetryConfig; use crate::providers::utils::RequestLog; +use crate::session_context::SESSION_ID_HEADER; use rmcp::model::Tool; /// Base URL for GCP Vertex AI documentation @@ -262,6 +263,7 @@ impl GcpVertexAIProvider { async fn send_request_with_retry( &self, + session_id: Option<&str>, url: Url, payload: &Value, ) -> Result { @@ -285,11 +287,17 @@ impl GcpVertexAIProvider { .await .map_err(|e| ProviderError::Authentication(e.to_string()))?; - let response = self + let mut request = self .client .post(url.clone()) .json(payload) - .header("Authorization", auth_header) + .header("Authorization", auth_header); + + if let Some(session_id) = session_id.filter(|id| !id.is_empty()) { + request = request.header(SESSION_ID_HEADER, session_id); + } + + let response = request .send() .await .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; @@ -348,6 +356,7 @@ impl GcpVertexAIProvider { async fn post_with_location( &self, + session_id: Option<&str>, payload: &Value, context: &RequestContext, location: &str, @@ -356,7 +365,9 @@ impl GcpVertexAIProvider { .build_request_url(context.provider(), location, false) .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - let response = self.send_request_with_retry(url, payload).await?; + let response = self + .send_request_with_retry(session_id, url, payload) + .await?; response .json::() @@ -366,11 +377,12 @@ impl GcpVertexAIProvider { async fn post( &self, + session_id: Option<&str>, payload: &Value, context: &RequestContext, ) -> Result { let result = self - .post_with_location(payload, context, &self.location) + .post_with_location(session_id, payload, context, &self.location) .await; if self.location == context.model.known_location().to_string() || result.is_ok() { @@ -387,7 +399,7 @@ impl GcpVertexAIProvider { "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}" ); - self.post_with_location(payload, context, &known_location) + self.post_with_location(session_id, payload, context, &known_location) .await } _ => result, @@ -396,6 +408,7 @@ impl GcpVertexAIProvider { async fn post_stream_with_location( &self, + session_id: Option<&str>, payload: &Value, context: &RequestContext, location: &str, @@ -404,16 +417,17 @@ impl GcpVertexAIProvider { .build_request_url(context.provider(), location, true) .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - self.send_request_with_retry(url, payload).await + self.send_request_with_retry(session_id, url, payload).await } async fn post_stream( &self, + session_id: Option<&str>, payload: &Value, context: &RequestContext, ) -> Result { let result = self - .post_stream_with_location(payload, context, &self.location) + .post_stream_with_location(session_id, payload, context, &self.location) .await; if self.location == context.model.known_location().to_string() || result.is_ok() { @@ -430,7 +444,7 @@ impl GcpVertexAIProvider { "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}" ); - self.post_stream_with_location(payload, context, &known_location) + self.post_stream_with_location(session_id, payload, context, &known_location) .await } _ => result, @@ -593,7 +607,7 @@ impl Provider for GcpVertexAIProvider { )] async fn complete_with_model( &self, - _session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -603,7 +617,7 @@ impl Provider for GcpVertexAIProvider { let (request, context) = create_request(model_config, system, messages, tools)?; // Send request and process response - let response = self.post(&request, &context).await?; + let response = self.post(session_id, &request, &context).await?; let usage = get_usage(&response, &context)?; let mut log = RequestLog::start(model_config, &request)?; @@ -627,7 +641,7 @@ impl Provider for GcpVertexAIProvider { async fn stream( &self, - _session_id: &str, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], @@ -644,7 +658,7 @@ impl Provider for GcpVertexAIProvider { let mut log = RequestLog::start(&model_config, &request)?; let response = self - .post_stream(&request, &context) + .post_stream(Some(session_id), &request, &context) .await .inspect_err(|e| { let _ = log.error(e); @@ -672,10 +686,7 @@ impl Provider for GcpVertexAIProvider { })) } - async fn fetch_supported_models( - &self, - _session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let models: Vec = KNOWN_MODELS.iter().map(|s| s.to_string()).collect(); let filtered = self.filter_by_org_policy(models).await; Ok(Some(filtered)) diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index c13302ca28e9..dbab2d201f80 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -264,7 +264,7 @@ impl Provider for GeminiCliProvider { )] async fn complete_with_model( &self, - session_id: &str, + _session_id: Option<&str>, // CLI has no external session-id flag to propagate. _model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 8f90d0458bb2..7b2d702a4389 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -169,7 +169,11 @@ impl GithubCopilotProvider { }) } - async fn post(&self, session_id: &str, payload: &mut Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &mut Value, + ) -> Result { let (endpoint, token) = self.get_api_info().await?; let auth = AuthMethod::BearerToken(token); let mut headers = self.get_github_headers(); @@ -411,7 +415,7 @@ impl Provider for GithubCopilotProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -469,7 +473,7 @@ impl Provider for GithubCopilotProvider { let response = self .with_retry(|| async { let mut payload_clone = payload.clone(); - let resp = self.post(session_id, &mut payload_clone).await?; + let resp = self.post(Some(session_id), &mut payload_clone).await?; handle_status_openai_compat(resp).await }) .await @@ -480,10 +484,7 @@ impl Provider for GithubCopilotProvider { stream_openai_compat(response, log) } - async fn fetch_supported_models( - &self, - _session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let (endpoint, token) = self.get_api_info().await?; let url = format!("{}/models", endpoint); diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index b091cf2dd2ab..7f560412c39c 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -93,7 +93,7 @@ impl GoogleProvider { async fn post( &self, - session_id: &str, + session_id: Option<&str>, model_name: &str, payload: &Value, ) -> Result { @@ -107,7 +107,7 @@ impl GoogleProvider { async fn post_stream( &self, - session_id: &str, + session_id: Option<&str>, model_name: &str, payload: &Value, ) -> Result { @@ -151,7 +151,7 @@ impl Provider for GoogleProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -178,13 +178,11 @@ impl Provider for GoogleProvider { Ok((message, provider_usage)) } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let response = self .api_client - .response_get(session_id, "v1beta/models") + .request(None, "v1beta/models") + .response_get() .await?; let json: serde_json::Value = response.json().await?; let arr = match json.get("models").and_then(|v| v.as_array()) { @@ -216,7 +214,7 @@ impl Provider for GoogleProvider { let response = self .with_retry(|| async { - self.post_stream(session_id, &self.model.model_name, &payload) + self.post_stream(Some(session_id), &self.model.model_name, &payload) .await }) .await diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 15b98ef48922..02a5784e85c4 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -342,7 +342,7 @@ impl Provider for LeadWorkerProvider { async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, _model_config: &ModelConfig, system: &str, messages: &[Message], @@ -393,7 +393,10 @@ impl Provider for LeadWorkerProvider { } // Make the completion request - let result = provider.complete(session_id, system, messages, tools).await; + let model_config = provider.get_model_config(); + let result = provider + .complete_with_model(session_id, &model_config, system, messages, tools) + .await; // For technical failures, try with default model (lead provider) instead let final_result = match &result { @@ -401,9 +404,10 @@ impl Provider for LeadWorkerProvider { tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type); // Try with lead provider as the default/fallback for technical failures + let model_config = self.lead_provider.get_model_config(); let default_result = self .lead_provider - .complete(session_id, system, messages, tools) + .complete_with_model(session_id, &model_config, system, messages, tools) .await; match &default_result { @@ -428,19 +432,10 @@ impl Provider for LeadWorkerProvider { final_result } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Combine models from both providers - let lead_models = self - .lead_provider - .fetch_supported_models(session_id) - .await?; - let worker_models = self - .worker_provider - .fetch_supported_models(session_id) - .await?; + let lead_models = self.lead_provider.fetch_supported_models().await?; + let worker_models = self.worker_provider.fetch_supported_models().await?; match (lead_models, worker_models) { (Some(lead), Some(worker)) => { @@ -517,7 +512,7 @@ mod tests { async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, _model_config: &ModelConfig, _system: &str, _messages: &[Message], @@ -703,7 +698,7 @@ mod tests { async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, _model_config: &ModelConfig, _system: &str, _messages: &[Message], diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 0648ee9ce216..1f31b4cecf8c 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -73,8 +73,12 @@ impl LiteLLMProvider { }) } - async fn fetch_models(&self, session: &str) -> Result, ProviderError> { - let response = self.api_client.response_get(session, "model/info").await?; + async fn fetch_models(&self) -> Result, ProviderError> { + let response = self + .api_client + .request(None, "model/info") + .response_get() + .await?; if !response.status().is_success() { return Err(ProviderError::RequestFailed(format!( @@ -112,7 +116,11 @@ impl LiteLLMProvider { Ok(models) } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let response = self .api_client .response_post(session_id, &self.base_path, payload) @@ -168,7 +176,7 @@ impl Provider for LiteLLMProvider { #[tracing::instrument(skip_all, name = "provider_complete")] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -183,7 +191,7 @@ impl Provider for LiteLLMProvider { false, )?; - if self.supports_cache_control(session_id).await { + if self.supports_cache_control().await { payload = update_request_for_cache_control(&payload); } @@ -206,8 +214,8 @@ impl Provider for LiteLLMProvider { true } - async fn supports_cache_control(&self, session_id: &str) -> bool { - if let Ok(models) = self.fetch_models(session_id).await { + async fn supports_cache_control(&self) -> bool { + if let Ok(models) = self.fetch_models().await { if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) { return model_info.supports_cache_control.unwrap_or(false); } @@ -216,11 +224,8 @@ impl Provider for LiteLLMProvider { self.model.model_name.to_lowercase().contains("claude") } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { - match self.fetch_models(session_id).await { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + match self.fetch_models().await { Ok(models) => { let model_names: Vec = models.into_iter().map(|m| m.name).collect(); Ok(Some(model_names)) @@ -251,7 +256,7 @@ impl EmbeddingCapable for LiteLLMProvider { let response = self .api_client - .response_post(session_id, "v1/embeddings", &payload) + .response_post(Some(session_id), "v1/embeddings", &payload) .await?; let response_text = response.text().await?; let response_json: Value = serde_json::from_str(&response_text)?; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 3d3746bedfe9..9aa3a65c27c9 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -119,7 +119,11 @@ impl OllamaProvider { }) } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let response = self .api_client .response_post(session_id, "v1/chat/completions", payload) @@ -173,7 +177,7 @@ impl Provider for OllamaProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -273,7 +277,7 @@ impl Provider for OllamaProvider { .with_retry(|| async { let resp = self .api_client - .response_post(session_id, "v1/chat/completions", &payload) + .response_post(Some(session_id), "v1/chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) @@ -284,13 +288,11 @@ impl Provider for OllamaProvider { stream_openai_compat(response, log) } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let response = self .api_client - .response_get(session_id, "api/tags") + .request(None, "api/tags") + .response_get() .await .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 20a069fb9c2f..560375d20824 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -192,7 +192,11 @@ impl OpenAiProvider { model_name.starts_with("gpt-5-codex") || model_name.starts_with("gpt-5.1-codex") } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let response = self .api_client .response_post(session_id, &self.base_path, payload) @@ -202,7 +206,7 @@ impl OpenAiProvider { async fn post_responses( &self, - session_id: &str, + session_id: Option<&str>, payload: &Value, ) -> Result { let response = self @@ -253,7 +257,7 @@ impl Provider for OpenAiProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -323,14 +327,12 @@ impl Provider for OpenAiProvider { } } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let models_path = self.base_path.replace("v1/chat/completions", "v1/models"); let response = self .api_client - .response_get(session_id, &models_path) + .request(None, &models_path) + .response_get() .await?; let json = handle_response_openai_compat(response).await?; if let Some(err_obj) = json.get("error") { @@ -388,7 +390,7 @@ impl Provider for OpenAiProvider { let payload_clone = payload.clone(); let resp = self .api_client - .response_post(session_id, "v1/responses", &payload_clone) + .response_post(Some(session_id), "v1/responses", &payload_clone) .await?; handle_status_openai_compat(resp).await }) @@ -426,7 +428,7 @@ impl Provider for OpenAiProvider { .with_retry(|| async { let resp = self .api_client - .response_post(session_id, &self.base_path, &payload) + .response_post(Some(session_id), &self.base_path, &payload) .await?; handle_status_openai_compat(resp).await }) @@ -479,7 +481,7 @@ impl EmbeddingCapable for OpenAiProvider { let request_value = serde_json::to_value(request_clone) .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; self.api_client - .api_post(session_id, "v1/embeddings", &request_value) + .api_post(Some(session_id), "v1/embeddings", &request_value) .await .map_err(|e| ProviderError::ExecutionError(e.to_string())) }) diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 344206e315d7..a9276f4f87aa 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -69,7 +69,11 @@ impl OpenRouterProvider { }) } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let response = self .api_client .response_post(session_id, "api/v1/chat/completions", payload) @@ -189,7 +193,7 @@ fn is_gemini_model(model_name: &str) -> bool { async fn create_request_based_on_model( provider: &OpenRouterProvider, - session_id: &str, + session_id: Option<&str>, system: &str, messages: &[Message], tools: &[Tool], @@ -203,7 +207,13 @@ async fn create_request_based_on_model( false, )?; - if provider.supports_cache_control(session_id).await { + if let Some(session_id) = session_id.filter(|id| !id.is_empty()) { + if let Some(obj) = payload.as_object_mut() { + obj.insert("user".to_string(), Value::String(session_id.to_string())); + } + } + + if provider.supports_cache_control().await { payload = update_request_for_anthropic(&payload); } @@ -254,7 +264,7 @@ impl Provider for OpenRouterProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -287,15 +297,13 @@ impl Provider for OpenRouterProvider { } /// Fetch supported models from OpenRouter API (only models with tool support) - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Handle request failures gracefully // If the request fails, fall back to manual entry let response = match self .api_client - .response_get(session_id, "api/v1/models") + .request(None, "api/v1/models") + .response_get() .await { Ok(response) => response, @@ -370,7 +378,7 @@ impl Provider for OpenRouterProvider { Ok(Some(models)) } - async fn supports_cache_control(&self, _session_id: &str) -> bool { + async fn supports_cache_control(&self) -> bool { self.model .model_name .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) @@ -396,7 +404,7 @@ impl Provider for OpenRouterProvider { true, )?; - if self.supports_cache_control(session_id).await { + if self.supports_cache_control().await { payload = update_request_for_anthropic(&payload); } @@ -414,7 +422,7 @@ impl Provider for OpenRouterProvider { .with_retry(|| async { let resp = self .api_client - .response_post(session_id, "api/v1/chat/completions", &payload) + .response_post(Some(session_id), "api/v1/chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 6cbabdbbdd63..0517836ef117 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -14,6 +14,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::RequestLog; use crate::conversation::message::{Message, MessageContent}; +use crate::session_context::SESSION_ID_HEADER; use crate::model::ModelConfig; use chrono::Utc; @@ -154,17 +155,27 @@ impl SageMakerTgiProvider { Ok(request) } - async fn invoke_endpoint(&self, payload: Value) -> Result { + async fn invoke_endpoint( + &self, + session_id: Option<&str>, + payload: Value, + ) -> Result { let body = serde_json::to_string(&payload).map_err(|e| { ProviderError::RequestFailed(format!("Failed to serialize request: {}", e)) })?; - let response = self + let mut request = self .sagemaker_client .invoke_endpoint() .endpoint_name(&self.endpoint_name) .content_type("application/json") - .body(body.into_bytes().into()) + .body(body.into_bytes().into()); + + if let Some(session_id) = session_id.filter(|id| !id.is_empty()) { + request = request.custom_attributes(format!("{SESSION_ID_HEADER}={session_id}")); + } + + let response = request .send() .await .map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?; @@ -289,7 +300,7 @@ impl Provider for SageMakerTgiProvider { )] async fn complete_with_model( &self, - _session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -302,7 +313,7 @@ impl Provider for SageMakerTgiProvider { })?; let response = self - .with_retry(|| self.invoke_endpoint(request_payload.clone())) + .with_retry(|| self.invoke_endpoint(session_id, request_payload.clone())) .await?; let message = self.parse_tgi_response(response)?; diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 0e0e452b0ca8..88db4cd213d3 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -107,7 +107,11 @@ impl SnowflakeProvider { }) } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let response = self .api_client .response_post(session_id, "api/v2/cortex/inference:complete", payload) @@ -319,7 +323,7 @@ impl Provider for SnowflakeProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index 8415cc4d0a71..fa249ea0ec9b 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -121,7 +121,7 @@ impl Provider for TestProvider { async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, _model_config: &ModelConfig, system: &str, messages: &[Message], @@ -130,7 +130,10 @@ impl Provider for TestProvider { let hash = Self::hash_input(messages); if let Some(inner) = &self.inner { - let (message, usage) = inner.complete(session_id, system, messages, tools).await?; + let model_config = inner.get_model_config(); + let (message, usage) = inner + .complete_with_model(session_id, &model_config, system, messages, tools) + .await?; let record = TestRecord { input: TestInput { @@ -203,7 +206,7 @@ mod tests { async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, _model_config: &ModelConfig, _system: &str, _messages: &[Message], diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index 73d753d5e72b..4124445dee64 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -63,7 +63,11 @@ impl TetrateProvider { }) } - async fn post(&self, session_id: &str, payload: &Value) -> Result { + async fn post( + &self, + session_id: Option<&str>, + payload: &Value, + ) -> Result { let response = self .api_client .response_post(session_id, "v1/chat/completions", payload) @@ -158,7 +162,7 @@ impl Provider for TetrateProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -215,7 +219,7 @@ impl Provider for TetrateProvider { .with_retry(|| async { let resp = self .api_client - .response_post(session_id, "v1/chat/completions", &payload) + .response_post(Some(session_id), "v1/chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) @@ -228,12 +232,14 @@ impl Provider for TetrateProvider { } /// Fetch supported models from Tetrate Agent Router Service API (only models with tool support) - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Use the existing api_client which already has authentication configured - let response = match self.api_client.response_get(session_id, "v1/models").await { + let response = match self + .api_client + .request(None, "v1/models") + .response_get() + .await + { Ok(response) => response, Err(e) => { tracing::warn!("Failed to fetch models from Tetrate Agent Router Service API: {}, falling back to manual model entry", e); diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 14229948b09b..935590d6da1c 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -115,7 +115,7 @@ impl VeniceProvider { async fn post( &self, - session_id: &str, + session_id: Option<&str>, path: &str, payload: &Value, ) -> Result { @@ -229,13 +229,11 @@ impl Provider for VeniceProvider { self.model.clone() } - async fn fetch_supported_models( - &self, - session_id: &str, - ) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let response = self .api_client - .response_get(session_id, &self.models_path) + .request(None, &self.models_path) + .response_get() .await?; let json: serde_json::Value = response.json().await?; @@ -265,7 +263,7 @@ impl Provider for VeniceProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 38d980b5fb14..6b53c8efecc2 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -69,7 +69,7 @@ impl XaiProvider { }) } - async fn post(&self, session_id: &str, payload: Value) -> Result { + async fn post(&self, session_id: Option<&str>, payload: Value) -> Result { let response = self .api_client .response_post(session_id, "chat/completions", &payload) @@ -110,7 +110,7 @@ impl Provider for XaiProvider { )] async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -165,7 +165,7 @@ impl Provider for XaiProvider { .with_retry(|| async { let resp = self .api_client - .response_post(session_id, "chat/completions", &payload) + .response_post(Some(session_id), "chat/completions", &payload) .await?; handle_status_openai_compat(resp).await }) diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index b7d292d69d22..ab7374fcc96a 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -378,12 +378,14 @@ mod tests { async fn complete_with_model( &self, - session_id: &str, + session_id: Option<&str>, _model_config: &ModelConfig, system_prompt: &str, messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + // Test-only: coerce missing session_id to empty so complete() can be reused. + let session_id = session_id.unwrap_or(""); self.complete(session_id, system_prompt, messages, tools) .await } diff --git a/crates/goose/tests/compaction.rs b/crates/goose/tests/compaction.rs index 50acb2e660e9..4081db928ff7 100644 --- a/crates/goose/tests/compaction.rs +++ b/crates/goose/tests/compaction.rs @@ -94,9 +94,10 @@ impl MockCompactionProvider { #[async_trait] impl Provider for MockCompactionProvider { - async fn complete( + async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, + _model_config: &ModelConfig, system_prompt: &str, messages: &[Message], _tools: &[Tool], @@ -165,30 +166,6 @@ impl Provider for MockCompactionProvider { Ok((message, usage)) } - async fn complete_with_model( - &self, - session_id: &str, - _model_config: &ModelConfig, - system_prompt: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - self.complete(session_id, system_prompt, messages, tools) - .await - } - - async fn complete_fast( - &self, - session_id: &str, - system_prompt: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - // Compaction uses complete_fast, so delegate to complete - self.complete(session_id, system_prompt, messages, tools) - .await - } - fn get_model_config(&self) -> ModelConfig { ModelConfig::new("mock-model").unwrap() } diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 28d43ece795e..6ec9eea5f9fa 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -59,7 +59,7 @@ impl Provider for MockProvider { async fn complete_with_model( &self, - _session_id: &str, + _session_id: Option<&str>, _model_config: &ModelConfig, _system: &str, _messages: &[Message],