Skip to content
Merged
1 change: 1 addition & 0 deletions sgl-model-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ parking_lot = "0.12.4"
rayon = "1.10"
thiserror = "2.0.12"
regex = "1.10"
memchr = "2.7" # SIMD-optimized byte pattern searching
url = "2.5.4"
validator = { version = "0.20.0", features = ["derive"] }
tokio-stream = { version = "0.1", features = ["sync"] }
Expand Down
61 changes: 33 additions & 28 deletions sgl-model-gateway/src/core/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,15 @@ impl std::str::FromStr for RuntimeType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"sglang" => Ok(RuntimeType::Sglang),
"vllm" => Ok(RuntimeType::Vllm),
"external" => Ok(RuntimeType::External),
_ => Err(format!("Unknown runtime type: {}", s)),
// Use eq_ignore_ascii_case to avoid to_lowercase() allocation
if s.eq_ignore_ascii_case("sglang") {
Ok(RuntimeType::Sglang)
} else if s.eq_ignore_ascii_case("vllm") {
Ok(RuntimeType::Vllm)
} else if s.eq_ignore_ascii_case("external") {
Ok(RuntimeType::External)
} else {
Err(format!("Unknown runtime type: {}", s))
}
}
}
Expand Down Expand Up @@ -516,22 +520,18 @@ impl fmt::Debug for BasicWorker {

impl BasicWorker {
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Use rfind to split from the right, handling IPv6 addresses with brackets
// e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0"
if let Some(at_pos) = self.url().rfind('@') {
let base_url = &self.url()[..at_pos];
let rank_str = &self.url()[at_pos + 1..];

// Validate that the rank part is actually a number
match rank_str.parse::<usize>() {
Ok(_) => Ok(base_url),
Err(_) => {
// The '@' is not a DP rank separator, return full URL
Ok(self.url())
}
}
// Use rfind directly - no need for redundant contains() check
// rfind already returns None if '@' is not found
// e.g., "http://[::1]:8080@0" -> "http://[::1]:8080" and "0"
if let Some(at_pos) = self.url().rfind('@') {
let base_url = &self.url()[..at_pos];
let rank_str = &self.url()[at_pos + 1..];

// Validate that the rank part is actually a number
if rank_str.parse::<usize>().is_ok() {
Ok(base_url)
} else {
// The '@' is not a DP rank separator, return full URL
Ok(self.url())
}
} else {
Expand Down Expand Up @@ -1085,33 +1085,38 @@ impl HealthChecker {

/// Helper to convert Worker trait object to WorkerInfo struct
pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
let worker_type_str = match worker.worker_type() {
// Cache values that are used multiple times to avoid redundant clones/allocations
let worker_type = worker.worker_type();
let connection_mode = worker.connection_mode();
let url = worker.url();
let model_id = worker.model_id();

let worker_type_str = match &worker_type {
WorkerType::Regular => "regular",
WorkerType::Prefill { .. } => "prefill",
WorkerType::Decode => "decode",
};

let bootstrap_port = match worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
let bootstrap_port = match &worker_type {
WorkerType::Prefill { bootstrap_port } => *bootstrap_port,
_ => None,
};

let runtime_type = match worker.connection_mode() {
let runtime_type = match &connection_mode {
ConnectionMode::Grpc { .. } => Some(worker.metadata().runtime_type.to_string()),
ConnectionMode::Http => None,
};

let model_id = worker.model_id();
WorkerInfo {
id: worker.url().to_string(),
url: worker.url().to_string(),
id: url.to_string(),
url: url.to_string(),
model_id: model_id.to_string(),
priority: worker.priority(),
cost: worker.cost(),
worker_type: worker_type_str.to_string(),
is_healthy: worker.is_healthy(),
load: worker.load(),
connection_mode: format!("{:?}", worker.connection_mode()),
connection_mode: connection_mode.to_string(),
runtime_type,
tokenizer_path: worker.tokenizer_path(model_id).map(String::from),
reasoning_parser: worker.reasoning_parser(model_id).map(String::from),
Expand Down
11 changes: 9 additions & 2 deletions sgl-model-gateway/src/core/worker_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,22 @@ impl WorkerRegistry {
/// Get worker statistics
pub fn stats(&self) -> WorkerRegistryStats {
let total_workers = self.workers.len();
let total_models = self.get_models().len();
// Count models directly instead of allocating Vec via get_models()
let total_models = self
.model_workers
.iter()
.filter(|entry| !entry.value().is_empty())
.count();

let mut healthy_count = 0;
let mut total_load = 0;
let mut regular_count = 0;
let mut prefill_count = 0;
let mut decode_count = 0;

for worker in self.get_all() {
// Iterate DashMap directly to avoid cloning all workers via get_all()
for entry in self.workers.iter() {
let worker = entry.value();
if worker.is_healthy() {
healthy_count += 1;
}
Expand Down
23 changes: 16 additions & 7 deletions sgl-model-gateway/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ pub async fn auth_middleware(
Ok(next.run(request).await)
}

/// Alphanumeric characters for request ID generation (as bytes for O(1) indexing)
const REQUEST_ID_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";

/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
let prefix = if path.contains("/chat/completions") {
Expand All @@ -94,12 +97,12 @@ fn generate_request_id(path: &str) -> String {
};

// Generate a random string similar to OpenAI's format
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
// Use byte array indexing (O(1)) instead of chars().nth() (O(n))
let mut rng = rand::rng();
let random_part: String = (0..24)
.map(|_| {
let idx = rng.random_range(0..chars.len());
chars.chars().nth(idx).unwrap()
let idx = rng.random_range(0..REQUEST_ID_CHARS.len());
REQUEST_ID_CHARS[idx] as char
})
.collect();

Expand Down Expand Up @@ -263,7 +266,8 @@ impl<B> OnResponse<B> for ResponseLogger {

// Record these in the span for structured logging/observability tools
span.record("status_code", status.as_u16());
span.record("latency", format!("{:?}", latency));
// Use microseconds as integer to avoid format! string allocation
span.record("latency", latency.as_micros() as u64);

// Log the response completion
let _enter = span.enter();
Expand Down Expand Up @@ -629,13 +633,18 @@ pub async fn wasm_middleware(
// Process each OnRequest module
let mut modified_body = body_bytes;

// Pre-compute strings once before the loop to avoid repeated allocations
let method_str = method.to_string();
let path_str = uri.path().to_string();
let query_str = uri.query().unwrap_or("").to_string();

for module in modules_on_request {
// Build WebAssembly request from collected data
let wasm_headers = build_wasm_headers_from_axum_headers(&headers);
let wasm_request = WasmRequest {
method: method.to_string(),
path: uri.path().to_string(),
query: uri.query().unwrap_or("").to_string(),
method: method_str.clone(),
path: path_str.clone(),
query: query_str.clone(),
headers: wasm_headers,
body: modified_body.clone(),
request_id: request_id.clone(),
Expand Down
8 changes: 4 additions & 4 deletions sgl-model-gateway/src/policies/bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use dashmap::DashMap;
use rand::Rng;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};
use uuid::Uuid;

use super::{get_healthy_worker_indices, BucketConfig, LoadBalancingPolicy};
Expand Down Expand Up @@ -259,14 +259,14 @@ impl LoadBalancingPolicy for BucketPolicy {
let rel_threshold = self.config.balance_rel_threshold * min_load as f32;
let is_imbalanced =
abs_diff > self.config.balance_abs_threshold && max_load as f32 > rel_threshold;
info!(
debug!(
"Current PD instance status | is_imbalanced={}",
is_imbalanced
);

let mut rng = rand::rng();
let prefill_url = if is_imbalanced {
info!("select prefill instance by Load Balance policy");
debug!("select prefill instance by Load Balance policy");
let min_url = chars_per_url_snapshot
.iter()
.min_by_key(|(_, &chars)| chars)
Expand All @@ -279,7 +279,7 @@ impl LoadBalancingPolicy for BucketPolicy {
});
min_url
} else {
info!("select prefill instance by Bucket policy");
debug!("select prefill instance by Bucket policy");
match choiced_url {
Some(url) if !url.is_empty() => url,
_ => {
Expand Down
10 changes: 6 additions & 4 deletions sgl-model-gateway/src/policies/cache_aware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
first_model
};

// Get current load statistics
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
let max_load = *loads.iter().max().unwrap_or(&0);
let min_load = *loads.iter().min().unwrap_or(&0);
// Get current load statistics - compute min/max in single pass without allocation
let (min_load, max_load) = workers.iter().fold((usize::MAX, 0usize), |(min, max), w| {
let load = w.load();
(min.min(load), max.max(load))
});
let min_load = if min_load == usize::MAX { 0 } else { min_load };

// Check if load is imbalanced
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
Expand Down
15 changes: 6 additions & 9 deletions sgl-model-gateway/src/policies/power_of_two.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

use rand::Rng;
use tracing::info;
use tracing::debug;

use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, observability::metrics::RouterMetrics};
Expand Down Expand Up @@ -57,15 +57,12 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy {
return Some(healthy_indices[0]);
}

// Select two random workers
// Select two random workers - use offset to guarantee different selection in O(1)
let mut rng = rand::rng();
let idx1 = rng.random_range(0..healthy_indices.len());
let mut idx2 = rng.random_range(0..healthy_indices.len());

// Ensure we pick two different workers
while idx2 == idx1 {
idx2 = rng.random_range(0..healthy_indices.len());
}
// Pick idx2 from remaining indices: offset by 1 + random from (len-1) to guarantee different
let idx2 =
(idx1 + 1 + rng.random_range(0..healthy_indices.len() - 1)) % healthy_indices.len();

let worker_idx1 = healthy_indices[idx1];
let worker_idx2 = healthy_indices[idx2];
Expand All @@ -81,7 +78,7 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy {
worker_idx2
};

info!(
debug!(
"Power-of-two selection: {}={} vs {}={} -> selected {}",
workers[worker_idx1].url(),
load1,
Expand Down
13 changes: 7 additions & 6 deletions sgl-model-gateway/src/policies/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,20 @@ impl PolicyRegistry {
power_of_two_policies.push(Arc::clone(&self.default_policy));
}

if let Some(ref policy) = *self.prefill_policy.read().unwrap() {
// Cache prefill and decode policies to avoid double-locking prefill_policy
let prefill_policy_opt = self.prefill_policy.read().unwrap().clone();
let decode_policy_opt = self.decode_policy.read().unwrap().clone();

if let Some(ref policy) = prefill_policy_opt {
if policy.name() == "power_of_two" && !Arc::ptr_eq(policy, &self.default_policy) {
power_of_two_policies.push(Arc::clone(policy));
}
}

if let Some(ref policy) = *self.decode_policy.read().unwrap() {
if let Some(ref policy) = decode_policy_opt {
if policy.name() == "power_of_two"
&& !Arc::ptr_eq(policy, &self.default_policy)
&& !self
.prefill_policy
.read()
.unwrap()
&& !prefill_policy_opt
.as_ref()
.is_some_and(|p| Arc::ptr_eq(p, policy))
{
Expand Down
Loading
Loading