diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 62d1e55e985e..46e7a1b05b8a 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -831,6 +831,13 @@ enum Command { /// Authentication token for both Basic Auth (password) and Bearer token #[arg(long, help = "Authentication token to secure the web interface")] auth_token: Option, + + /// Allow running without authentication when exposed on the network (unsafe) + #[arg( + long, + help = "Skip auth requirement when exposed on the network (unsafe)" + )] + no_auth: bool, }, /// Terminal-integrated session (one session per terminal) @@ -1507,7 +1514,8 @@ pub async fn cli() -> anyhow::Result<()> { host, open, auth_token, - }) => crate::commands::web::handle_web(port, host, open, auth_token).await, + no_auth, + }) => crate::commands::web::handle_web(port, host, open, auth_token, no_auth).await, Some(Command::Term { command }) => handle_term_subcommand(command).await, None => handle_default_session().await, } diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index c6aa47f6b82e..d55dd28c4624 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -19,7 +19,7 @@ use goose::session::session_manager::SessionType; use goose::session::SessionManager; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{net::SocketAddr, sync::Arc}; +use std::{net::ToSocketAddrs, sync::Arc}; use tokio::sync::{Mutex, RwLock}; use tower_http::cors::{AllowOrigin, Any, CorsLayer}; use tracing::error; @@ -125,14 +125,28 @@ async fn auth_middleware( Ok(response) } -pub async fn handle_web( - port: u16, - host: String, - open: bool, - auth_token: Option, -) -> Result<()> { - crate::logging::setup_logging(Some("goose-web"), None)?; +fn is_loopback_address(host: &str) -> bool { + (host, 0) + .to_socket_addrs() + .map(|mut addrs| addrs.any(|addr| addr.ip().is_loopback())) + .unwrap_or(false) +} +fn validate_network_auth(host: &str, auth_token: &Option, no_auth: bool) { + if !is_loopback_address(host) && auth_token.is_none() && !no_auth { + eprintln!( + "Error: --auth-token is required when the server is exposed on the network ({}).", + host + ); + eprintln!( + "For security, use --auth-token or bind to a local address (e.g., localhost)." + ); + eprintln!("To skip this check, use --no-auth (unsafe)."); + std::process::exit(1); + } +} + +fn get_provider_and_model() -> (String, String) { let config = goose::config::Config::global(); let provider_name: String = match config.get_goose_provider() { @@ -151,7 +165,11 @@ pub async fn handle_web( } }; - let model_config = goose::model::ModelConfig::new(&model)?; + (provider_name, model) +} + +async fn create_agent(provider_name: &str, model: &str) -> Result { + let model_config = goose::model::ModelConfig::new(model)?; let init_session = SessionManager::create_session( std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), @@ -161,7 +179,7 @@ pub async fn handle_web( .await?; let agent = Agent::new(); - let provider = goose::providers::create(&provider_name, model_config).await?; + let provider = goose::providers::create(provider_name, model_config).await?; agent.update_provider(provider, &init_session.id).await?; let enabled_configs = goose::config::get_enabled_extensions(); @@ -171,20 +189,11 @@ pub async fn handle_web( } } - let ws_token = if auth_token.is_none() { - uuid::Uuid::new_v4().to_string() - } else { - String::new() - }; - - let state = AppState { - agent: Arc::new(agent), - cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())), - auth_token: auth_token.clone(), - ws_token, - }; + Ok(agent) +} - let cors_layer = if auth_token.is_none() { +fn build_cors_layer(auth_token: &Option, host: &str, port: u16) -> CorsLayer { + if auth_token.is_none() { let allowed_origins = [ "http://localhost:3000".parse().unwrap(), "http://127.0.0.1:3000".parse().unwrap(), @@ -199,9 +208,11 @@ pub async fn handle_web( .allow_origin(Any) .allow_methods(Any) .allow_headers(Any) - }; + } +} - let app = Router::new() +fn build_router(state: AppState, cors_layer: CorsLayer) -> Router { + Router::new() .route("/", get(serve_index)) .route("/session/{session_name}", get(serve_session)) .route("/ws", get(websocket_handler)) @@ -214,9 +225,42 @@ pub async fn handle_web( auth_middleware, )) .layer(cors_layer) - .with_state(state); + .with_state(state) +} - let addr: SocketAddr = format!("{}:{}", host, port).parse()?; +pub async fn handle_web( + port: u16, + host: String, + open: bool, + auth_token: Option, + no_auth: bool, +) -> Result<()> { + validate_network_auth(&host, &auth_token, no_auth); + crate::logging::setup_logging(Some("goose-web"), None)?; + + let (provider_name, model) = get_provider_and_model(); + let agent = create_agent(&provider_name, &model).await?; + + let ws_token = if auth_token.is_none() { + uuid::Uuid::new_v4().to_string() + } else { + String::new() + }; + + let state = AppState { + agent: Arc::new(agent), + cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())), + auth_token: auth_token.clone(), + ws_token, + }; + + let cors_layer = build_cors_layer(&auth_token, &host, port); + let app = build_router(state, cors_layer); + + let addr = (host.as_str(), port) + .to_socket_addrs()? + .next() + .ok_or_else(|| anyhow::anyhow!("Could not resolve address: {}", host))?; println!("\n🪿 Starting goose web server"); println!(" Provider: {} | Model: {}", provider_name, model); @@ -373,98 +417,7 @@ async fn handle_socket(socket: WebSocket, state: AppState) { if let Ok(msg) = msg { match msg { Message::Text(text) => { - match serde_json::from_str::(&text.to_string()) { - Ok(WebSocketMessage::Message { - content, - session_id, - .. - }) => { - let sender_clone = sender.clone(); - let agent = state.agent.clone(); - let session_id_clone = session_id.clone(); - - let task_handle = tokio::spawn(async move { - let result = process_message_streaming( - &agent, - session_id_clone, - content, - sender_clone, - ) - .await; - - if let Err(e) = result { - error!("Error processing message: {}", e); - } - }); - - { - let mut cancellations = state.cancellations.write().await; - cancellations - .insert(session_id.clone(), task_handle.abort_handle()); - } - - // Handle task completion and cleanup - let sender_for_abort = sender.clone(); - let session_id_for_cleanup = session_id.clone(); - let cancellations_for_cleanup = state.cancellations.clone(); - - tokio::spawn(async move { - match task_handle.await { - Ok(_) => {} - Err(e) if e.is_cancelled() => { - let mut sender = sender_for_abort.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string( - &WebSocketMessage::Cancelled { - message: "Operation cancelled by user" - .to_string(), - }, - ) - .unwrap() - .into(), - )) - .await; - } - Err(e) => { - error!("Task error: {}", e); - } - } - - let mut cancellations = cancellations_for_cleanup.write().await; - cancellations.remove(&session_id_for_cleanup); - }); - } - Ok(WebSocketMessage::Cancel { session_id }) => { - // Cancel the active operation for this session - let abort_handle = { - let mut cancellations = state.cancellations.write().await; - cancellations.remove(&session_id) - }; - - if let Some(handle) = abort_handle { - handle.abort(); - - // Send cancellation confirmation - let mut sender = sender.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string(&WebSocketMessage::Cancelled { - message: "Operation cancelled".to_string(), - }) - .unwrap() - .into(), - )) - .await; - } - } - Ok(_) => { - // Ignore other message types - } - Err(e) => { - error!("Failed to parse WebSocket message: {}", e); - } - } + handle_text_message(&text.to_string(), &sender, &state).await; } Message::Close(_) => break, _ => {} @@ -475,15 +428,101 @@ async fn handle_socket(socket: WebSocket, state: AppState) { } } +async fn handle_text_message( + text: &str, + sender: &Arc>>, + state: &AppState, +) { + match serde_json::from_str::(text) { + Ok(WebSocketMessage::Message { + content, + session_id, + .. + }) => { + handle_user_message(content, session_id, sender.clone(), state).await; + } + Ok(WebSocketMessage::Cancel { session_id }) => { + handle_cancel_message(session_id, sender, state).await; + } + Ok(_) => {} + Err(e) => { + error!("Failed to parse WebSocket message: {}", e); + } + } +} + +async fn handle_user_message( + content: String, + session_id: String, + sender: Arc>>, + state: &AppState, +) { + let agent = state.agent.clone(); + let session_id_clone = session_id.clone(); + + let task_handle = tokio::spawn(async move { + let result = process_message_streaming(&agent, session_id_clone, content, sender).await; + + if let Err(e) = result { + error!("Error processing message: {}", e); + } + }); + + { + let mut cancellations = state.cancellations.write().await; + cancellations.insert(session_id.clone(), task_handle.abort_handle()); + } + + let cancellations_for_cleanup = state.cancellations.clone(); + let session_id_for_cleanup = session_id; + + tokio::spawn(async move { + if let Err(e) = task_handle.await { + if e.is_cancelled() { + tracing::debug!("Task was cancelled"); + } else { + error!("Task error: {}", e); + } + } + + let mut cancellations = cancellations_for_cleanup.write().await; + cancellations.remove(&session_id_for_cleanup); + }); +} + +async fn handle_cancel_message( + session_id: String, + sender: &Arc>>, + state: &AppState, +) { + let abort_handle = { + let mut cancellations = state.cancellations.write().await; + cancellations.remove(&session_id) + }; + + if let Some(handle) = abort_handle { + handle.abort(); + + let mut sender = sender.lock().await; + let _ = sender + .send(Message::Text( + serde_json::to_string(&WebSocketMessage::Cancelled { + message: "Operation cancelled".to_string(), + }) + .unwrap() + .into(), + )) + .await; + } +} + async fn process_message_streaming( agent: &Agent, session_id: String, content: String, sender: Arc>>, ) -> Result<()> { - use futures::StreamExt; use goose::agents::SessionConfig; - use goose::conversation::message::MessageContent; let user_message = GooseMessage::user().with_text(content.clone()); @@ -521,93 +560,12 @@ async fn process_message_streaming( while let Some(result) = stream.next().await { match result { Ok(AgentEvent::Message(message)) => { - for content in &message.content { - match content { - MessageContent::Text(text) => { - let mut sender = sender.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string(&WebSocketMessage::Response { - content: text.text.clone(), - role: "assistant".to_string(), - timestamp: chrono::Utc::now().timestamp_millis(), - }) - .unwrap() - .into(), - )) - .await; - } - MessageContent::ToolRequest(req) => { - let mut sender = sender.lock().await; - if let Ok(tool_call) = &req.tool_call { - let _ = sender - .send(Message::Text( - serde_json::to_string( - &WebSocketMessage::ToolRequest { - id: req.id.clone(), - tool_name: tool_call.name.to_string(), - arguments: Value::from( - tool_call.arguments.clone(), - ), - }, - ) - .unwrap() - .into(), - )) - .await; - } - } - MessageContent::ToolResponse(_resp) => {} - MessageContent::ToolConfirmationRequest(confirmation) => { - let mut sender = sender.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string( - &WebSocketMessage::ToolConfirmation { - id: confirmation.id.clone(), - tool_name: confirmation - .tool_name - .to_string() - .clone(), - arguments: Value::from( - confirmation.arguments.clone(), - ), - needs_confirmation: true, - }, - ) - .unwrap() - .into(), - )) - .await; - - agent.handle_confirmation( - confirmation.id.clone(), - goose::permission::PermissionConfirmation { - principal_type: goose::permission::permission_confirmation::PrincipalType::Tool, - permission: goose::permission::Permission::AllowOnce, - } - ).await; - } - MessageContent::Thinking(thinking) => { - let mut sender = sender.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string(&WebSocketMessage::Thinking { - message: thinking.thinking.clone(), - }) - .unwrap() - .into(), - )) - .await; - } - _ => {} - } - } + process_agent_message(&message, &sender, agent).await; } - Ok(AgentEvent::HistoryReplaced(_new_messages)) => { + Ok(AgentEvent::HistoryReplaced(_)) => { tracing::info!("History replaced, compacting happened in reply"); } - Ok(AgentEvent::McpNotification(_notification)) => { + Ok(AgentEvent::McpNotification(_)) => { tracing::info!("Received MCP notification in web interface"); } Ok(AgentEvent::ModelChange { model, mode }) => { @@ -615,16 +573,7 @@ async fn process_message_streaming( } Err(e) => { error!("Error in message stream: {}", e); - let mut sender = sender.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string(&WebSocketMessage::Error { - message: format!("Error: {}", e), - }) - .unwrap() - .into(), - )) - .await; + send_error(&sender, &format!("Error: {}", e)).await; break; } } @@ -632,16 +581,7 @@ async fn process_message_streaming( } Err(e) => { error!("Error calling agent: {}", e); - let mut sender = sender.lock().await; - let _ = sender - .send(Message::Text( - serde_json::to_string(&WebSocketMessage::Error { - message: format!("Error: {}", e), - }) - .unwrap() - .into(), - )) - .await; + send_error(&sender, &format!("Error: {}", e)).await; } } @@ -658,3 +598,104 @@ async fn process_message_streaming( Ok(()) } + +async fn process_agent_message( + message: &GooseMessage, + sender: &Arc>>, + agent: &Agent, +) { + use goose::conversation::message::MessageContent; + + for content in &message.content { + match content { + MessageContent::Text(text) => { + let mut sender = sender.lock().await; + let _ = sender + .send(Message::Text( + serde_json::to_string(&WebSocketMessage::Response { + content: text.text.clone(), + role: "assistant".to_string(), + timestamp: chrono::Utc::now().timestamp_millis(), + }) + .unwrap() + .into(), + )) + .await; + } + MessageContent::ToolRequest(req) => { + let mut sender = sender.lock().await; + if let Ok(tool_call) = &req.tool_call { + let _ = sender + .send(Message::Text( + serde_json::to_string(&WebSocketMessage::ToolRequest { + id: req.id.clone(), + tool_name: tool_call.name.to_string(), + arguments: Value::from(tool_call.arguments.clone()), + }) + .unwrap() + .into(), + )) + .await; + } + } + MessageContent::ToolResponse(_) => {} + MessageContent::ToolConfirmationRequest(confirmation) => { + { + let mut sender = sender.lock().await; + let _ = sender + .send(Message::Text( + serde_json::to_string(&WebSocketMessage::ToolConfirmation { + id: confirmation.id.clone(), + tool_name: confirmation.tool_name.to_string(), + arguments: Value::from(confirmation.arguments.clone()), + needs_confirmation: true, + }) + .unwrap() + .into(), + )) + .await; + } + + agent + .handle_confirmation( + confirmation.id.clone(), + goose::permission::PermissionConfirmation { + principal_type: + goose::permission::permission_confirmation::PrincipalType::Tool, + permission: goose::permission::Permission::AllowOnce, + }, + ) + .await; + } + MessageContent::Thinking(thinking) => { + let mut sender = sender.lock().await; + let _ = sender + .send(Message::Text( + serde_json::to_string(&WebSocketMessage::Thinking { + message: thinking.thinking.clone(), + }) + .unwrap() + .into(), + )) + .await; + } + _ => {} + } + } +} + +async fn send_error( + sender: &Arc>>, + message: &str, +) { + let mut sender = sender.lock().await; + let _ = sender + .send(Message::Text( + serde_json::to_string(&WebSocketMessage::Error { + message: message.to_string(), + }) + .unwrap() + .into(), + )) + .await; +}