Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src-tauri/src/core/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pub mod commands;
pub mod proxy;
#[cfg(test)]
pub mod tests;
43 changes: 17 additions & 26 deletions src-tauri/src/core/server/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ use crate::core::state::ServerHandle;

/// Configuration for the proxy server
#[derive(Clone)]
struct ProxyConfig {
prefix: String,
proxy_api_key: String,
trusted_hosts: Vec<Vec<String>>,
pub struct ProxyConfig {
pub prefix: String,
pub proxy_api_key: String,
pub trusted_hosts: Vec<Vec<String>>,
}

/// Determines the final destination path based on the original request path
fn get_destination_path(original_path: &str, prefix: &str) -> String {
pub fn get_destination_path(original_path: &str, prefix: &str) -> String {
remove_prefix(original_path, prefix)
}

Expand Down Expand Up @@ -95,9 +95,7 @@ async fn proxy_request(
};

if !is_trusted {
log::warn!(
"CORS preflight: Host '{host}' not trusted for path '{request_path}'"
);
log::warn!("CORS preflight: Host '{host}' not trusted for path '{request_path}'");
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Host not allowed"))
Expand Down Expand Up @@ -153,9 +151,7 @@ async fn proxy_request(
};

if !headers_valid {
log::warn!(
"CORS preflight: Some requested headers not allowed: {requested_headers}"
);
log::warn!("CORS preflight: Some requested headers not allowed: {requested_headers}");
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Headers not allowed"))
Expand All @@ -180,9 +176,7 @@ async fn proxy_request(
response = response.header("Access-Control-Allow-Origin", "*");
}

log::debug!(
"CORS preflight response: host_trusted={is_trusted}, origin='{origin}'"
);
log::debug!("CORS preflight response: host_trusted={is_trusted}, origin='{origin}'");
return Ok(response.body(Body::empty()).unwrap());
}

Expand Down Expand Up @@ -277,9 +271,7 @@ async fn proxy_request(
.unwrap());
}
} else if is_whitelisted_path {
log::debug!(
"Bypassing authorization check for whitelisted path: {path}"
);
log::debug!("Bypassing authorization check for whitelisted path: {path}");
}

if path.contains("/configs") {
Expand All @@ -302,8 +294,10 @@ async fn proxy_request(
match (method.clone(), destination_path.as_str()) {
(hyper::Method::POST, "/chat/completions")
| (hyper::Method::POST, "/completions")
| (hyper::Method::POST, "/embeddings") => {
log::debug!(
| (hyper::Method::POST, "/embeddings")
| (hyper::Method::POST, "/messages")
| (hyper::Method::POST, "/messages/count_tokens") => {
log::info!(
"Handling POST request to {destination_path} requiring model lookup in body",
);
let body_bytes = match hyper::body::to_bytes(body).await {
Expand Down Expand Up @@ -331,9 +325,7 @@ async fn proxy_request(
let sessions_guard = sessions.lock().await;

if sessions_guard.is_empty() {
log::warn!(
"Request for model '{model_id}' but no models are running."
);
log::warn!("Request for model '{model_id}' but no models are running.");
let mut error_response =
Response::builder().status(StatusCode::SERVICE_UNAVAILABLE);
error_response = add_cors_headers_with_host_and_origin(
Expand Down Expand Up @@ -388,9 +380,7 @@ async fn proxy_request(
}
}
Err(e) => {
log::warn!(
"Failed to parse POST body for {destination_path} as JSON: {e}"
);
log::warn!("Failed to parse POST body for {destination_path} as JSON: {e}");
let mut error_response = Response::builder().status(StatusCode::BAD_REQUEST);
error_response = add_cors_headers_with_host_and_origin(
error_response,
Expand Down Expand Up @@ -564,8 +554,9 @@ async fn proxy_request(
.unwrap());
}
};
log::info!("Proxying request to model server at port {port}, path: {destination_path}");

let upstream_url = format!("http://127.0.0.1:{port}{destination_path}");
let upstream_url = format!("http://127.0.0.1:{port}/v1{destination_path}");

let mut outbound_req = client.request(method.clone(), &upstream_url);

Expand Down
113 changes: 113 additions & 0 deletions src-tauri/src/core/server/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#[cfg(test)]
mod tests {
use crate::core::server::proxy;

#[test]
fn test_get_destination_path_basic() {
let result = proxy::get_destination_path("/v1/messages", "/v1");
assert_eq!(result, "/messages");
}

#[test]
fn test_get_destination_path_with_subpath() {
let result = proxy::get_destination_path("/v1/messages/threads/123", "/v1");
assert_eq!(result, "/messages/threads/123");
}

#[test]
fn test_get_destination_path_no_prefix() {
let result = proxy::get_destination_path("/messages", "");
assert_eq!(result, "/messages");
}

#[test]
fn test_get_destination_path_different_prefix() {
let result = proxy::get_destination_path("/api/v1/messages", "/api/v1");
assert_eq!(result, "/messages");
}

#[test]
fn test_get_destination_path_empty_prefix() {
let result = proxy::get_destination_path("/messages", "/v1");
assert_eq!(result, "/messages");
}

#[test]
fn test_messages_in_cors_whitelist() {
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico", "/messages"];
assert!(whitelisted_paths.contains(&"/messages"));
}

#[test]
fn test_messages_in_main_whitelist() {
let whitelisted_paths = [
"/",
"/openapi.json",
"/favicon.ico",
"/docs/swagger-ui.css",
"/docs/swagger-ui-bundle.js",
"/docs/swagger-ui-standalone-preset.js",
"/messages",
];
assert!(whitelisted_paths.contains(&"/messages"));
}

#[test]
fn test_messages_subpath_not_in_exact_whitelist() {
let whitelisted_paths = [
"/",
"/openapi.json",
"/favicon.ico",
"/messages",
];
// Only exact match
assert!(!whitelisted_paths.contains(&"/messages/threads"));
assert!(!whitelisted_paths.contains(&"/messages/api"));
}

#[test]
fn test_proxy_config_creation() {
let config = proxy::ProxyConfig {
prefix: "/v1".to_string(),
proxy_api_key: "test-key".to_string(),
trusted_hosts: vec![vec!["localhost".to_string()]],
};
assert_eq!(config.prefix, "/v1");
assert_eq!(config.proxy_api_key, "test-key");
assert_eq!(config.trusted_hosts.len(), 1);
}

#[test]
fn test_proxy_config_default() {
let config = proxy::ProxyConfig {
prefix: "".to_string(),
proxy_api_key: "".to_string(),
trusted_hosts: vec![],
};
assert_eq!(config.prefix, "");
assert_eq!(config.proxy_api_key, "");
assert_eq!(config.trusted_hosts.len(), 0);
}

#[test]
fn test_allowed_methods() {
let allowed_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"];
assert!(allowed_methods.contains(&"POST"));
assert!(allowed_methods.contains(&"GET"));
assert!(allowed_methods.contains(&"OPTIONS"));
}

#[test]
fn test_allowed_headers() {
let allowed_headers = [
"accept",
"authorization",
"content-type",
"host",
"origin",
"user-agent",
];
assert!(allowed_headers.contains(&"authorization"));
assert!(allowed_headers.contains(&"content-type"));
}
}
Loading