Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-acp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConf
timeout: None,
bundled: Some(false),
available_tools: vec![],
allowed_headers: vec![],
}),
McpServer::Sse(_) => Err("SSE is unsupported, migrate to streamable_http".to_string()),
_ => Err("Unknown MCP server type".to_string()),
Expand Down
1 change: 1 addition & 0 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ fn configure_streamable_http_extension() -> anyhow::Result<()> {
timeout: Some(timeout),
bundled: None,
available_tools: Vec::new(),
allowed_headers: Vec::new(),
},
});

Expand Down
288 changes: 197 additions & 91 deletions crates/goose-cli/src/commands/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use axum::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, Request, State,
},
http::{StatusCode, Uri},
http::{HeaderMap, StatusCode, Uri},
middleware::{self, Next},
response::{Html, IntoResponse, Response},
routing::get,
Expand Down Expand Up @@ -41,7 +41,10 @@ enum WebSocketMessage {
Message {
content: String,
session_id: String,
#[serde(default)]
timestamp: i64,
#[serde(default)]
headers: Option<std::collections::HashMap<String, String>>,
},
#[serde(rename = "cancel")]
Cancel { session_id: String },
Expand Down Expand Up @@ -244,7 +247,8 @@ pub async fn handle_web(
let agent = create_agent(&provider_name, &model).await?;

let ws_token = if auth_token.is_none() {
uuid::Uuid::new_v4().to_string()
// uuid::Uuid::new_v4().to_string()
String::new() // Disable WS token for now
} else {
String::new()
};
Expand Down Expand Up @@ -413,6 +417,7 @@ async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(query): Query<WsQuery>,
headers: HeaderMap,
) -> Result<impl IntoResponse, StatusCode> {
if state.auth_token.is_none() {
let provided_token = query.token.as_deref().unwrap_or("");
Expand All @@ -422,18 +427,205 @@ async fn websocket_handler(
}
}

Ok(ws.on_upgrade(|socket| handle_socket(socket, state)))
eprintln!("[WEBSOCKET] WebSocket upgrade request received with {} headers", headers.len());

// Extract headers from HTTP request and convert to HashMap
let mut request_headers = std::collections::HashMap::new();
for (key, value) in headers.iter() {
if let Ok(value_str) = value.to_str() {
let key_str = key.as_str().to_string();
request_headers.insert(key_str, value_str.to_string());
eprintln!("[WEBSOCKET] Header from upgrade request: {} = {}", key.as_str(), value_str);
}
}

// Store headers in a shared location that can be accessed when messages arrive
// We'll store them per connection and associate with session when message arrives
let headers_arc = Arc::new(Mutex::new(request_headers));

Ok(ws.on_upgrade(move |socket| handle_socket(socket, state, headers_arc)))
}

async fn handle_socket(socket: WebSocket, state: AppState) {
async fn handle_socket(
socket: WebSocket,
state: AppState,
request_headers: Arc<Mutex<std::collections::HashMap<String, String>>>,
) {
let (sender, mut receiver) = socket.split();
let sender = Arc::new(Mutex::new(sender));

// Log headers from upgrade request
let headers = request_headers.lock().await;
eprintln!("[WEBSOCKET] Connection established with {} headers from upgrade request", headers.len());
if !headers.is_empty() {
eprintln!("[WEBSOCKET] Headers from upgrade: {:?}", headers);
}
drop(headers);

while let Some(msg) = receiver.next().await {
if let Ok(msg) = msg {
match msg {
Message::Text(text) => {
handle_text_message(&text.to_string(), &sender, &state).await;
let text_str = text.to_string();
eprintln!("[WEBSOCKET] Raw message received: {}", text_str);
match serde_json::from_str::<WebSocketMessage>(&text_str) {
Ok(WebSocketMessage::Message {
content,
session_id,
headers: _json_headers,
..
}) => {
eprintln!("[WEBSOCKET] Parsed message - session_id: {}, has_json_headers: {}",
session_id, _json_headers.is_some());

// Ensure session exists
if goose::session::SessionManager::instance().get_session(&session_id, false).await.is_err() {
eprintln!("[WEBSOCKET] Session {} not found, creating new session", session_id);
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
if let Err(e) = goose::session::SessionManager::instance().create_session_with_id(
&session_id,
cwd,
"Web Session".to_string(),
goose::session::SessionType::User
).await {
error!("Failed to auto-create session {}: {}", session_id, e);
}
}

// Get headers from HTTP upgrade request (preferred) or from JSON message (fallback)
let headers_to_store = {
let upgrade_headers = request_headers.lock().await;
if !upgrade_headers.is_empty() {
eprintln!("[WEBSOCKET] Using headers from HTTP upgrade request: {:?}", upgrade_headers);
Some(upgrade_headers.clone())
} else if let Some(ref json_headers) = _json_headers {
eprintln!("[WEBSOCKET] Using headers from JSON message: {:?}", json_headers);
Some(json_headers.clone())
} else {
None
}
};

if let Some(ref headers) = headers_to_store {
eprintln!("[WEBSOCKET] Received headers for session {}: {:?}", session_id, headers);
tracing::info!("[WEBSOCKET] Received headers for session {}: {:?}", session_id, headers);
if let Err(e) = goose::session::SessionManager::instance().update(&session_id)
.extension_data({
let mut ext_data = goose::session::SessionManager::instance().get_session(&session_id, false)
.await
.map(|s| s.extension_data)
.unwrap_or_default();
ext_data.set_extension_state(
"websocket_headers",
"v0",
serde_json::to_value(headers).unwrap_or_default(),
);
eprintln!("[WEBSOCKET] Stored headers in session extension_data: {:?}", headers);
tracing::info!("[WEBSOCKET] Stored headers in session extension_data: {:?}", headers);
ext_data
})
.apply()
.await
{
eprintln!("[WEBSOCKET] ERROR: Failed to store headers in session: {}", e);
error!("Failed to store headers in session: {}", e);
} else {
eprintln!("[WEBSOCKET] Successfully stored headers in session {}", session_id);
tracing::info!("[WEBSOCKET] Successfully stored headers in session {}", session_id);
}
} else {
eprintln!("[WEBSOCKET] No headers in websocket message for session {}", session_id);
tracing::debug!("[WEBSOCKET] No headers in websocket message for session {}", session_id);
}

let sender_clone = sender.clone();
let agent = state.agent.clone();
let session_id_clone = session_id.clone();

let task_handle = tokio::spawn(async move {
let result = process_message_streaming(
&agent,
session_id_clone,
content,
sender_clone,
)
.await;

if let Err(e) = result {
error!("Error processing message: {}", e);
}
});

{
let mut cancellations = state.cancellations.write().await;
cancellations
.insert(session_id.clone(), task_handle.abort_handle());
}

// Handle task completion and cleanup
let sender_for_abort = sender.clone();
let session_id_for_cleanup = session_id.clone();
let cancellations_for_cleanup = state.cancellations.clone();

tokio::spawn(async move {
match task_handle.await {
Ok(_) => {}
Err(e) if e.is_cancelled() => {
let mut sender = sender_for_abort.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(
&WebSocketMessage::Cancelled {
message: "Operation cancelled by user"
.to_string(),
},
)
.unwrap()
.into(),
))
.await;
}
Err(e) => {
error!("Task error: {}", e);
}
}

let mut cancellations = cancellations_for_cleanup.write().await;
cancellations.remove(&session_id_for_cleanup);
});
}
Ok(WebSocketMessage::Cancel { session_id }) => {
// Cancel the active operation for this session
let abort_handle = {
let mut cancellations = state.cancellations.write().await;
cancellations.remove(&session_id)
};

if let Some(handle) = abort_handle {
handle.abort();

// Send cancellation confirmation
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Cancelled {
message: "Operation cancelled".to_string(),
})
.unwrap()
.into(),
))
.await;
}
}
Ok(_) => {
// Ignore other message types
}
Err(e) => {
eprintln!("[WEBSOCKET] ERROR: Failed to parse WebSocket message: {}", e);
eprintln!("[WEBSOCKET] Raw message was: {}", text_str);
error!("Failed to parse WebSocket message: {}", e);
}
}
}
Message::Close(_) => break,
_ => {}
Expand All @@ -444,93 +636,7 @@ async fn handle_socket(socket: WebSocket, state: AppState) {
}
}

