diff --git a/sgl-model-gateway/Cargo.toml b/sgl-model-gateway/Cargo.toml index 3f27186458c9..323c654b8f5f 100644 --- a/sgl-model-gateway/Cargo.toml +++ b/sgl-model-gateway/Cargo.toml @@ -70,6 +70,7 @@ parking_lot = "0.12.4" rayon = "1.10" thiserror = "2.0.12" regex = "1.10" +memchr = "2.7" # SIMD-optimized byte pattern searching url = "2.5.4" validator = { version = "0.20.0", features = ["derive"] } tokio-stream = { version = "0.1", features = ["sync"] } diff --git a/sgl-model-gateway/src/core/worker.rs b/sgl-model-gateway/src/core/worker.rs index 05307428605c..77c260056fe9 100644 --- a/sgl-model-gateway/src/core/worker.rs +++ b/sgl-model-gateway/src/core/worker.rs @@ -355,11 +355,15 @@ impl std::str::FromStr for RuntimeType { type Err = String; fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "sglang" => Ok(RuntimeType::Sglang), - "vllm" => Ok(RuntimeType::Vllm), - "external" => Ok(RuntimeType::External), - _ => Err(format!("Unknown runtime type: {}", s)), + // Use eq_ignore_ascii_case to avoid to_lowercase() allocation + if s.eq_ignore_ascii_case("sglang") { + Ok(RuntimeType::Sglang) + } else if s.eq_ignore_ascii_case("vllm") { + Ok(RuntimeType::Vllm) + } else if s.eq_ignore_ascii_case("external") { + Ok(RuntimeType::External) + } else { + Err(format!("Unknown runtime type: {}", s)) } } } @@ -516,22 +520,18 @@ impl fmt::Debug for BasicWorker { impl BasicWorker { pub fn normalised_url(&self) -> WorkerResult<&str> { - if self.url().contains("@") { - // Use rfind to split from the right, handling IPv6 addresses with brackets - // e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0" - if let Some(at_pos) = self.url().rfind('@') { - let base_url = &self.url()[..at_pos]; - let rank_str = &self.url()[at_pos + 1..]; - - // Validate that the rank part is actually a number - match rank_str.parse::() { - Ok(_) => Ok(base_url), - Err(_) => { - // The '@' is not a DP rank separator, return full URL - Ok(self.url()) - } - } + // Use rfind directly - no need for redundant contains() check + // rfind already returns None if '@' is not found + // e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0" + if let Some(at_pos) = self.url().rfind('@') { + let base_url = &self.url()[..at_pos]; + let rank_str = &self.url()[at_pos + 1..]; + + // Validate that the rank part is actually a number + if rank_str.parse::().is_ok() { + Ok(base_url) } else { + // The '@' is not a DP rank separator, return full URL Ok(self.url()) } } else { @@ -1085,33 +1085,38 @@ impl HealthChecker { /// Helper to convert Worker trait object to WorkerInfo struct pub fn worker_to_info(worker: &Arc) -> WorkerInfo { - let worker_type_str = match worker.worker_type() { + // Cache values that are used multiple times to avoid redundant clones/allocations + let worker_type = worker.worker_type(); + let connection_mode = worker.connection_mode(); + let url = worker.url(); + let model_id = worker.model_id(); + + let worker_type_str = match &worker_type { WorkerType::Regular => "regular", WorkerType::Prefill { .. } => "prefill", WorkerType::Decode => "decode", }; - let bootstrap_port = match worker.worker_type() { - WorkerType::Prefill { bootstrap_port } => bootstrap_port, + let bootstrap_port = match &worker_type { + WorkerType::Prefill { bootstrap_port } => *bootstrap_port, _ => None, }; - let runtime_type = match worker.connection_mode() { + let runtime_type = match &connection_mode { ConnectionMode::Grpc { .. } => Some(worker.metadata().runtime_type.to_string()), ConnectionMode::Http => None, }; - let model_id = worker.model_id(); WorkerInfo { - id: worker.url().to_string(), - url: worker.url().to_string(), + id: url.to_string(), + url: url.to_string(), model_id: model_id.to_string(), priority: worker.priority(), cost: worker.cost(), worker_type: worker_type_str.to_string(), is_healthy: worker.is_healthy(), load: worker.load(), - connection_mode: format!("{:?}", worker.connection_mode()), + connection_mode: connection_mode.to_string(), runtime_type, tokenizer_path: worker.tokenizer_path(model_id).map(String::from), reasoning_parser: worker.reasoning_parser(model_id).map(String::from), diff --git a/sgl-model-gateway/src/core/worker_registry.rs b/sgl-model-gateway/src/core/worker_registry.rs index 1bf9f2bd1e5c..6a3d9a7fec52 100644 --- a/sgl-model-gateway/src/core/worker_registry.rs +++ b/sgl-model-gateway/src/core/worker_registry.rs @@ -342,7 +342,12 @@ impl WorkerRegistry { /// Get worker statistics pub fn stats(&self) -> WorkerRegistryStats { let total_workers = self.workers.len(); - let total_models = self.get_models().len(); + // Count models directly instead of allocating Vec via get_models() + let total_models = self + .model_workers + .iter() + .filter(|entry| !entry.value().is_empty()) + .count(); let mut healthy_count = 0; let mut total_load = 0; @@ -350,7 +355,9 @@ impl WorkerRegistry { let mut prefill_count = 0; let mut decode_count = 0; - for worker in self.get_all() { + // Iterate DashMap directly to avoid cloning all workers via get_all() + for entry in self.workers.iter() { + let worker = entry.value(); if worker.is_healthy() { healthy_count += 1; } diff --git a/sgl-model-gateway/src/middleware.rs b/sgl-model-gateway/src/middleware.rs index fbad1188adc5..d5c042af5fb2 100644 --- a/sgl-model-gateway/src/middleware.rs +++ b/sgl-model-gateway/src/middleware.rs @@ -79,6 +79,9 @@ pub async fn auth_middleware( Ok(next.run(request).await) } +/// Alphanumeric characters for request ID generation (as bytes for O(1) indexing) +const REQUEST_ID_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + /// Generate OpenAI-compatible request ID based on endpoint fn generate_request_id(path: &str) -> String { let prefix = if path.contains("/chat/completions") { @@ -94,12 +97,12 @@ fn generate_request_id(path: &str) -> String { }; // Generate a random string similar to OpenAI's format - let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + // Use byte array indexing (O(1)) instead of chars().nth() (O(n)) let mut rng = rand::rng(); let random_part: String = (0..24) .map(|_| { - let idx = rng.random_range(0..chars.len()); - chars.chars().nth(idx).unwrap() + let idx = rng.random_range(0..REQUEST_ID_CHARS.len()); + REQUEST_ID_CHARS[idx] as char }) .collect(); @@ -263,7 +266,8 @@ impl OnResponse for ResponseLogger { // Record these in the span for structured logging/observability tools span.record("status_code", status.as_u16()); - span.record("latency", format!("{:?}", latency)); + // Use microseconds as integer to avoid format! string allocation + span.record("latency", latency.as_micros() as u64); // Log the response completion let _enter = span.enter(); @@ -629,13 +633,18 @@ pub async fn wasm_middleware( // Process each OnRequest module let mut modified_body = body_bytes; + // Pre-compute strings once before the loop to avoid repeated allocations + let method_str = method.to_string(); + let path_str = uri.path().to_string(); + let query_str = uri.query().unwrap_or("").to_string(); + for module in modules_on_request { // Build WebAssembly request from collected data let wasm_headers = build_wasm_headers_from_axum_headers(&headers); let wasm_request = WasmRequest { - method: method.to_string(), - path: uri.path().to_string(), - query: uri.query().unwrap_or("").to_string(), + method: method_str.clone(), + path: path_str.clone(), + query: query_str.clone(), headers: wasm_headers, body: modified_body.clone(), request_id: request_id.clone(), diff --git a/sgl-model-gateway/src/policies/bucket.rs b/sgl-model-gateway/src/policies/bucket.rs index e16daac70943..56d8401df5c5 100644 --- a/sgl-model-gateway/src/policies/bucket.rs +++ b/sgl-model-gateway/src/policies/bucket.rs @@ -7,7 +7,7 @@ use std::{ use dashmap::DashMap; use rand::Rng; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use uuid::Uuid; use super::{get_healthy_worker_indices, BucketConfig, LoadBalancingPolicy}; @@ -259,14 +259,14 @@ impl LoadBalancingPolicy for BucketPolicy { let rel_threshold = self.config.balance_rel_threshold * min_load as f32; let is_imbalanced = abs_diff > self.config.balance_abs_threshold && max_load as f32 > rel_threshold; - info!( + debug!( "Current PD instance status | is_imbalanced={}", is_imbalanced ); let mut rng = rand::rng(); let prefill_url = if is_imbalanced { - info!("select prefill instance by Load Balance policy"); + debug!("select prefill instance by Load Balance policy"); let min_url = chars_per_url_snapshot .iter() .min_by_key(|(_, &chars)| chars) @@ -279,7 +279,7 @@ impl LoadBalancingPolicy for BucketPolicy { }); min_url } else { - info!("select prefill instance by Bucket policy"); + debug!("select prefill instance by Bucket policy"); match choiced_url { Some(url) if !url.is_empty() => url, _ => { diff --git a/sgl-model-gateway/src/policies/cache_aware.rs b/sgl-model-gateway/src/policies/cache_aware.rs index 781f413e7ba8..0b5400148df8 100644 --- a/sgl-model-gateway/src/policies/cache_aware.rs +++ b/sgl-model-gateway/src/policies/cache_aware.rs @@ -233,10 +233,12 @@ impl LoadBalancingPolicy for CacheAwarePolicy { first_model }; - // Get current load statistics - let loads: Vec = workers.iter().map(|w| w.load()).collect(); - let max_load = *loads.iter().max().unwrap_or(&0); - let min_load = *loads.iter().min().unwrap_or(&0); + // Get current load statistics - compute min/max in single pass without allocation + let (min_load, max_load) = workers.iter().fold((usize::MAX, 0usize), |(min, max), w| { + let load = w.load(); + (min.min(load), max.max(load)) + }); + let min_load = if min_load == usize::MAX { 0 } else { min_load }; // Check if load is imbalanced let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold diff --git a/sgl-model-gateway/src/policies/power_of_two.rs b/sgl-model-gateway/src/policies/power_of_two.rs index 851c83cb15bb..18a2cdca2a17 100644 --- a/sgl-model-gateway/src/policies/power_of_two.rs +++ b/sgl-model-gateway/src/policies/power_of_two.rs @@ -6,7 +6,7 @@ use std::{ }; use rand::Rng; -use tracing::info; +use tracing::debug; use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use crate::{core::Worker, observability::metrics::RouterMetrics}; @@ -57,15 +57,12 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { return Some(healthy_indices[0]); } - // Select two random workers + // Select two random workers - use offset to guarantee different selection in O(1) let mut rng = rand::rng(); let idx1 = rng.random_range(0..healthy_indices.len()); - let mut idx2 = rng.random_range(0..healthy_indices.len()); - - // Ensure we pick two different workers - while idx2 == idx1 { - idx2 = rng.random_range(0..healthy_indices.len()); - } + // Pick idx2 from remaining indices: offset by 1 + random from (len-1) to guarantee different + let idx2 = + (idx1 + 1 + rng.random_range(0..healthy_indices.len() - 1)) % healthy_indices.len(); let worker_idx1 = healthy_indices[idx1]; let worker_idx2 = healthy_indices[idx2]; @@ -81,7 +78,7 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { worker_idx2 }; - info!( + debug!( "Power-of-two selection: {}={} vs {}={} -> selected {}", workers[worker_idx1].url(), load1, diff --git a/sgl-model-gateway/src/policies/registry.rs b/sgl-model-gateway/src/policies/registry.rs index 5fe5b24ae7da..df2c15bb12cd 100644 --- a/sgl-model-gateway/src/policies/registry.rs +++ b/sgl-model-gateway/src/policies/registry.rs @@ -281,19 +281,20 @@ impl PolicyRegistry { power_of_two_policies.push(Arc::clone(&self.default_policy)); } - if let Some(ref policy) = *self.prefill_policy.read().unwrap() { + // Cache prefill and decode policies to avoid double-locking prefill_policy + let prefill_policy_opt = self.prefill_policy.read().unwrap().clone(); + let decode_policy_opt = self.decode_policy.read().unwrap().clone(); + + if let Some(ref policy) = prefill_policy_opt { if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) { power_of_two_policies.push(Arc::clone(policy)); } } - if let Some(ref policy) = *self.decode_policy.read().unwrap() { + if let Some(ref policy) = decode_policy_opt { if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) - && !self - .prefill_policy - .read() - .unwrap() + && !prefill_policy_opt .as_ref() .is_some_and(|p| Arc::ptr_eq(p, policy)) { diff --git a/sgl-model-gateway/src/protocols/chat.rs b/sgl-model-gateway/src/protocols/chat.rs index 6c3f2699a8c0..c1f54828e378 100644 --- a/sgl-model-gateway/src/protocols/chat.rs +++ b/sgl-model-gateway/src/protocols/chat.rs @@ -66,21 +66,67 @@ pub enum MessageContent { } impl MessageContent { + /// Returns the text content, cloning only when necessary. + /// For simple text, returns a clone of the string. + /// For parts, concatenates text parts with spaces. pub fn to_simple_string(&self) -> String { match self { MessageContent::Text(text) => text.clone(), MessageContent::Parts(parts) => { - let texts: Vec = parts + // Pre-count text parts to avoid intermediate Vec allocation + let text_parts: Vec<&str> = parts .iter() .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), + ContentPart::Text { text } => Some(text.as_str()), _ => None, }) .collect(); - texts.join(" ") + text_parts.join(" ") } } } + + /// Appends text content directly to a buffer, avoiding intermediate allocations. + /// Returns true if any content was appended. + #[inline] + pub fn append_text_to(&self, buffer: &mut String) -> bool { + match self { + MessageContent::Text(text) => { + if !text.is_empty() { + buffer.push_str(text); + true + } else { + false + } + } + MessageContent::Parts(parts) => { + let mut appended = false; + for part in parts { + if let ContentPart::Text { text } = part { + if !text.is_empty() { + if appended { + buffer.push(' '); + } + buffer.push_str(text); + appended = true; + } + } + } + appended + } + } + } + + /// Returns true if this content contains any non-empty text. + #[inline] + pub fn has_text(&self) -> bool { + match self { + MessageContent::Text(text) => !text.is_empty(), + MessageContent::Parts(parts) => parts + .iter() + .any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())), + } + } } // ============================================================================ @@ -581,33 +627,66 @@ impl GenerationRequest for ChatCompletionRequest { fn extract_text_for_routing(&self) -> String { // Extract text from messages for routing decisions - self.messages - .iter() - .filter_map(|msg| match msg { - ChatMessage::System { content, .. } => Some(content.to_simple_string()), - ChatMessage::User { content, .. } => Some(content.to_simple_string()), + // Use a single buffer to avoid intermediate Vec allocations + let mut buffer = String::new(); + let mut has_content = false; + + for msg in &self.messages { + match msg { + ChatMessage::System { content, .. } | ChatMessage::User { content, .. } => { + if has_content && content.has_text() { + buffer.push(' '); + } + if content.append_text_to(&mut buffer) { + has_content = true; + } + } ChatMessage::Assistant { content, reasoning_content, .. } => { - // Combine content and reasoning content for routing decisions - let main_content = content - .as_ref() - .map(|c| c.to_simple_string()) - .unwrap_or_default(); - let reasoning = reasoning_content.clone().unwrap_or_default(); - if main_content.is_empty() && reasoning.is_empty() { - None - } else { - Some(format!("{} {}", main_content, reasoning).trim().to_string()) + // Append main content + if let Some(c) = content { + if has_content && c.has_text() { + buffer.push(' '); + } + if c.append_text_to(&mut buffer) { + has_content = true; + } + } + // Append reasoning content + if let Some(reasoning) = reasoning_content { + if !reasoning.is_empty() { + if has_content { + buffer.push(' '); + } + buffer.push_str(reasoning); + has_content = true; + } } } - ChatMessage::Tool { content, .. } => Some(content.to_simple_string()), - ChatMessage::Function { content, .. } => Some(content.clone()), - }) - .collect::>() - .join(" ") + ChatMessage::Tool { content, .. } => { + if has_content && content.has_text() { + buffer.push(' '); + } + if content.append_text_to(&mut buffer) { + has_content = true; + } + } + ChatMessage::Function { content, .. } => { + if !content.is_empty() { + if has_content { + buffer.push(' '); + } + buffer.push_str(content); + has_content = true; + } + } + } + } + + buffer } } diff --git a/sgl-model-gateway/src/protocols/common.rs b/sgl-model-gateway/src/protocols/common.rs index ac51e85ea6a3..4e3ae494bd5e 100644 --- a/sgl-model-gateway/src/protocols/common.rs +++ b/sgl-model-gateway/src/protocols/common.rs @@ -65,15 +65,80 @@ impl StringOrArray { } } - /// Convert to a vector of strings + /// Convert to a vector of strings (clones the data) pub fn to_vec(&self) -> Vec { match self { StringOrArray::String(s) => vec![s.clone()], StringOrArray::Array(arr) => arr.clone(), } } + + /// Returns an iterator over string references without cloning. + /// Use this instead of `to_vec()` when you only need to iterate. + pub fn iter(&self) -> StringOrArrayIter<'_> { + StringOrArrayIter { + inner: self, + index: 0, + } + } + + /// Returns the first string, or None if empty + pub fn first(&self) -> Option<&str> { + match self { + StringOrArray::String(s) => { + if s.is_empty() { + None + } else { + Some(s) + } + } + StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()), + } + } } +/// Iterator over StringOrArray that yields string references without cloning +pub struct StringOrArrayIter<'a> { + inner: &'a StringOrArray, + index: usize, +} + +impl<'a> Iterator for StringOrArrayIter<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + match self.inner { + StringOrArray::String(s) => { + if self.index == 0 { + self.index = 1; + Some(s.as_str()) + } else { + None + } + } + StringOrArray::Array(arr) => { + if self.index < arr.len() { + let item = &arr[self.index]; + self.index += 1; + Some(item.as_str()) + } else { + None + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = match self.inner { + StringOrArray::String(_) => 1 - self.index, + StringOrArray::Array(arr) => arr.len() - self.index, + }; + (remaining, Some(remaining)) + } +} + +impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {} + /// Validates stop sequences (max 4, non-empty strings) /// Used by both ChatCompletionRequest and ResponsesRequest pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> { diff --git a/sgl-model-gateway/src/routers/header_utils.rs b/sgl-model-gateway/src/routers/header_utils.rs index f705a3fc97ed..b7ca861f6a49 100644 --- a/sgl-model-gateway/src/routers/header_utils.rs +++ b/sgl-model-gateway/src/routers/header_utils.rs @@ -26,8 +26,8 @@ pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap { for (name, value) in reqwest_headers.iter() { // Skip hop-by-hop headers that shouldn't be forwarded - let name_str = name.as_str().to_lowercase(); - if should_forward_header(&name_str) { + // Use eq_ignore_ascii_case to avoid string allocation + if should_forward_header_no_alloc(name.as_str()) { // The original name and value are already valid, so we can just clone them headers.insert(name.clone(), value.clone()); } @@ -36,22 +36,20 @@ pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap { headers } -/// Determine if a header should be forwarded from backend to client -fn should_forward_header(name: &str) -> bool { +/// Determine if a header should be forwarded without allocating (case-insensitive) +fn should_forward_header_no_alloc(name: &str) -> bool { // List of headers that should NOT be forwarded (hop-by-hop headers) - !matches!( - name, - "connection" | - "keep-alive" | - "proxy-authenticate" | - "proxy-authorization" | - "te" | - "trailers" | - "transfer-encoding" | - "upgrade" | - "content-encoding" | // Let axum/hyper handle encoding - "host" // Should not forward the backend's host header - ) + // Use eq_ignore_ascii_case to avoid to_lowercase() allocation + !(name.eq_ignore_ascii_case("connection") + || name.eq_ignore_ascii_case("keep-alive") + || name.eq_ignore_ascii_case("proxy-authenticate") + || name.eq_ignore_ascii_case("proxy-authorization") + || name.eq_ignore_ascii_case("te") + || name.eq_ignore_ascii_case("trailers") + || name.eq_ignore_ascii_case("transfer-encoding") + || name.eq_ignore_ascii_case("upgrade") + || name.eq_ignore_ascii_case("content-encoding") + || name.eq_ignore_ascii_case("host")) } /// Apply headers to a reqwest request builder, filtering out headers that shouldn't be forwarded @@ -70,24 +68,27 @@ pub fn apply_request_headers( } // Forward other headers, filtering out problematic ones + // Use eq_ignore_ascii_case to avoid to_lowercase() allocation per header for (key, value) in headers.iter() { - let key_str = key.as_str().to_lowercase(); + let key_str = key.as_str(); // Skip headers that: // - Are set automatically by reqwest (content-type, content-length for POST/PUT) // - We already handled (authorization) // - Are hop-by-hop headers (connection, transfer-encoding) // - Should not be forwarded (host) - let should_skip = key_str == "authorization" || // Already handled above - key_str == "host" || - key_str == "connection" || - key_str == "transfer-encoding" || - key_str == "keep-alive" || - key_str == "te" || - key_str == "trailers" || - key_str == "accept-encoding" || - key_str == "upgrade" || - (skip_content_headers && (key_str == "content-type" || key_str == "content-length")); + let should_skip = key_str.eq_ignore_ascii_case("authorization") // Already handled above + || key_str.eq_ignore_ascii_case("host") + || key_str.eq_ignore_ascii_case("connection") + || key_str.eq_ignore_ascii_case("transfer-encoding") + || key_str.eq_ignore_ascii_case("keep-alive") + || key_str.eq_ignore_ascii_case("te") + || key_str.eq_ignore_ascii_case("trailers") + || key_str.eq_ignore_ascii_case("accept-encoding") + || key_str.eq_ignore_ascii_case("upgrade") + || (skip_content_headers + && (key_str.eq_ignore_ascii_case("content-type") + || key_str.eq_ignore_ascii_case("content-length"))); if !should_skip { request_builder = request_builder.header(key.clone(), value.clone()); diff --git a/sgl-model-gateway/src/routers/http/router.rs b/sgl-model-gateway/src/routers/http/router.rs index f98905220128..b7b6ee8c7696 100644 --- a/sgl-model-gateway/src/routers/http/router.rs +++ b/sgl-model-gateway/src/routers/http/router.rs @@ -11,6 +11,7 @@ use axum::{ Json, }; use futures_util::StreamExt; +use memchr::memmem; use reqwest::Client; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error}; @@ -91,8 +92,10 @@ impl Router { Ok(worker_url) => { let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint)); for (name, value) in headers { - let name_lc = name.to_lowercase(); - if name_lc != "content-type" && name_lc != "content-length" { + // Use eq_ignore_ascii_case to avoid string allocation + if !name.eq_ignore_ascii_case("content-type") + && !name.eq_ignore_ascii_case("content-length") + { request_builder = request_builder.header(name, value); } } @@ -300,6 +303,18 @@ impl Router { return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); } + // Pre-filter headers once before the loop to avoid repeated lowercasing + let filtered_headers: Vec<_> = headers + .map(|hdrs| { + hdrs.iter() + .filter(|(name, _)| { + !name.as_str().eq_ignore_ascii_case("content-type") + && !name.as_str().eq_ignore_ascii_case("content-length") + }) + .collect() + }) + .unwrap_or_default(); + let mut last_response: Option = None; for worker in workers { let worker_url = worker.url(); @@ -323,13 +338,9 @@ impl Router { request_builder.header("Authorization", format!("Bearer {}", api_key)); } - if let Some(hdrs) = headers { - for (name, value) in hdrs { - let name_lc = name.as_str().to_lowercase(); - if name_lc != "content-type" && name_lc != "content-length" { - request_builder = request_builder.header(name, value); - } - } + // Apply pre-filtered headers + for (name, value) in &filtered_headers { + request_builder = request_builder.header(*name, *value); } match request_builder.send().await { @@ -417,11 +428,9 @@ impl Router { is_stream: bool, load_incremented: bool, // Whether load was incremented for this request ) -> Response { - // Get the worker's API key if available - let api_key = self - .worker_registry - .get_by_url(worker_url) - .and_then(|w| w.api_key().clone()); + // Get the worker once and reuse for API key and load tracking + let worker = self.worker_registry.get_by_url(worker_url); + let api_key = worker.as_ref().and_then(|w| w.api_key().clone()); let mut request_builder = if self.dp_aware { let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { @@ -452,10 +461,13 @@ impl Router { String::from("data_parallel_rank"), serde_json::json!(dp_rank), ); - debug!( - "Modified request body: {}", - serde_json::to_string(&json_val).unwrap_or(String::from("ERR")) - ); + // Only serialize if debug logging is enabled to avoid CPU overhead + if tracing::enabled!(tracing::Level::DEBUG) { + debug!( + "Modified request body: {}", + serde_json::to_string(&json_val).unwrap_or_else(|_| String::from("ERR")) + ); + } } else { return ( StatusCode::BAD_REQUEST, @@ -497,9 +509,9 @@ impl Router { // Decrement load on error if it was incremented if load_incremented { - if let Some(worker) = self.worker_registry.get_by_url(worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); + if let Some(ref w) = worker { + w.decrement_load(); + RouterMetrics::set_running_requests(worker_url, w.load()); } } @@ -528,9 +540,9 @@ impl Router { Err(e) => { // IMPORTANT: Decrement load on error before returning if load_incremented { - if let Some(worker) = self.worker_registry.get_by_url(worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); + if let Some(ref w) = worker { + w.decrement_load(); + RouterMetrics::set_running_requests(worker_url, w.load()); } } @@ -541,17 +553,18 @@ impl Router { // Decrement load counter for non-streaming requests if it was incremented if load_incremented { - if let Some(worker) = self.worker_registry.get_by_url(worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(worker_url, worker.load()); + if let Some(ref w) = worker { + w.decrement_load(); + RouterMetrics::set_running_requests(worker_url, w.load()); } } response } else if load_incremented { // For streaming with load tracking, we need to manually decrement when done - let registry = Arc::clone(&self.worker_registry); - let worker_url = worker_url.to_string(); + // Clone the worker Arc for the async block instead of looking it up again + let stream_worker = worker.clone(); + let worker_url_owned = worker_url.to_string(); // Preserve headers for streaming response let mut response_headers = header_utils::preserve_response_headers(res.headers()); @@ -568,15 +581,14 @@ impl Router { while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { - // Check for stream end marker - if bytes - .as_ref() - .windows(12) - .any(|window| window == b"data: [DONE]") - { - if let Some(worker) = registry.get_by_url(&worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(&worker_url, worker.load()); + // Check for stream end marker using memmem for efficiency + if memmem::find(&bytes, b"data: [DONE]").is_some() { + if let Some(ref w) = stream_worker { + w.decrement_load(); + RouterMetrics::set_running_requests( + &worker_url_owned, + w.load(), + ); decremented = true; } } @@ -591,9 +603,9 @@ impl Router { } } if !decremented { - if let Some(worker) = registry.get_by_url(&worker_url) { - worker.decrement_load(); - RouterMetrics::set_running_requests(&worker_url, worker.load()); + if let Some(ref w) = stream_worker { + w.decrement_load(); + RouterMetrics::set_running_requests(&worker_url_owned, w.load()); } } }); diff --git a/sgl-model-gateway/src/routers/router_manager.rs b/sgl-model-gateway/src/routers/router_manager.rs index 9ac494ff9b81..ceafc5625d30 100644 --- a/sgl-model-gateway/src/routers/router_manager.rs +++ b/sgl-model-gateway/src/routers/router_manager.rs @@ -269,13 +269,13 @@ impl RouterManager { let mut best_router = None; let mut best_score = 0.0; - let num_regular_workers = self - .worker_registry - .get_all() + // Cache worker list to avoid duplicate get_all() calls + let all_workers = self.worker_registry.get_all(); + let num_regular_workers = all_workers .iter() .filter(|w| matches!(w.worker_type(), WorkerType::Regular)) .count(); - let num_pd_workers = self.worker_registry.get_all().len() - num_regular_workers; + let num_pd_workers = all_workers.len() - num_regular_workers; for router in candidate_routers { let mut score = 1.0;