Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,9 @@ impl Agent {
}
}
Err(e) => {
yield AgentEvent::Message(
Message::assistant().with_text(
format!("Ran into this error trying to compact: {e}.\n\nPlease try again or create a new session")
)
);
yield AgentEvent::Message(Message::assistant().with_text(
format!("Ran into this error trying to compact: {e}.\n\nPlease try again or create a new session")
));
}
}
}))
Expand Down
128 changes: 122 additions & 6 deletions crates/goose/src/agents/mcp_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::agents::types::SharedProvider;
use crate::session_context::SESSION_ID_HEADER;
use rmcp::model::{Content, ErrorCode, JsonObject};
/// MCP client implementation for Goose
use rmcp::{
Expand Down Expand Up @@ -334,7 +335,7 @@ impl McpClientTrait for McpClient {
ClientRequest::ListResourcesRequest(ListResourcesRequest {
params: Some(PaginatedRequestParam { cursor }),
method: Default::default(),
extensions: Default::default(),
extensions: inject_session_into_extensions(Default::default()),
}),
cancel_token,
)
Expand All @@ -358,7 +359,7 @@ impl McpClientTrait for McpClient {
uri: uri.to_string(),
},
method: Default::default(),
extensions: Default::default(),
extensions: inject_session_into_extensions(Default::default()),
}),
cancel_token,
)
Expand All @@ -380,7 +381,7 @@ impl McpClientTrait for McpClient {
ClientRequest::ListToolsRequest(ListToolsRequest {
params: Some(PaginatedRequestParam { cursor }),
method: Default::default(),
extensions: Default::default(),
extensions: inject_session_into_extensions(Default::default()),
}),
cancel_token,
)
Expand All @@ -406,7 +407,7 @@ impl McpClientTrait for McpClient {
arguments,
},
method: Default::default(),
extensions: Default::default(),
extensions: inject_session_into_extensions(Default::default()),
}),
cancel_token,
)
Expand All @@ -428,7 +429,7 @@ impl McpClientTrait for McpClient {
ClientRequest::ListPromptsRequest(ListPromptsRequest {
params: Some(PaginatedRequestParam { cursor }),
method: Default::default(),
extensions: Default::default(),
extensions: inject_session_into_extensions(Default::default()),
}),
cancel_token,
)
Expand Down Expand Up @@ -458,7 +459,7 @@ impl McpClientTrait for McpClient {
arguments,
},
method: Default::default(),
extensions: Default::default(),
extensions: inject_session_into_extensions(Default::default()),
}),
cancel_token,
)
Expand All @@ -476,3 +477,118 @@ impl McpClientTrait for McpClient {
rx
}
}

/// Replaces session ID, case-insensitively, in Extensions._meta.
fn inject_session_into_extensions(
mut extensions: rmcp::model::Extensions,
) -> rmcp::model::Extensions {
use rmcp::model::Meta;

if let Some(session_id) = crate::session_context::current_session_id() {
let mut meta_map = extensions
.get::<Meta>()
.map(|meta| meta.0.clone())
.unwrap_or_default();

// JsonObject is case-sensitive, so we use retain for case-insensitive removal
meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER));

meta_map.insert(SESSION_ID_HEADER.to_string(), Value::String(session_id));

extensions.insert(Meta(meta_map));
}

extensions
}

#[cfg(test)]
mod tests {
use super::*;
use rmcp::model::Meta;

#[tokio::test]
async fn test_session_id_in_mcp_meta() {
use serde_json::json;

let session_id = "test-session-789";
crate::session_context::with_session_id(Some(session_id.to_string()), async {
let extensions = inject_session_into_extensions(Default::default());
let meta = extensions.get::<Meta>().unwrap();

assert_eq!(
&meta.0,
json!({
SESSION_ID_HEADER: session_id
})
.as_object()
.unwrap()
);
})
.await;
}

#[tokio::test]
async fn test_no_session_id_in_mcp_when_absent() {
let extensions = inject_session_into_extensions(Default::default());
let meta = extensions.get::<Meta>();

assert!(meta.is_none());
}

#[tokio::test]
async fn test_all_mcp_operations_include_session() {
use serde_json::json;

let session_id = "consistent-session-id";
crate::session_context::with_session_id(Some(session_id.to_string()), async {
let ext1 = inject_session_into_extensions(Default::default());
let ext2 = inject_session_into_extensions(Default::default());
let ext3 = inject_session_into_extensions(Default::default());

for ext in [&ext1, &ext2, &ext3] {
assert_eq!(
&ext.get::<Meta>().unwrap().0,
json!({
SESSION_ID_HEADER: session_id
})
.as_object()
.unwrap()
);
}
})
.await;
}

#[tokio::test]
async fn test_session_id_case_insensitive_replacement() {
use rmcp::model::{Extensions, Meta};
use serde_json::{from_value, json};

let session_id = "new-session-id";
crate::session_context::with_session_id(Some(session_id.to_string()), async {
let mut extensions = Extensions::new();
extensions.insert(
from_value::<Meta>(json!({
"GOOSE-SESSION-ID": "old-session-1",
"Goose-Session-Id": "old-session-2",
"other-key": "preserve-me"
}))
.unwrap(),
);

let extensions = inject_session_into_extensions(extensions);
let meta = extensions.get::<Meta>().unwrap();

assert_eq!(
&meta.0,
json!({
SESSION_ID_HEADER: session_id,
"other-key": "preserve-me"
})
.as_object()
.unwrap()
);
})
.await;
}
}
13 changes: 9 additions & 4 deletions crates/goose/src/agents/subagent_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,15 @@ fn get_agent_messages(
} else {
None
};
let mut stream = agent
.reply(conversation.clone(), session_config, None)
.await
.map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?;

