From 4ae071f0b0cde577718969028cf4f7948d1def7d Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:39:38 -0800 Subject: [PATCH 1/9] - Cache worker lookup in send_typed_request to avoid 4+ redundant get_by_url() calls per request - Replace sliding window search with SIMD-optimized memchr::memmem for stream end marker detection - Pre-filter headers before worker loop in route_simple_request to avoid repeated lowercasing per worker - Use eq_ignore_ascii_case() instead of to_lowercase() to avoid string allocations in header comparisons - Add conditional check for debug logging to avoid JSON serialization when debug is disabled - Cache get_all() result in router_manager to avoid duplicate calls --- sgl-model-gateway/Cargo.toml | 1 + sgl-model-gateway/src/routers/header_utils.rs | 57 ++++++------ sgl-model-gateway/src/routers/http/router.rs | 91 ++++++++++--------- .../src/routers/router_manager.rs | 8 +- 4 files changed, 84 insertions(+), 73 deletions(-) diff --git a/sgl-model-gateway/Cargo.toml b/sgl-model-gateway/Cargo.toml index 3f27186458c9..323c654b8f5f 100644 --- a/sgl-model-gateway/Cargo.toml +++ b/sgl-model-gateway/Cargo.toml @@ -70,6 +70,7 @@ parking_lot = "0.12.4" rayon = "1.10" thiserror = "2.0.12" regex = "1.10" +memchr = "2.7" # SIMD-optimized byte pattern searching url = "2.5.4" validator = { version = "0.20.0", features = ["derive"] } tokio-stream = { version = "0.1", features = ["sync"] } diff --git a/sgl-model-gateway/src/routers/header_utils.rs b/sgl-model-gateway/src/routers/header_utils.rs index f705a3fc97ed..b7ca861f6a49 100644 --- a/sgl-model-gateway/src/routers/header_utils.rs +++ b/sgl-model-gateway/src/routers/header_utils.rs @@ -26,8 +26,8 @@ pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap { for (name, value) in reqwest_headers.iter() { // Skip hop-by-hop headers that shouldn't be forwarded - let name_str = name.as_str().to_lowercase(); - if should_forward_header(&name_str) { + // Use eq_ignore_ascii_case to avoid string allocation + if should_forward_header_no_alloc(name.as_str()) { // The original name and value are already valid, so we can just clone them headers.insert(name.clone(), value.clone()); } @@ -36,22 +36,20 @@ pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap { headers } -/// Determine if a header should be forwarded from backend to client -fn should_forward_header(name: &str) -> bool { +/// Determine if a header should be forwarded without allocating (case-insensitive) +fn should_forward_header_no_alloc(name: &str) -> bool { // List of headers that should NOT be forwarded (hop-by-hop headers) - !matches!( - name, - "connection" | - "keep-alive" | - "proxy-authenticate" | - "proxy-authorization" | - "te" | - "trailers" | - "transfer-encoding" | - "upgrade" | - "content-encoding" | // Let axum/hyper handle encoding - "host" // Should not forward the backend's host header - ) + // Use eq_ignore_ascii_case to avoid to_lowercase() allocation + !(name.eq_ignore_ascii_case("connection") + || name.eq_ignore_ascii_case("keep-alive") + || name.eq_ignore_ascii_case("proxy-authenticate") + || name.eq_ignore_ascii_case("proxy-authorization") + || name.eq_ignore_ascii_case("te") + || name.eq_ignore_ascii_case("trailers") + || name.eq_ignore_ascii_case("transfer-encoding") + || name.eq_ignore_ascii_case("upgrade") + || name.eq_ignore_ascii_case("content-encoding") + || name.eq_ignore_ascii_case("host")) } /// Apply headers to a reqwest request builder, filtering out headers that shouldn't be forwarded @@ -70,24 +68,27 @@ pub fn apply_request_headers( } // Forward other headers, filtering out problematic ones + // Use eq_ignore_ascii_case to avoid to_lowercase() allocation per header for (key, value) in headers.iter() { - let key_str = key.as_str().to_lowercase(); + let key_str = key.as_str(); // Skip headers that: // - Are set automatically by reqwest (content-type, content-length for POST/PUT) // - We already handled (authorization) // - Are hop-by-hop headers (connection, transfer-encoding) // - Should not be forwarded (host) - let should_skip = key_str == "authorization" || // Already handled above - key_str == "host" || - key_str == "connection" || - key_str == "transfer-encoding" || - key_str == "keep-alive" || - key_str == "te" || - key_str == "trailers" || - key_str == "accept-encoding" || - key_str == "upgrade" || - (skip_content_headers && (key_str == "content-type" || key_str == "content-length")); + let should_skip = key_str.eq_ignore_ascii_case("authorization") // Already handled above + || key_str.eq_ignore_ascii_case("host") + || key_str.eq_ignore_ascii_case("connection") + || key_str.eq_ignore_ascii_case("transfer-encoding") + || key_str.eq_ignore_ascii_case("keep-alive") + || key_str.eq_ignore_ascii_case("te") + || key_str.eq_ignore_ascii_case("trailers") + || key_str.eq_ignore_ascii_case("accept-encoding") + || key_str.eq_ignore_ascii_case("upgrade") + || (skip_content_headers + && (key_str.eq_ignore_ascii_case("content-type") + || key_str.eq_ignore_ascii_case("content-length"))); if !should_skip { request_builder = request_builder.header(key.clone(), value.clone()); diff --git a/sgl-model-gateway/src/routers/http/router.rs b/sgl-model-gateway/src/routers/http/router.rs index f98905220128..ffcefde10adf 100644 --- a/sgl-model-gateway/src/routers/http/router.rs +++ b/sgl-model-gateway/src/routers/http/router.rs @@ -11,6 +11,7 @@ use axum::{ Json, }; use futures_util::StreamExt; +use memchr::memmem; use reqwest::Client; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error}; @@ -91,8 +92,10 @@ impl Router { Ok(worker_url) => { let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint)); for (name, value) in headers { - let name_lc = name.to_lowercase(); - if name_lc != "content-type" && name_lc != "content-length" { + // Use eq_ignore_ascii_case to avoid string allocation + if !name.eq_ignore_ascii_case("content-type") + && !name.eq_ignore_ascii_case("content-length") + { request_builder = request_builder.header(name, value); } } @@ -300,6 +303,18 @@ impl Router { return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); } + // Pre-filter headers once before the loop to avoid repeated lowercasing + let filtered_headers: Vec<_> = headers + .map(|hdrs| { + hdrs.iter() + .filter(|(name, _)| { + !name.as_str().eq_ignore_ascii_case("content-type") + && !name.as_str().eq_ignore_ascii_case("content-length") + }) + .collect() + }) + .unwrap_or_default(); + let mut last_response: Option = None; for worker in workers { let worker_url = worker.url(); @@ -323,13 +338,9 @@ impl Router { request_builder.header("Authorization", format!("Bearer {}", api_key)); } - if let Some(hdrs) = headers { - for (name, value) in hdrs { - let name_lc = name.as_str().to_lowercase(); - if name_lc != "content-type" && name_lc != "content-length" { - request_builder = request_builder.header(name, value); - } - } + // Apply pre-filtered headers + for (name, value) in &filtered_headers { + request_builder = request_builder.header(*name, *value); } match request_builder.send().await { @@ -417,11 +428,9 @@ impl Router { is_stream: bool, load_incremented: bool, // Whether load was incremented for this request ) -> Response { - // Get the worker's API key if available - let api_key = self - .worker_registry - .get_by_url(worker_url) - .and_then(|w| w.api_key().clone()); + // Get the worker once and reuse for API key and load tracking + let worker = self.worker_registry.get_by_url(worker_url); + let api_key = worker.as_ref().and_then(|w| w.api_key().clone()); let mut request_builder = if self.dp_aware { let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { @@ -452,10 +461,13 @@ impl Router { String::from("data_parallel_rank"), serde_json::json!(dp_rank), ); - debug!( - "Modified request body: {}", - serde_json::to_string(&json_val).unwrap_or(String::from("ERR")) - ); + // Only serialize if debug logging is enabled to avoid CPU overhead + if tracing::enabled!(tracing::Level::DEBUG) { + debug!( + "Modified request body: {}", + serde_json::to_string(&json_val).unwrap_or_else(|_| String::from("ERR")) + ); + } } else { return ( StatusCode::BAD_REQUEST, @@ -497,9 +509,9 @@ impl Router { // Decrement load on error if it was incremented if load_incremented { - if let Some(worker) = self.worker_registry.get_by_url(worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); + if let Some(ref w) = worker { + w.decrement_load(); + RouterMetrics::set_running_requests(worker_url, w.load()); } } @@ -528,9 +540,9 @@ impl Router { Err(e) => { // IMPORTANT: Decrement load on error before returning if load_incremented { - if let Some(worker) = self.worker_registry.get_by_url(worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); + if let Some(ref w) = worker { + w.decrement_load(); + RouterMetrics::set_running_requests(worker_url, w.load()); } } @@ -541,17 +553,18 @@ impl Router { // Decrement load counter for non-streaming requests if it was incremented if load_incremented { - if let Some(worker) = self.worker_registry.get_by_url(worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); + if let Some(ref w) = worker { + w.decrement_load(); + RouterMetrics::set_running_requests(worker_url, w.load()); } } response } else if load_incremented { // For streaming with load tracking, we need to manually decrement when done - let registry = Arc::clone(&self.worker_registry); - let worker_url = worker_url.to_string(); + // Clone the worker Arc for the async block instead of looking it up again + let stream_worker = worker.clone(); + let worker_url_owned = worker_url.to_string(); // Preserve headers for streaming response let mut response_headers = header_utils::preserve_response_headers(res.headers()); @@ -568,15 +581,11 @@ impl Router { while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { - // Check for stream end marker - if bytes - .as_ref() - .windows(12) - .any(|window| window == b"data: [DONE]") - { - if let Some(worker) = registry.get_by_url(&worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(&worker_url, worker.load()); + // Check for stream end marker using memmem for efficiency + if memmem::find(&bytes, b"data: [DONE]").is_some() { + if let Some(ref w) = stream_worker { + w.decrement_load(); + RouterMetrics::set_running_requests(&worker_url_owned, w.load()); decremented = true; } } @@ -591,9 +600,9 @@ impl Router { } } if !decremented { - if let Some(worker) = registry.get_by_url(&worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(&worker_url, worker.load()); + if let Some(ref w) = stream_worker { + w.decrement_load(); + RouterMetrics::set_running_requests(&worker_url_owned, w.load()); } } }); diff --git a/sgl-model-gateway/src/routers/router_manager.rs b/sgl-model-gateway/src/routers/router_manager.rs index 9ac494ff9b81..ceafc5625d30 100644 --- a/sgl-model-gateway/src/routers/router_manager.rs +++ b/sgl-model-gateway/src/routers/router_manager.rs @@ -269,13 +269,13 @@ impl RouterManager { let mut best_router = None; let mut best_score = 0.0; - let num_regular_workers = self - .worker_registry - .get_all() + // Cache worker list to avoid duplicate get_all() calls + let all_workers = self.worker_registry.get_all(); + let num_regular_workers = all_workers .iter() .filter(|w| matches!(w.worker_type(), WorkerType::Regular)) .count(); - let num_pd_workers = self.worker_registry.get_all().len() - num_regular_workers; + let num_pd_workers = all_workers.len() - num_regular_workers; for router in candidate_routers { let mut score = 1.0; From 7004b6c138b10b3143ff1350b783776d3407b03d Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:40:17 -0800 Subject: [PATCH 2/9] - Fix request ID generation: use byte array indexing (O(1)) instead of chars().nth() which is O(n) per character. Saves ~1500 operations per request. - Optimize extract_text_for_routing: use single buffer with direct string appending instead of intermediate Vec collection and join. Eliminates multiple allocations per request. - Add MessageContent.append_text_to() method for efficient buffer building without intermediate string allocations. - Add MessageContent.has_text() method for quick existence check without string creation. - Add StringOrArray.iter() method that returns string references without cloning the underlying data. - Add StringOrArray.first() method for quick access to first element. - Improve to_simple_string() to use Vec<&str> instead of Vec for parts concatenation, reducing allocations. --- sgl-model-gateway/src/middleware.rs | 9 +- sgl-model-gateway/src/protocols/chat.rs | 125 ++++++++++++++++++---- sgl-model-gateway/src/protocols/common.rs | 67 +++++++++++- 3 files changed, 174 insertions(+), 27 deletions(-) diff --git a/sgl-model-gateway/src/middleware.rs b/sgl-model-gateway/src/middleware.rs index fbad1188adc5..0f56f17be3b6 100644 --- a/sgl-model-gateway/src/middleware.rs +++ b/sgl-model-gateway/src/middleware.rs @@ -79,6 +79,9 @@ pub async fn auth_middleware( Ok(next.run(request).await) } +/// Alphanumeric characters for request ID generation (as bytes for O(1) indexing) +const REQUEST_ID_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + /// Generate OpenAI-compatible request ID based on endpoint fn generate_request_id(path: &str) -> String { let prefix = if path.contains("/chat/completions") { @@ -94,12 +97,12 @@ fn generate_request_id(path: &str) -> String { }; // Generate a random string similar to OpenAI's format - let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + // Use byte array indexing (O(1)) instead of chars().nth() (O(n)) let mut rng = rand::rng(); let random_part: String = (0..24) .map(|_| { - let idx = rng.random_range(0..chars.len()); - chars.chars().nth(idx).unwrap() + let idx = rng.random_range(0..REQUEST_ID_CHARS.len()); + REQUEST_ID_CHARS[idx] as char }) .collect(); diff --git a/sgl-model-gateway/src/protocols/chat.rs b/sgl-model-gateway/src/protocols/chat.rs index 6c3f2699a8c0..d5e745fe6690 100644 --- a/sgl-model-gateway/src/protocols/chat.rs +++ b/sgl-model-gateway/src/protocols/chat.rs @@ -66,21 +66,67 @@ pub enum MessageContent { } impl MessageContent { + /// Returns the text content, cloning only when necessary. + /// For simple text, returns a clone of the string. + /// For parts, concatenates text parts with spaces. pub fn to_simple_string(&self) -> String { match self { MessageContent::Text(text) => text.clone(), MessageContent::Parts(parts) => { - let texts: Vec = parts + // Pre-count text parts to avoid intermediate Vec allocation + let text_parts: Vec<&str> = parts .iter() .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), + ContentPart::Text { text } => Some(text.as_str()), _ => None, }) .collect(); - texts.join(" ") + text_parts.join(" ") } } } + + /// Appends text content directly to a buffer, avoiding intermediate allocations. + /// Returns true if any content was appended. + #[inline] + pub fn append_text_to(&self, buffer: &mut String) -> bool { + match self { + MessageContent::Text(text) => { + if !text.is_empty() { + buffer.push_str(text); + true + } else { + false + } + } + MessageContent::Parts(parts) => { + let mut appended = false; + for part in parts { + if let ContentPart::Text { text } = part { + if !text.is_empty() { + if appended { + buffer.push(' '); + } + buffer.push_str(text); + appended = true; + } + } + } + appended + } + } + } + + /// Returns true if this content contains any non-empty text. + #[inline] + pub fn has_text(&self) -> bool { + match self { + MessageContent::Text(text) => !text.is_empty(), + MessageContent::Parts(parts) => parts.iter().any(|part| { + matches!(part, ContentPart::Text { text } if !text.is_empty()) + }), + } + } } // ============================================================================ @@ -581,33 +627,66 @@ impl GenerationRequest for ChatCompletionRequest { fn extract_text_for_routing(&self) -> String { // Extract text from messages for routing decisions - self.messages - .iter() - .filter_map(|msg| match msg { - ChatMessage::System { content, .. } => Some(content.to_simple_string()), - ChatMessage::User { content, .. } => Some(content.to_simple_string()), + // Use a single buffer to avoid intermediate Vec allocations + let mut buffer = String::new(); + let mut has_content = false; + + for msg in &self.messages { + match msg { + ChatMessage::System { content, .. } | ChatMessage::User { content, .. } => { + if has_content && content.has_text() { + buffer.push(' '); + } + if content.append_text_to(&mut buffer) { + has_content = true; + } + } ChatMessage::Assistant { content, reasoning_content, .. } => { - // Combine content and reasoning content for routing decisions - let main_content = content - .as_ref() - .map(|c| c.to_simple_string()) - .unwrap_or_default(); - let reasoning = reasoning_content.clone().unwrap_or_default(); - if main_content.is_empty() && reasoning.is_empty() { - None - } else { - Some(format!("{} {}", main_content, reasoning).trim().to_string()) + // Append main content + if let Some(c) = content { + if has_content && c.has_text() { + buffer.push(' '); + } + if c.append_text_to(&mut buffer) { + has_content = true; + } + } + // Append reasoning content + if let Some(reasoning) = reasoning_content { + if !reasoning.is_empty() { + if has_content { + buffer.push(' '); + } + buffer.push_str(reasoning); + has_content = true; + } } } - ChatMessage::Tool { content, .. } => Some(content.to_simple_string()), - ChatMessage::Function { content, .. } => Some(content.clone()), - }) - .collect::>() - .join(" ") + ChatMessage::Tool { content, .. } => { + if has_content && content.has_text() { + buffer.push(' '); + } + if content.append_text_to(&mut buffer) { + has_content = true; + } + } + ChatMessage::Function { content, .. } => { + if !content.is_empty() { + if has_content { + buffer.push(' '); + } + buffer.push_str(content); + has_content = true; + } + } + } + } + + buffer } } diff --git a/sgl-model-gateway/src/protocols/common.rs b/sgl-model-gateway/src/protocols/common.rs index ac51e85ea6a3..4e3ae494bd5e 100644 --- a/sgl-model-gateway/src/protocols/common.rs +++ b/sgl-model-gateway/src/protocols/common.rs @@ -65,15 +65,80 @@ impl StringOrArray { } } - /// Convert to a vector of strings + /// Convert to a vector of strings (clones the data) pub fn to_vec(&self) -> Vec { match self { StringOrArray::String(s) => vec![s.clone()], StringOrArray::Array(arr) => arr.clone(), } } + + /// Returns an iterator over string references without cloning. + /// Use this instead of `to_vec()` when you only need to iterate. + pub fn iter(&self) -> StringOrArrayIter<'_> { + StringOrArrayIter { + inner: self, + index: 0, + } + } + + /// Returns the first string, or None if empty + pub fn first(&self) -> Option<&str> { + match self { + StringOrArray::String(s) => { + if s.is_empty() { + None + } else { + Some(s) + } + } + StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()), + } + } } +/// Iterator over StringOrArray that yields string references without cloning +pub struct StringOrArrayIter<'a> { + inner: &'a StringOrArray, + index: usize, +} + +impl<'a> Iterator for StringOrArrayIter<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + match self.inner { + StringOrArray::String(s) => { + if self.index == 0 { + self.index = 1; + Some(s.as_str()) + } else { + None + } + } + StringOrArray::Array(arr) => { + if self.index < arr.len() { + let item = &arr[self.index]; + self.index += 1; + Some(item.as_str()) + } else { + None + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = match self.inner { + StringOrArray::String(_) => 1 - self.index, + StringOrArray::Array(arr) => arr.len() - self.index, + }; + (remaining, Some(remaining)) + } +} + +impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {} + /// Validates stop sequences (max 4, non-empty strings) /// Used by both ChatCompletionRequest and ResponsesRequest pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> { From e1590a26c6c0799053a32e4b9c668e542e908c58 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:40:53 -0800 Subject: [PATCH 3/9] perf: optimize middleware latency recording and WASM string caching - Replace format!("{:?}", latency) with latency.as_micros() to avoid string allocation on every response - Pre-compute method, path, and query strings before WASM module loop to avoid repeated allocations per module iteration --- sgl-model-gateway/src/middleware.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sgl-model-gateway/src/middleware.rs b/sgl-model-gateway/src/middleware.rs index 0f56f17be3b6..d5c042af5fb2 100644 --- a/sgl-model-gateway/src/middleware.rs +++ b/sgl-model-gateway/src/middleware.rs @@ -266,7 +266,8 @@ impl OnResponse for ResponseLogger { // Record these in the span for structured logging/observability tools span.record("status_code", status.as_u16()); - span.record("latency", format!("{:?}", latency)); + // Use microseconds as integer to avoid format! string allocation + span.record("latency", latency.as_micros() as u64); // Log the response completion let _enter = span.enter(); @@ -632,13 +633,18 @@ pub async fn wasm_middleware( // Process each OnRequest module let mut modified_body = body_bytes; + // Pre-compute strings once before the loop to avoid repeated allocations + let method_str = method.to_string(); + let path_str = uri.path().to_string(); + let query_str = uri.query().unwrap_or("").to_string(); + for module in modules_on_request { // Build WebAssembly request from collected data let wasm_headers = build_wasm_headers_from_axum_headers(&headers); let wasm_request = WasmRequest { - method: method.to_string(), - path: uri.path().to_string(), - query: uri.query().unwrap_or("").to_string(), + method: method_str.clone(), + path: path_str.clone(), + query: query_str.clone(), headers: wasm_headers, body: modified_body.clone(), request_id: request_id.clone(), From b6348ad1eddec76d3e17da9bcf2fe56de57db858 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:41:24 -0800 Subject: [PATCH 4/9] perf: optimize worker and worker registry CPU overhead - RuntimeType::from_str: Use eq_ignore_ascii_case() instead of to_lowercase() to avoid String allocation per parse - worker_to_info: Use Display format (.to_string()) instead of Debug format (format!("{:?}"...)) for connection_mode - avoids unnecessary formatting overhead - WorkerRegistry::stats(): Iterate DashMap directly instead of calling get_all() which clones ALL workers into a Vec first. This avoids N Arc clones per stats() call where N = worker count --- sgl-model-gateway/src/core/worker.rs | 16 ++++++++++------ sgl-model-gateway/src/core/worker_registry.rs | 4 +++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/sgl-model-gateway/src/core/worker.rs b/sgl-model-gateway/src/core/worker.rs index 05307428605c..088c8116005d 100644 --- a/sgl-model-gateway/src/core/worker.rs +++ b/sgl-model-gateway/src/core/worker.rs @@ -355,11 +355,15 @@ impl std::str::FromStr for RuntimeType { type Err = String; fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "sglang" => Ok(RuntimeType::Sglang), - "vllm" => Ok(RuntimeType::Vllm), - "external" => Ok(RuntimeType::External), - _ => Err(format!("Unknown runtime type: {}", s)), + // Use eq_ignore_ascii_case to avoid to_lowercase() allocation + if s.eq_ignore_ascii_case("sglang") { + Ok(RuntimeType::Sglang) + } else if s.eq_ignore_ascii_case("vllm") { + Ok(RuntimeType::Vllm) + } else if s.eq_ignore_ascii_case("external") { + Ok(RuntimeType::External) + } else { + Err(format!("Unknown runtime type: {}", s)) } } } @@ -1111,7 +1115,7 @@ pub fn worker_to_info(worker: &Arc) -> WorkerInfo { worker_type: worker_type_str.to_string(), is_healthy: worker.is_healthy(), load: worker.load(), - connection_mode: format!("{:?}", worker.connection_mode()), + connection_mode: worker.connection_mode().to_string(), runtime_type, tokenizer_path: worker.tokenizer_path(model_id).map(String::from), reasoning_parser: worker.reasoning_parser(model_id).map(String::from), diff --git a/sgl-model-gateway/src/core/worker_registry.rs b/sgl-model-gateway/src/core/worker_registry.rs index 1bf9f2bd1e5c..751dd842d17c 100644 --- a/sgl-model-gateway/src/core/worker_registry.rs +++ b/sgl-model-gateway/src/core/worker_registry.rs @@ -350,7 +350,9 @@ impl WorkerRegistry { let mut prefill_count = 0; let mut decode_count = 0; - for worker in self.get_all() { + // Iterate DashMap directly to avoid cloning all workers via get_all() + for entry in self.workers.iter() { + let worker = entry.value(); if worker.is_healthy() { healthy_count += 1; } From fc9cd4d7d547eb41d8d1604c4078af712183eade Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:41:54 -0800 Subject: [PATCH 5/9] perf: additional worker layer optimizations - worker_to_info(): Cache worker_type(), connection_mode(), url(), and model_id() to avoid redundant clones and allocations. Previously called worker_type() twice, connection_mode() twice, and url() twice per invocation. - normalised_url(): Remove redundant contains("@") check before rfind('@'). The rfind already returns None if '@' is not found, so the contains() call was an extra O(n) scan. - WorkerRegistry::stats(): Count models by iterating directly instead of calling get_models().len() which allocates a Vec just to count elements. --- sgl-model-gateway/src/core/worker.rs | 47 ++++++++++--------- sgl-model-gateway/src/core/worker_registry.rs | 7 ++- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/sgl-model-gateway/src/core/worker.rs b/sgl-model-gateway/src/core/worker.rs index 088c8116005d..77c260056fe9 100644 --- a/sgl-model-gateway/src/core/worker.rs +++ b/sgl-model-gateway/src/core/worker.rs @@ -520,22 +520,18 @@ impl fmt::Debug for BasicWorker { impl BasicWorker { pub fn normalised_url(&self) -> WorkerResult<&str> { - if self.url().contains("@") { - // Use rfind to split from the right, handling IPv6 addresses with brackets - // e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0" - if let Some(at_pos) = self.url().rfind('@') { - let base_url = &self.url()[..at_pos]; - let rank_str = &self.url()[at_pos + 1..]; - - // Validate that the rank part is actually a number - match rank_str.parse::() { - Ok(_) => Ok(base_url), - Err(_) => { - // The '@' is not a DP rank separator, return full URL - Ok(self.url()) - } - } + // Use rfind directly - no need for redundant contains() check + // rfind already returns None if '@' is not found + // e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0" + if let Some(at_pos) = self.url().rfind('@') { + let base_url = &self.url()[..at_pos]; + let rank_str = &self.url()[at_pos + 1..]; + + // Validate that the rank part is actually a number + if rank_str.parse::().is_ok() { + Ok(base_url) } else { + // The '@' is not a DP rank separator, return full URL Ok(self.url()) } } else { @@ -1089,33 +1085,38 @@ impl HealthChecker { /// Helper to convert Worker trait object to WorkerInfo struct pub fn worker_to_info(worker: &Arc) -> WorkerInfo { - let worker_type_str = match worker.worker_type() { + // Cache values that are used multiple times to avoid redundant clones/allocations + let worker_type = worker.worker_type(); + let connection_mode = worker.connection_mode(); + let url = worker.url(); + let model_id = worker.model_id(); + + let worker_type_str = match &worker_type { WorkerType::Regular => "regular", WorkerType::Prefill { .. } => "prefill", WorkerType::Decode => "decode", }; - let bootstrap_port = match worker.worker_type() { - WorkerType::Prefill { bootstrap_port } => bootstrap_port, + let bootstrap_port = match &worker_type { + WorkerType::Prefill { bootstrap_port } => *bootstrap_port, _ => None, }; - let runtime_type = match worker.connection_mode() { + let runtime_type = match &connection_mode { ConnectionMode::Grpc { .. } => Some(worker.metadata().runtime_type.to_string()), ConnectionMode::Http => None, }; - let model_id = worker.model_id(); WorkerInfo { - id: worker.url().to_string(), - url: worker.url().to_string(), + id: url.to_string(), + url: url.to_string(), model_id: model_id.to_string(), priority: worker.priority(), cost: worker.cost(), worker_type: worker_type_str.to_string(), is_healthy: worker.is_healthy(), load: worker.load(), - connection_mode: worker.connection_mode().to_string(), + connection_mode: connection_mode.to_string(), runtime_type, tokenizer_path: worker.tokenizer_path(model_id).map(String::from), reasoning_parser: worker.reasoning_parser(model_id).map(String::from), diff --git a/sgl-model-gateway/src/core/worker_registry.rs b/sgl-model-gateway/src/core/worker_registry.rs index 751dd842d17c..6a3d9a7fec52 100644 --- a/sgl-model-gateway/src/core/worker_registry.rs +++ b/sgl-model-gateway/src/core/worker_registry.rs @@ -342,7 +342,12 @@ impl WorkerRegistry { /// Get worker statistics pub fn stats(&self) -> WorkerRegistryStats { let total_workers = self.workers.len(); - let total_models = self.get_models().len(); + // Count models directly instead of allocating Vec via get_models() + let total_models = self + .model_workers + .iter() + .filter(|entry| !entry.value().is_empty()) + .count(); let mut healthy_count = 0; let mut total_load = 0; From 30dcd239777b75fdedde25322db8f4bfeac41ee1 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:42:15 -0800 Subject: [PATCH 6/9] perf: fix double lock acquisition in PolicyRegistry In get_all_power_of_two_policies(), the code was acquiring the prefill_policy read lock twice: 1. First at line 284 to check if prefill policy is power_of_two 2. Again at lines 293-298 inside the decode policy check This caused unnecessary lock contention. Fixed by caching both prefill_policy and decode_policy at the start and reusing the cached values for subsequent checks. --- sgl-model-gateway/src/policies/registry.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sgl-model-gateway/src/policies/registry.rs b/sgl-model-gateway/src/policies/registry.rs index 5fe5b24ae7da..df2c15bb12cd 100644 --- a/sgl-model-gateway/src/policies/registry.rs +++ b/sgl-model-gateway/src/policies/registry.rs @@ -281,19 +281,20 @@ impl PolicyRegistry { power_of_two_policies.push(Arc::clone(&self.default_policy)); } - if let Some(ref policy) = *self.prefill_policy.read().unwrap() { + // Cache prefill and decode policies to avoid double-locking prefill_policy + let prefill_policy_opt = self.prefill_policy.read().unwrap().clone(); + let decode_policy_opt = self.decode_policy.read().unwrap().clone(); + + if let Some(ref policy) = prefill_policy_opt { if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) { power_of_two_policies.push(Arc::clone(policy)); } } - if let Some(ref policy) = *self.decode_policy.read().unwrap() { + if let Some(ref policy) = decode_policy_opt { if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) - && !self - .prefill_policy - .read() - .unwrap() + && !prefill_policy_opt .as_ref() .is_some_and(|p| Arc::ptr_eq(p, policy)) { From f2a42c832ae0fd775a9ee6463e9403869f713675 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:43:03 -0800 Subject: [PATCH 7/9] perf: optimize load balancing policies CPU overhead --- sgl-model-gateway/src/policies/bucket.rs | 8 ++++---- sgl-model-gateway/src/policies/cache_aware.rs | 10 ++++++---- sgl-model-gateway/src/policies/power_of_two.rs | 14 +++++--------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/sgl-model-gateway/src/policies/bucket.rs b/sgl-model-gateway/src/policies/bucket.rs index e16daac70943..56d8401df5c5 100644 --- a/sgl-model-gateway/src/policies/bucket.rs +++ b/sgl-model-gateway/src/policies/bucket.rs @@ -7,7 +7,7 @@ use std::{ use dashmap::DashMap; use rand::Rng; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use uuid::Uuid; use super::{get_healthy_worker_indices, BucketConfig, LoadBalancingPolicy}; @@ -259,14 +259,14 @@ impl LoadBalancingPolicy for BucketPolicy { let rel_threshold = self.config.balance_rel_threshold * min_load as f32; let is_imbalanced = abs_diff > self.config.balance_abs_threshold && max_load as f32 > rel_threshold; - info!( + debug!( "Current PD instance status | is_imbalanced={}", is_imbalanced ); let mut rng = rand::rng(); let prefill_url = if is_imbalanced { - info!("select prefill instance by Load Balance policy"); + debug!("select prefill instance by Load Balance policy"); let min_url = chars_per_url_snapshot .iter() .min_by_key(|(_, &chars)| chars) @@ -279,7 +279,7 @@ impl LoadBalancingPolicy for BucketPolicy { }); min_url } else { - info!("select prefill instance by Bucket policy"); + debug!("select prefill instance by Bucket policy"); match choiced_url { Some(url) if !url.is_empty() => url, _ => { diff --git a/sgl-model-gateway/src/policies/cache_aware.rs b/sgl-model-gateway/src/policies/cache_aware.rs index 781f413e7ba8..0b5400148df8 100644 --- a/sgl-model-gateway/src/policies/cache_aware.rs +++ b/sgl-model-gateway/src/policies/cache_aware.rs @@ -233,10 +233,12 @@ impl LoadBalancingPolicy for CacheAwarePolicy { first_model }; - // Get current load statistics - let loads: Vec = workers.iter().map(|w| w.load()).collect(); - let max_load = *loads.iter().max().unwrap_or(&0); - let min_load = *loads.iter().min().unwrap_or(&0); + // Get current load statistics - compute min/max in single pass without allocation + let (min_load, max_load) = workers.iter().fold((usize::MAX, 0usize), |(min, max), w| { + let load = w.load(); + (min.min(load), max.max(load)) + }); + let min_load = if min_load == usize::MAX { 0 } else { min_load }; // Check if load is imbalanced let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold diff --git a/sgl-model-gateway/src/policies/power_of_two.rs b/sgl-model-gateway/src/policies/power_of_two.rs index 851c83cb15bb..32c053851dd6 100644 --- a/sgl-model-gateway/src/policies/power_of_two.rs +++ b/sgl-model-gateway/src/policies/power_of_two.rs @@ -6,7 +6,7 @@ use std::{ }; use rand::Rng; -use tracing::info; +use tracing::debug; use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use crate::{core::Worker, observability::metrics::RouterMetrics}; @@ -57,15 +57,11 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { return Some(healthy_indices[0]); } - // Select two random workers + // Select two random workers - use offset to guarantee different selection in O(1) let mut rng = rand::rng(); let idx1 = rng.random_range(0..healthy_indices.len()); - let mut idx2 = rng.random_range(0..healthy_indices.len()); - - // Ensure we pick two different workers - while idx2 == idx1 { - idx2 = rng.random_range(0..healthy_indices.len()); - } + // Pick idx2 from remaining indices: offset by 1 + random from (len-1) to guarantee different + let idx2 = (idx1 + 1 + rng.random_range(0..healthy_indices.len() - 1)) % healthy_indices.len(); let worker_idx1 = healthy_indices[idx1]; let worker_idx2 = healthy_indices[idx2]; @@ -81,7 +77,7 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { worker_idx2 }; - info!( + debug!( "Power-of-two selection: {}={} vs {}={} -> selected {}", workers[worker_idx1].url(), load1, From c9dfe3bbf4eecd376498cb4b4cfe92ac7fa6dbb3 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:44:51 -0800 Subject: [PATCH 8/9] [model-gateway] fmy --- sgl-model-gateway/src/policies/power_of_two.rs | 3 ++- sgl-model-gateway/src/protocols/chat.rs | 6 +++--- sgl-model-gateway/src/routers/http/router.rs | 5 ++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sgl-model-gateway/src/policies/power_of_two.rs b/sgl-model-gateway/src/policies/power_of_two.rs index 32c053851dd6..18a2cdca2a17 100644 --- a/sgl-model-gateway/src/policies/power_of_two.rs +++ b/sgl-model-gateway/src/policies/power_of_two.rs @@ -61,7 +61,8 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { let mut rng = rand::rng(); let idx1 = rng.random_range(0..healthy_indices.len()); // Pick idx2 from remaining indices: offset by 1 + random from (len-1) to guarantee different - let idx2 = (idx1 + 1 + rng.random_range(0..healthy_indices.len() - 1)) % healthy_indices.len(); + let idx2 = + (idx1 + 1 + rng.random_range(0..healthy_indices.len() - 1)) % healthy_indices.len(); let worker_idx1 = healthy_indices[idx1]; let worker_idx2 = healthy_indices[idx2]; diff --git a/sgl-model-gateway/src/protocols/chat.rs b/sgl-model-gateway/src/protocols/chat.rs index d5e745fe6690..c1f54828e378 100644 --- a/sgl-model-gateway/src/protocols/chat.rs +++ b/sgl-model-gateway/src/protocols/chat.rs @@ -122,9 +122,9 @@ impl MessageContent { pub fn has_text(&self) -> bool { match self { MessageContent::Text(text) => !text.is_empty(), - MessageContent::Parts(parts) => parts.iter().any(|part| { - matches!(part, ContentPart::Text { text } if !text.is_empty()) - }), + MessageContent::Parts(parts) => parts + .iter() + .any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())), } } } diff --git a/sgl-model-gateway/src/routers/http/router.rs b/sgl-model-gateway/src/routers/http/router.rs index ffcefde10adf..b7b6ee8c7696 100644 --- a/sgl-model-gateway/src/routers/http/router.rs +++ b/sgl-model-gateway/src/routers/http/router.rs @@ -585,7 +585,10 @@ impl Router { if memmem::find(&bytes, b"data: [DONE]").is_some() { if let Some(ref w) = stream_worker { w.decrement_load(); - RouterMetrics::set_running_requests(&worker_url_owned, w.load()); + RouterMetrics::set_running_requests( + &worker_url_owned, + w.load(), + ); decremented = true; } } From 99a00c06ca0f2df1cea14208bc872696a2e8791b Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 8 Dec 2025 08:49:15 -0800 Subject: [PATCH 9/9] trigger build