diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 36fdfd4b4381..82b7b1a310b0 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -873,14 +873,17 @@ impl Agent { let reply_span = tracing::Span::current(); self.reset_retry_attempts().await; + let working_dir = session.working_dir.clone(); let provider = self.provider().await?; let session_id = session_config.id.clone(); - let working_dir = session.working_dir.clone(); - tokio::spawn(async move { - if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await { - warn!("Failed to generate session description: {}", e); - } - }); + let naming_handle = tokio::spawn(crate::session_context::with_session_id( + Some(session_id.clone()), + async move { + if let Err(e) = SessionManager::maybe_update_name(&session_id, provider).await { + warn!("Failed to generate session description: {}", e); + } + }, + )); Ok(Box::pin(async_stream::try_stream! { let _ = reply_span.enter(); @@ -1220,6 +1223,8 @@ impl Agent { tokio::task::yield_now().await; } + + let _ = naming_handle.await; })) } @@ -1528,7 +1533,13 @@ impl Agent { #[cfg(test)] mod tests { use super::*; + use crate::model::ModelConfig; + use crate::providers::base::ProviderUsage; use crate::recipe::Response; + use crate::session::session_manager::SessionType; + use crate::session::SessionManager; + use async_trait::async_trait; + use test_case::test_case; #[tokio::test] async fn test_add_final_output_tool() -> Result<()> { @@ -1587,4 +1598,144 @@ mod tests { Ok(()) } + + enum NamingBehavior { + Success, + Error, + Panic, + } + + #[derive(Clone)] + struct MockNamingProvider { + model_config: ModelConfig, + behavior: Arc>, + captured_session_id: Arc>>, + } + + impl MockNamingProvider { + fn new(behavior: NamingBehavior) -> Self { + Self { + model_config: ModelConfig::new_or_fail("test-model"), + behavior: Arc::new(std::sync::Mutex::new(behavior)), + captured_session_id: Arc::new(std::sync::Mutex::new(None)), + } + } + + fn get_captured_session_id(&self) -> Option { + self.captured_session_id.lock().unwrap().clone() + } + } + + #[async_trait] + impl Provider for MockNamingProvider { + fn metadata() -> crate::providers::base::ProviderMetadata { + crate::providers::base::ProviderMetadata::empty() + } + + fn get_name(&self) -> &str { + "mock" + } + + async fn complete_with_model( + &self, + _model_config: &ModelConfig, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + Ok(( + Message::assistant().with_text("Response"), + ProviderUsage::new("mock".to_string(), Default::default()), + )) + } + + async fn generate_session_name( + &self, + _messages: &Conversation, + ) -> Result { + *self.captured_session_id.lock().unwrap() = + crate::session_context::current_session_id(); + + let behavior = self.behavior.lock().unwrap(); + match *behavior { + NamingBehavior::Success => Ok("Generated Name".to_string()), + NamingBehavior::Error => { + Err(ProviderError::RequestFailed("naming failed".to_string())) + } + NamingBehavior::Panic => panic!("naming panicked"), + } + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + } + + // Verifies session ID is visible in generate_session_name when user hasn't provided a name. + // When user provides name, maybe_update_name early-returns and session ID isn't captured. + #[test_case(NamingBehavior::Success, None, "Generated Name", true)] + #[test_case(NamingBehavior::Error, None, "initial", true)] + #[test_case(NamingBehavior::Panic, None, "initial", true)] + #[test_case( + NamingBehavior::Success, + Some("my-custom-name"), + "my-custom-name", + false + )] + #[tokio::test] + async fn test_session_naming( + behavior: NamingBehavior, + user_provided_name: Option<&str>, + expected_name: &str, + should_capture_session_id: bool, + ) { + let provider = Arc::new(MockNamingProvider::new(behavior)); + let agent = Agent::new(); + agent.update_provider(provider.clone()).await.unwrap(); + + let session = SessionManager::create_session( + std::env::current_dir().unwrap(), + user_provided_name.unwrap_or("initial").to_string(), + SessionType::User, + ) + .await + .unwrap(); + + if let Some(name) = user_provided_name { + SessionManager::update_session(&session.id) + .user_provided_name(name) + .apply() + .await + .unwrap(); + } + + let stream = agent + .reply( + Message::user().with_text("test"), + SessionConfig { + id: session.id.clone(), + schedule_id: None, + max_turns: Some(1), + retry_config: None, + }, + None, + ) + .await + .unwrap(); + tokio::pin!(stream); + while stream.next().await.is_some() {} + + let session = SessionManager::get_session(&session.id, false) + .await + .unwrap(); + assert_eq!(session.name, expected_name); + + if should_capture_session_id { + assert_eq!(provider.get_captured_session_id(), Some(session.id.clone())); + } else { + assert_eq!(provider.get_captured_session_id(), None); + } + + SessionManager::delete_session(&session.id).await.unwrap(); + } } diff --git a/crates/goose/tests/session_id_propagation_test.rs b/crates/goose/tests/session_id_propagation_test.rs index 10a83a6261f3..724be463a563 100644 --- a/crates/goose/tests/session_id_propagation_test.rs +++ b/crates/goose/tests/session_id_propagation_test.rs @@ -1,8 +1,11 @@ +use goose::agents::{Agent, SessionConfig}; use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::api_client::{ApiClient, AuthMethod}; use goose::providers::base::Provider; use goose::providers::openai::OpenAiProvider; +use goose::session::session_manager::SessionType; +use goose::session::SessionManager; use goose::session_context; use goose::session_context::SESSION_ID_HEADER; use serde_json::json; @@ -153,3 +156,83 @@ async fn test_different_sessions_have_different_ids() { ] ); } + +#[tokio::test] +async fn test_session_id_propagation_in_rename_task() { + let mock_server = MockServer::start().await; + let capture = HeaderCapture::new(); + let capture_clone = capture.clone(); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(move |req: &Request| { + capture_clone.capture_session_header(req); + ResponseTemplate::new(200).set_body_json(json!({ + "choices": [{ + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Test response", + "role": "assistant" + } + }], + "created": 1755133833, + "id": "chatcmpl-test", + "model": "gpt-5-nano", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 8, + "total_tokens": 18 + } + })) + }) + .mount(&mock_server) + .await; + + let api_client = ApiClient::new( + mock_server.uri(), + AuthMethod::BearerToken("test-key".to_string()), + ) + .unwrap(); + let model = ModelConfig::new_or_fail("gpt-5-nano"); + let provider = Arc::new(OpenAiProvider::new(api_client, model)); + + let agent = Agent::new(); + agent.update_provider(provider).await.unwrap(); + + let session = SessionManager::create_session( + std::env::current_dir().unwrap(), + "initial".to_string(), + SessionType::User, + ) + .await + .unwrap(); + + session_context::with_session_id(Some(session.id.clone()), async { + let stream = agent + .reply( + Message::user().with_text("test"), + SessionConfig { + id: session.id.clone(), + schedule_id: None, + max_turns: Some(1), + retry_config: None, + }, + None, + ) + .await + .unwrap(); + + use futures::StreamExt; + tokio::pin!(stream); + while stream.next().await.is_some() {} + }) + .await; + + let captured = capture.get_captured(); + assert_eq!(captured.len(), 2); + assert_eq!(captured[0], Some(session.id.clone())); + assert_eq!(captured[1], Some(session.id.clone())); + + SessionManager::delete_session(&session.id).await.unwrap(); +}