diff --git a/src-tauri/src/core/server/mod.rs b/src-tauri/src/core/server/mod.rs index 368235778b..a2a35c0974 100644 --- a/src-tauri/src/core/server/mod.rs +++ b/src-tauri/src/core/server/mod.rs @@ -1,2 +1,4 @@ pub mod commands; pub mod proxy; +#[cfg(test)] +pub mod tests; diff --git a/src-tauri/src/core/server/proxy.rs b/src-tauri/src/core/server/proxy.rs index 45c7535b04..9ada8349b4 100644 --- a/src-tauri/src/core/server/proxy.rs +++ b/src-tauri/src/core/server/proxy.rs @@ -16,14 +16,14 @@ use crate::core::state::ServerHandle; /// Configuration for the proxy server #[derive(Clone)] -struct ProxyConfig { - prefix: String, - proxy_api_key: String, - trusted_hosts: Vec>, +pub struct ProxyConfig { + pub prefix: String, + pub proxy_api_key: String, + pub trusted_hosts: Vec>, } /// Determines the final destination path based on the original request path -fn get_destination_path(original_path: &str, prefix: &str) -> String { +pub fn get_destination_path(original_path: &str, prefix: &str) -> String { remove_prefix(original_path, prefix) } @@ -95,9 +95,7 @@ async fn proxy_request( }; if !is_trusted { - log::warn!( - "CORS preflight: Host '{host}' not trusted for path '{request_path}'" - ); + log::warn!("CORS preflight: Host '{host}' not trusted for path '{request_path}'"); return Ok(Response::builder() .status(StatusCode::FORBIDDEN) .body(Body::from("Host not allowed")) @@ -153,9 +151,7 @@ async fn proxy_request( }; if !headers_valid { - log::warn!( - "CORS preflight: Some requested headers not allowed: {requested_headers}" - ); + log::warn!("CORS preflight: Some requested headers not allowed: {requested_headers}"); return Ok(Response::builder() .status(StatusCode::FORBIDDEN) .body(Body::from("Headers not allowed")) @@ -180,9 +176,7 @@ async fn proxy_request( response = response.header("Access-Control-Allow-Origin", "*"); } - log::debug!( - "CORS preflight response: host_trusted={is_trusted}, origin='{origin}'" - ); + log::debug!("CORS preflight response: host_trusted={is_trusted}, origin='{origin}'"); return Ok(response.body(Body::empty()).unwrap()); } @@ -277,9 +271,7 @@ async fn proxy_request( .unwrap()); } } else if is_whitelisted_path { - log::debug!( - "Bypassing authorization check for whitelisted path: {path}" - ); + log::debug!("Bypassing authorization check for whitelisted path: {path}"); } if path.contains("/configs") { @@ -302,8 +294,10 @@ async fn proxy_request( match (method.clone(), destination_path.as_str()) { (hyper::Method::POST, "/chat/completions") | (hyper::Method::POST, "/completions") - | (hyper::Method::POST, "/embeddings") => { - log::debug!( + | (hyper::Method::POST, "/embeddings") + | (hyper::Method::POST, "/messages") + | (hyper::Method::POST, "/messages/count_tokens") => { + log::info!( "Handling POST request to {destination_path} requiring model lookup in body", ); let body_bytes = match hyper::body::to_bytes(body).await { @@ -331,9 +325,7 @@ async fn proxy_request( let sessions_guard = sessions.lock().await; if sessions_guard.is_empty() { - log::warn!( - "Request for model '{model_id}' but no models are running." - ); + log::warn!("Request for model '{model_id}' but no models are running."); let mut error_response = Response::builder().status(StatusCode::SERVICE_UNAVAILABLE); error_response = add_cors_headers_with_host_and_origin( @@ -388,9 +380,7 @@ async fn proxy_request( } } Err(e) => { - log::warn!( - "Failed to parse POST body for {destination_path} as JSON: {e}" - ); + log::warn!("Failed to parse POST body for {destination_path} as JSON: {e}"); let mut error_response = Response::builder().status(StatusCode::BAD_REQUEST); error_response = add_cors_headers_with_host_and_origin( error_response, @@ -564,8 +554,9 @@ async fn proxy_request( .unwrap()); } }; + log::info!("Proxying request to model server at port {port}, path: {destination_path}"); - let upstream_url = format!("http://127.0.0.1:{port}{destination_path}"); + let upstream_url = format!("http://127.0.0.1:{port}/v1{destination_path}"); let mut outbound_req = client.request(method.clone(), &upstream_url); diff --git a/src-tauri/src/core/server/tests.rs b/src-tauri/src/core/server/tests.rs new file mode 100644 index 0000000000..25d5c1cf31 --- /dev/null +++ b/src-tauri/src/core/server/tests.rs @@ -0,0 +1,113 @@ +#[cfg(test)] +mod tests { + use crate::core::server::proxy; + + #[test] + fn test_get_destination_path_basic() { + let result = proxy::get_destination_path("/v1/messages", "/v1"); + assert_eq!(result, "/messages"); + } + + #[test] + fn test_get_destination_path_with_subpath() { + let result = proxy::get_destination_path("/v1/messages/threads/123", "/v1"); + assert_eq!(result, "/messages/threads/123"); + } + + #[test] + fn test_get_destination_path_no_prefix() { + let result = proxy::get_destination_path("/messages", ""); + assert_eq!(result, "/messages"); + } + + #[test] + fn test_get_destination_path_different_prefix() { + let result = proxy::get_destination_path("/api/v1/messages", "/api/v1"); + assert_eq!(result, "/messages"); + } + + #[test] + fn test_get_destination_path_empty_prefix() { + let result = proxy::get_destination_path("/messages", "/v1"); + assert_eq!(result, "/messages"); + } + + #[test] + fn test_messages_in_cors_whitelist() { + let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico", "/messages"]; + assert!(whitelisted_paths.contains(&"/messages")); + } + + #[test] + fn test_messages_in_main_whitelist() { + let whitelisted_paths = [ + "/", + "/openapi.json", + "/favicon.ico", + "/docs/swagger-ui.css", + "/docs/swagger-ui-bundle.js", + "/docs/swagger-ui-standalone-preset.js", + "/messages", + ]; + assert!(whitelisted_paths.contains(&"/messages")); + } + + #[test] + fn test_messages_subpath_not_in_exact_whitelist() { + let whitelisted_paths = [ + "/", + "/openapi.json", + "/favicon.ico", + "/messages", + ]; + // Only exact match + assert!(!whitelisted_paths.contains(&"/messages/threads")); + assert!(!whitelisted_paths.contains(&"/messages/api")); + } + + #[test] + fn test_proxy_config_creation() { + let config = proxy::ProxyConfig { + prefix: "/v1".to_string(), + proxy_api_key: "test-key".to_string(), + trusted_hosts: vec![vec!["localhost".to_string()]], + }; + assert_eq!(config.prefix, "/v1"); + assert_eq!(config.proxy_api_key, "test-key"); + assert_eq!(config.trusted_hosts.len(), 1); + } + + #[test] + fn test_proxy_config_default() { + let config = proxy::ProxyConfig { + prefix: "".to_string(), + proxy_api_key: "".to_string(), + trusted_hosts: vec![], + }; + assert_eq!(config.prefix, ""); + assert_eq!(config.proxy_api_key, ""); + assert_eq!(config.trusted_hosts.len(), 0); + } + + #[test] + fn test_allowed_methods() { + let allowed_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]; + assert!(allowed_methods.contains(&"POST")); + assert!(allowed_methods.contains(&"GET")); + assert!(allowed_methods.contains(&"OPTIONS")); + } + + #[test] + fn test_allowed_headers() { + let allowed_headers = [ + "accept", + "authorization", + "content-type", + "host", + "origin", + "user-agent", + ]; + assert!(allowed_headers.contains(&"authorization")); + assert!(allowed_headers.contains(&"content-type")); + } +}