-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Cors and token #5850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cors and token #5850
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ use axum::response::Redirect; | |
| use axum::{ | ||
| extract::{ | ||
| ws::{Message, WebSocket, WebSocketUpgrade}, | ||
| Request, State, | ||
| Query, Request, State, | ||
| }, | ||
| http::StatusCode, | ||
| middleware::{self, Next}, | ||
|
|
@@ -21,7 +21,7 @@ use serde::{Deserialize, Serialize}; | |
| use serde_json::Value; | ||
| use std::{net::SocketAddr, sync::Arc}; | ||
| use tokio::sync::{Mutex, RwLock}; | ||
| use tower_http::cors::{Any, CorsLayer}; | ||
| use tower_http::cors::{AllowOrigin, Any, CorsLayer}; | ||
| use tracing::error; | ||
| use webbrowser; | ||
|
|
||
|
|
@@ -32,6 +32,7 @@ struct AppState { | |
| agent: Arc<Agent>, | ||
| cancellations: CancellationStore, | ||
| auth_token: Option<String>, | ||
| ws_token: String, | ||
| } | ||
|
|
||
| #[derive(Serialize, Deserialize)] | ||
|
|
@@ -87,17 +88,14 @@ async fn auth_middleware( | |
| req: Request, | ||
| next: Next, | ||
| ) -> Result<Response, StatusCode> { | ||
| // Skip auth for health check | ||
| if req.uri().path() == "/api/health" { | ||
| return Ok(next.run(req).await); | ||
| } | ||
|
|
||
| // If no auth token is configured, skip authentication entirely | ||
| let Some(ref expected_token) = state.auth_token else { | ||
| return Ok(next.run(req).await); | ||
| }; | ||
|
|
||
| // Check for Bearer token first | ||
| if let Some(auth_header) = req.headers().get("authorization") { | ||
| if let Ok(auth_str) = auth_header.to_str() { | ||
| if let Some(token) = auth_str.strip_prefix("Bearer ") { | ||
|
|
@@ -106,7 +104,6 @@ async fn auth_middleware( | |
| } | ||
| } | ||
|
|
||
| // Check for Basic auth (password-only, ignore username) | ||
| if let Some(basic_token) = auth_str.strip_prefix("Basic ") { | ||
| if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(basic_token) { | ||
| if let Ok(credentials) = String::from_utf8(decoded) { | ||
|
|
@@ -119,7 +116,6 @@ async fn auth_middleware( | |
| } | ||
| } | ||
|
|
||
| // Authentication failed - return 401 with WWW-Authenticate header | ||
| let mut response = Response::new("Authentication required".into()); | ||
| *response.status_mut() = StatusCode::UNAUTHORIZED; | ||
| response.headers_mut().insert( | ||
|
|
@@ -135,7 +131,6 @@ pub async fn handle_web( | |
| open: bool, | ||
| auth_token: Option<String>, | ||
| ) -> Result<()> { | ||
| // Setup logging | ||
| crate::logging::setup_logging(Some("goose-web"), None)?; | ||
|
|
||
| let config = goose::config::Config::global(); | ||
|
|
@@ -176,10 +171,34 @@ pub async fn handle_web( | |
| } | ||
| } | ||
|
|
||
| let ws_token = if auth_token.is_none() { | ||
| uuid::Uuid::new_v4().to_string() | ||
| } else { | ||
| String::new() | ||
| }; | ||
|
|
||
| let state = AppState { | ||
| agent: Arc::new(agent), | ||
| cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())), | ||
| auth_token, | ||
| auth_token: auth_token.clone(), | ||
| ws_token, | ||
| }; | ||
|
|
||
| let cors_layer = if auth_token.is_none() { | ||
| let allowed_origins = [ | ||
| "http://localhost:3000".parse().unwrap(), | ||
| "http://127.0.0.1:3000".parse().unwrap(), | ||
| format!("http://{}:{}", host, port).parse().unwrap(), | ||
| ]; | ||
| CorsLayer::new() | ||
| .allow_origin(AllowOrigin::list(allowed_origins)) | ||
| .allow_methods(Any) | ||
| .allow_headers(Any) | ||
| } else { | ||
| CorsLayer::new() | ||
| .allow_origin(Any) | ||
| .allow_methods(Any) | ||
| .allow_headers(Any) | ||
| }; | ||
|
|
||
| let app = Router::new() | ||
|
|
@@ -194,12 +213,7 @@ pub async fn handle_web( | |
| state.clone(), | ||
| auth_middleware, | ||
| )) | ||
| .layer( | ||
| CorsLayer::new() | ||
| .allow_origin(Any) | ||
| .allow_methods(Any) | ||
| .allow_headers(Any), | ||
| ) | ||
| .layer(cors_layer) | ||
| .with_state(state); | ||
|
|
||
| let addr: SocketAddr = format!("{}:{}", host, port).parse()?; | ||
|
|
@@ -214,7 +228,6 @@ pub async fn handle_web( | |
| println!(" Press Ctrl+C to stop\n"); | ||
|
|
||
| if open { | ||
| // Open browser | ||
| let url = format!("http://{}", addr); | ||
| if let Err(e) = webbrowser::open(&url) { | ||
| eprintln!("Failed to open browser: {}", e); | ||
|
|
@@ -241,14 +254,15 @@ async fn serve_index() -> Result<Redirect, (http::StatusCode, String)> { | |
|
|
||
| async fn serve_session( | ||
| axum::extract::Path(session_name): axum::extract::Path<String>, | ||
| State(state): State<AppState>, | ||
| ) -> Html<String> { | ||
| let html = include_str!("../../static/index.html"); | ||
| // Inject the session name into the HTML so JavaScript can use it | ||
| let html_with_session = html.replace( | ||
| "<script src=\"/static/script.js\"></script>", | ||
| &format!( | ||
| "<script>window.GOOSE_SESSION_NAME = '{}';</script>\n <script src=\"/static/script.js\"></script>", | ||
| session_name | ||
| "<script>window.GOOSE_SESSION_NAME = '{}'; window.GOOSE_WS_TOKEN = '{}';</script>\n <script src=\"/static/script.js\"></script>", | ||
| session_name, | ||
| state.ws_token | ||
| ) | ||
| ); | ||
| Html(html_with_session) | ||
|
|
@@ -324,11 +338,25 @@ async fn get_session( | |
| } | ||
| } | ||
|
|
||
| #[derive(Deserialize)] | ||
| struct WsQuery { | ||
| token: Option<String>, | ||
| } | ||
|
|
||
| async fn websocket_handler( | ||
| ws: WebSocketUpgrade, | ||
| State(state): State<AppState>, | ||
| ) -> impl IntoResponse { | ||
| ws.on_upgrade(|socket| handle_socket(socket, state)) | ||
| Query(query): Query<WsQuery>, | ||
| ) -> Result<impl IntoResponse, StatusCode> { | ||
| if state.auth_token.is_none() { | ||
| let provided_token = query.token.as_deref().unwrap_or(""); | ||
| if provided_token != state.ws_token { | ||
| tracing::warn!("WebSocket connection rejected: invalid token"); | ||
| return Err(StatusCode::FORBIDDEN); | ||
| } | ||
| } | ||
|
Comment on lines
+351
to
+357
|
||
|
|
||
| Ok(ws.on_upgrade(|socket| handle_socket(socket, state))) | ||
| } | ||
|
|
||
| async fn handle_socket(socket: WebSocket, state: AppState) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is actually a helpful and IMO necessary comment, I don't like just blindly deleting oneline coments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interestingly enough, goose deleted this comment. which it probably did because I have general settings that it should delete useless comments.
I would still argue that the comment says the same thing as the code below though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it never ever wants to with mine - I guess it tends to follow "house style" over what system prompt says! @DOsinga sometimes I wonder if we need to put some logic in the editor tool for it to to return an error if it detects single line comment (error can be "this is an inane comment, please either remove and try again or consider if it is really needed) or something?