async fn handle_text_message(
text: &str,
sender: &Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
state: &AppState,
) {
match serde_json::from_str::<WebSocketMessage>(text) {
Ok(WebSocketMessage::Message {
content,
session_id,
..
}) => {
handle_user_message(content, session_id, sender.clone(), state).await;
}
Ok(WebSocketMessage::Cancel { session_id }) => {
handle_cancel_message(session_id, sender, state).await;
}
Ok(_) => {}
Err(e) => {
error!("Failed to parse WebSocket message: {}", e);
}
}
}

async fn handle_user_message(
content: String,
session_id: String,
sender: Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
state: &AppState,
) {
let agent = state.agent.clone();
let session_id_clone = session_id.clone();

let task_handle = tokio::spawn(async move {
let result = process_message_streaming(&agent, session_id_clone, content, sender).await;

if let Err(e) = result {
error!("Error processing message: {}", e);
}
});

{
let mut cancellations = state.cancellations.write().await;
cancellations.insert(session_id.clone(), task_handle.abort_handle());
}

let cancellations_for_cleanup = state.cancellations.clone();
let session_id_for_cleanup = session_id;

tokio::spawn(async move {
if let Err(e) = task_handle.await {
if e.is_cancelled() {
tracing::debug!("Task was cancelled");
} else {
error!("Task error: {}", e);
}
}

let mut cancellations = cancellations_for_cleanup.write().await;
cancellations.remove(&session_id_for_cleanup);
});
}