let session_id = session_config.as_ref().map(|s| s.id.clone());
let mut stream = crate::session_context::with_session_id(session_id, async {
agent
.reply(conversation.clone(), session_config, None)
.await
})
.await
.map_err(|e| anyhow!("Failed to get reply from agent: {}", e))?;
while let Some(message_result) = stream.next().await {
match message_result {
Ok(AgentEvent::Message(msg)) => conversation.push(msg),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod scheduler_factory;
pub mod scheduler_trait;
pub mod security;
pub mod session;
pub mod session_context;
pub mod token_counter;
pub mod tool_inspection;
pub mod tool_monitor;
Expand Down
57 changes: 57 additions & 0 deletions crates/goose/src/providers/api_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::session_context::SESSION_ID_HEADER;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{
Expand Down Expand Up @@ -369,6 +370,10 @@ impl<'a> ApiRequestBuilder<'a> {
let mut request = request_builder(url, &self.client.client);
request = request.headers(self.headers.clone());

if let Some(session_id) = crate::session_context::current_session_id() {
request = request.header(SESSION_ID_HEADER, session_id);
}

request = match &self.client.auth {
AuthMethod::BearerToken(token) => {
request.header("Authorization", format!("Bearer {}", token))
Expand Down Expand Up @@ -398,3 +403,55 @@ impl fmt::Debug for ApiClient {
.finish_non_exhaustive()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_session_id_header_injection() {
let client = ApiClient::new(
"http://localhost:8080".to_string(),
AuthMethod::BearerToken("test-token".to_string()),
)
.unwrap();

// Execute request within session context
crate::session_context::with_session_id(Some("test-session-456".to_string()), async {
let builder = client.request("/test");
let request = builder
.send_request(|url, client| client.get(url))
.await
.unwrap();

let headers = request.build().unwrap().headers().clone();

assert!(headers.contains_key(SESSION_ID_HEADER));
assert_eq!(
headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(),
"test-session-456"
);
})
.await;
}

#[tokio::test]
async fn test_no_session_id_header_when_absent() {
let client = ApiClient::new(
"http://localhost:8080".to_string(),
AuthMethod::BearerToken("test-token".to_string()),
)
.unwrap();

// Build a request without session context
let builder = client.request("/test");
let request = builder
.send_request(|url, client| client.get(url))
.await
.unwrap();

let headers = request.build().unwrap().headers().clone();

assert!(!headers.contains_key(SESSION_ID_HEADER));
}
}
2 changes: 1 addition & 1 deletion crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod anthropic;
mod api_client;
pub mod api_client;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed to expose this for tests in order to write them without ENV dependencies

pub mod azure;
pub mod azureauth;
pub mod base;
Expand Down
14 changes: 14 additions & 0 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ impl OpenAiProvider {
})
}

#[doc(hidden)]
pub fn new(api_client: ApiClient, model: ModelConfig) -> Self {
Self {
api_client,
base_path: "v1/chat/completions".to_string(),
organization: None,
project: None,
model,
custom_headers: None,
supports_streaming: true,
name: Self::metadata().name,
}
}

pub fn from_custom_config(
model: ModelConfig,
config: DeclarativeProviderConfig,
Expand Down
10 changes: 7 additions & 3 deletions crates/goose/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1216,9 +1216,13 @@ async fn run_scheduled_job_internal(
retry_config: None,
};

match agent
.reply(conversation.clone(), Some(session_config.clone()), None)
.await
let session_id = Some(session_config.id.clone());
match crate::session_context::with_session_id(session_id, async {
agent
.reply(conversation.clone(), Some(session_config.clone()), None)
.await
})
.await
{
Ok(mut stream) => {
use futures::StreamExt;
Expand Down
Loading