async fn handle_cancel_message(
session_id: String,
sender: &Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
state: &AppState,
) {
let abort_handle = {
let mut cancellations = state.cancellations.write().await;
cancellations.remove(&session_id)
};

if let Some(handle) = abort_handle {
handle.abort();

let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Cancelled {
message: "Operation cancelled".to_string(),
})
.unwrap()
.into(),
))
.await;
}
}

async fn process_message_streaming(
agent: &Agent,
Expand Down
1 change: 1 addition & 0 deletions crates/goose-cli/src/scenario_tests/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl McpClientTrait for MockClient {
name: &str,
arguments: Option<serde_json::Map<String, Value>>,
_cancel_token: CancellationToken,
_allowed_headers: Option<Vec<String>>,
) -> Result<CallToolResult, Error> {
if let Some(handler) = self.handlers.get(name) {
match handler(&Value::Object(arguments.unwrap_or_default())) {
Expand Down
1 change: 1 addition & 0 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl CliSession {
envs: Envs::new(HashMap::new()),
env_keys: Vec::new(),
headers: HashMap::new(),
allowed_headers: Vec::new(),
description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(),
timeout: Some(timeout),
bundled: None,
Expand Down
19 changes: 11 additions & 8 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,17 @@ impl Agent {
)))
} else {
// Clone the result to ensure no references to extension_manager are returned
let result = self
.extension_manager
.dispatch_tool_call(
&session.id,
tool_call.clone(),
cancellation_token.unwrap_or_default(),
)
.await;
// Wrap in session context so that session_id is available for header filtering
let result = crate::session_context::with_session_id(Some(session.id.clone()), async {
self
.extension_manager
.dispatch_tool_call(
&session.id,
tool_call.clone(),
cancellation_token.unwrap_or_default(),
)
.await
}).await;
result.unwrap_or_else(|e| {
crate::posthog::emit_error(
"tool_execution_failed",
Expand Down
Loading