diff --git a/sgl-model-gateway/src/core/circuit_breaker.rs b/sgl-model-gateway/src/core/circuit_breaker.rs index 3a91f32fc8f8..f629de17a82d 100644 --- a/sgl-model-gateway/src/core/circuit_breaker.rs +++ b/sgl-model-gateway/src/core/circuit_breaker.rs @@ -88,9 +88,7 @@ impl CircuitBreaker { /// Check if a request can be executed pub fn can_execute(&self) -> bool { - self.check_and_update_state(); - - let state = *self.state.read().unwrap(); + let state = self.state(); match state { CircuitState::Closed => true, CircuitState::Open => false, @@ -100,8 +98,21 @@ impl CircuitBreaker { /// Get the current state pub fn state(&self) -> CircuitState { - self.check_and_update_state(); - *self.state.read().unwrap() + self.check_and_update_state_returning() + } + + /// Check and update state, returning the current state to avoid double lock + fn check_and_update_state_returning(&self) -> CircuitState { + let current_state = *self.state.read().unwrap(); + + if current_state == CircuitState::Open { + let last_change = *self.last_state_change.read().unwrap(); + if last_change.elapsed() >= self.config.timeout_duration { + self.transition_to(CircuitState::HalfOpen); + return CircuitState::HalfOpen; + } + } + current_state } /// Record the outcome of a request @@ -160,18 +171,6 @@ impl CircuitBreaker { } } - /// Check and update state based on timeout - fn check_and_update_state(&self) { - let current_state = *self.state.read().unwrap(); - - if current_state == CircuitState::Open { - let last_change = *self.last_state_change.read().unwrap(); - if last_change.elapsed() >= self.config.timeout_duration { - self.transition_to(CircuitState::HalfOpen); - } - } - } - /// Transition to a new state fn transition_to(&self, new_state: CircuitState) { let mut state = self.state.write().unwrap(); diff --git a/sgl-model-gateway/src/core/worker.rs b/sgl-model-gateway/src/core/worker.rs index 77c260056fe9..006b124851f1 100644 --- a/sgl-model-gateway/src/core/worker.rs +++ b/sgl-model-gateway/src/core/worker.rs @@ -22,9 +22,18 @@ use crate::{ routers::grpc::client::GrpcClient, }; +/// Default worker priority (mid-range on 0-100 scale) +pub const DEFAULT_WORKER_PRIORITY: u32 = 50; + +/// Default worker cost factor (baseline cost) +pub const DEFAULT_WORKER_COST: f32 = 1.0; + +/// Default HTTP client timeout for worker requests (in seconds) +pub const DEFAULT_WORKER_HTTP_TIMEOUT_SECS: u64 = 30; + static WORKER_CLIENT: LazyLock = LazyLock::new(|| { reqwest::Client::builder() - .timeout(Duration::from_secs(30)) + .timeout(Duration::from_secs(DEFAULT_WORKER_HTTP_TIMEOUT_SECS)) .build() .expect("Failed to create worker HTTP client") }); @@ -37,10 +46,12 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Get the worker's API key fn api_key(&self) -> &Option; /// Get the worker's type (Regular, Prefill, or Decode) - fn worker_type(&self) -> WorkerType; + /// Returns a reference to avoid cloning on every access + fn worker_type(&self) -> &WorkerType; /// Get the worker's connection mode (HTTP or gRPC) - fn connection_mode(&self) -> ConnectionMode; + /// Returns a reference to avoid cloning on every access + fn connection_mode(&self) -> &ConnectionMode; /// Get the bootstrap hostname for PD mode /// Returns cached hostname parsed from URL at construction time @@ -64,6 +75,18 @@ pub trait Worker: Send + Sync + fmt::Debug { async fn check_health_async(&self) -> WorkerResult<()>; /// Synchronous health check wrapper (for compatibility) + /// + /// # Deprecation Notice + /// This method creates a new Tokio runtime for each call, which is expensive. + /// Prefer using `check_health_async()` within an async context instead. + /// + /// # Performance Warning + /// Creating a runtime per call has significant overhead. Only use this + /// method when you cannot use the async version. + #[deprecated( + since = "0.4.6", + note = "Use check_health_async() instead. This method creates a new Tokio runtime per call." + )] fn check_health(&self) -> WorkerResult<()> { tokio::runtime::Builder::new_current_thread() .enable_all() @@ -191,16 +214,16 @@ pub trait Worker: Send + Sync + fmt::Debug { .labels .get("priority") .and_then(|s| s.parse().ok()) - .unwrap_or(50) // Default priority is 50 (mid-range) + .unwrap_or(DEFAULT_WORKER_PRIORITY) } - /// Get the cost factor of this worker (1.0 = baseline) + /// Get the cost factor of this worker (baseline = 1.0) fn cost(&self) -> f32 { self.metadata() .labels .get("cost") .and_then(|s| s.parse().ok()) - .unwrap_or(1.0) + .unwrap_or(DEFAULT_WORKER_COST) } /// Get tokenizer path for a specific model. @@ -550,12 +573,12 @@ impl Worker for BasicWorker { &self.metadata.api_key } - fn worker_type(&self) -> WorkerType { - self.metadata.worker_type.clone() + fn worker_type(&self) -> &WorkerType { + &self.metadata.worker_type } - fn connection_mode(&self) -> ConnectionMode { - self.metadata.connection_mode.clone() + fn connection_mode(&self) -> &ConnectionMode { + &self.metadata.connection_mode } fn is_healthy(&self) -> bool { @@ -832,11 +855,11 @@ impl Worker for DPAwareWorker { self.base_worker.api_key() } - fn worker_type(&self) -> WorkerType { + fn worker_type(&self) -> &WorkerType { self.base_worker.worker_type() } - fn connection_mode(&self) -> ConnectionMode { + fn connection_mode(&self) -> &ConnectionMode { self.base_worker.connection_mode() } @@ -1085,24 +1108,24 @@ impl HealthChecker { /// Helper to convert Worker trait object to WorkerInfo struct pub fn worker_to_info(worker: &Arc) -> WorkerInfo { - // Cache values that are used multiple times to avoid redundant clones/allocations + // Cache references that are used multiple times to avoid redundant method calls let worker_type = worker.worker_type(); let connection_mode = worker.connection_mode(); let url = worker.url(); let model_id = worker.model_id(); - let worker_type_str = match &worker_type { + let worker_type_str = match worker_type { WorkerType::Regular => "regular", WorkerType::Prefill { .. } => "prefill", WorkerType::Decode => "decode", }; - let bootstrap_port = match &worker_type { + let bootstrap_port = match worker_type { WorkerType::Prefill { bootstrap_port } => *bootstrap_port, _ => None, }; - let runtime_type = match &connection_mode { + let runtime_type = match connection_mode { ConnectionMode::Grpc { .. } => Some(worker.metadata().runtime_type.to_string()), ConnectionMode::Http => None, }; @@ -1219,7 +1242,7 @@ mod tests { .worker_type(WorkerType::Regular) .build(); assert_eq!(worker.url(), "http://test:8080"); - assert_eq!(worker.worker_type(), WorkerType::Regular); + assert_eq!(worker.worker_type(), &WorkerType::Regular); assert!(worker.is_healthy()); assert_eq!(worker.load(), 0); assert_eq!(worker.processed_requests(), 0); @@ -1276,7 +1299,7 @@ mod tests { let regular = BasicWorkerBuilder::new("http://test:8080") .worker_type(WorkerType::Regular) .build(); - assert_eq!(regular.worker_type(), WorkerType::Regular); + assert_eq!(regular.worker_type(), &WorkerType::Regular); let prefill = BasicWorkerBuilder::new("http://test:8080") .worker_type(WorkerType::Prefill { @@ -1285,7 +1308,7 @@ mod tests { .build(); assert_eq!( prefill.worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: Some(9090) } ); @@ -1293,7 +1316,7 @@ mod tests { let decode = BasicWorkerBuilder::new("http://test:8080") .worker_type(WorkerType::Decode) .build(); - assert_eq!(decode.worker_type(), WorkerType::Decode); + assert_eq!(decode.worker_type(), &WorkerType::Decode); } #[test] @@ -1444,7 +1467,7 @@ mod tests { .build(), ); assert_eq!(worker.url(), "http://regular:8080"); - assert_eq!(worker.worker_type(), WorkerType::Regular); + assert_eq!(worker.worker_type(), &WorkerType::Regular); } #[test] @@ -1459,7 +1482,7 @@ mod tests { assert_eq!(worker1.url(), "http://prefill:8080"); assert_eq!( worker1.worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: Some(9090) } ); @@ -1473,7 +1496,7 @@ mod tests { ); assert_eq!( worker2.worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: None } ); @@ -1487,7 +1510,7 @@ mod tests { .build(), ); assert_eq!(worker.url(), "http://decode:8080"); - assert_eq!(worker.worker_type(), WorkerType::Decode); + assert_eq!(worker.worker_type(), &WorkerType::Decode); } #[test] @@ -1573,7 +1596,7 @@ mod tests { assert_eq!(workers.len(), 2); assert_eq!(workers[0].url(), "http://w1:8080"); assert_eq!(workers[1].url(), "http://w2:8080"); - assert_eq!(workers[0].worker_type(), WorkerType::Regular); + assert_eq!(workers[0].worker_type(), &WorkerType::Regular); } #[test] @@ -1595,14 +1618,15 @@ mod tests { assert_eq!(urls, vec!["http://w1:8080", "http://w2:8080"]); } - #[test] - fn test_check_health_sync_wrapper() { + #[tokio::test] + async fn test_check_health_async() { use crate::core::BasicWorkerBuilder; let worker = BasicWorkerBuilder::new("http://test:8080") .worker_type(WorkerType::Regular) .build(); - let result = worker.check_health(); + // Health check should fail since there's no actual server + let result = worker.check_health_async().await; assert!(result.is_err()); } @@ -1640,7 +1664,7 @@ mod tests { assert!(dp_worker.is_dp_aware()); assert_eq!(dp_worker.dp_rank(), Some(2)); assert_eq!(dp_worker.dp_size(), Some(4)); - assert_eq!(dp_worker.worker_type(), WorkerType::Regular); + assert_eq!(dp_worker.worker_type(), &WorkerType::Regular); } #[test] @@ -1655,7 +1679,7 @@ mod tests { assert!(dp_worker.is_dp_aware()); assert_eq!( dp_worker.worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: Some(9090) } ); @@ -1669,7 +1693,7 @@ mod tests { assert_eq!(dp_worker.url(), "http://worker1:8080@0"); assert!(dp_worker.is_dp_aware()); - assert_eq!(dp_worker.worker_type(), WorkerType::Decode); + assert_eq!(dp_worker.worker_type(), &WorkerType::Decode); } #[tokio::test] @@ -1760,7 +1784,7 @@ mod tests { assert!(worker.is_dp_aware()); assert_eq!(worker.dp_rank(), Some(1)); assert_eq!(worker.dp_size(), Some(4)); - assert_eq!(worker.worker_type(), WorkerType::Regular); + assert_eq!(worker.worker_type(), &WorkerType::Regular); } #[tokio::test] @@ -1779,7 +1803,7 @@ mod tests { assert!(worker.is_dp_aware()); assert_eq!( worker.worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: Some(8090) } ); @@ -1919,22 +1943,22 @@ mod tests { assert!(workers[4].is_dp_aware()); assert!(workers[5].is_dp_aware()); - assert_eq!(workers[0].worker_type(), WorkerType::Regular); + assert_eq!(workers[0].worker_type(), &WorkerType::Regular); assert_eq!( workers[1].worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: Some(9090) } ); - assert_eq!(workers[2].worker_type(), WorkerType::Decode); - assert_eq!(workers[3].worker_type(), WorkerType::Regular); + assert_eq!(workers[2].worker_type(), &WorkerType::Decode); + assert_eq!(workers[3].worker_type(), &WorkerType::Regular); assert_eq!( workers[4].worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: None } ); - assert_eq!(workers[5].worker_type(), WorkerType::Decode); + assert_eq!(workers[5].worker_type(), &WorkerType::Decode); } // === Phase 1.3: WorkerMetadata model methods tests === diff --git a/sgl-model-gateway/src/core/worker_builder.rs b/sgl-model-gateway/src/core/worker_builder.rs index a650e2f95b63..63ff5b0993cb 100644 --- a/sgl-model-gateway/src/core/worker_builder.rs +++ b/sgl-model-gateway/src/core/worker_builder.rs @@ -352,8 +352,8 @@ mod tests { let worker = BasicWorkerBuilder::new("http://localhost:8080").build(); assert_eq!(worker.url(), "http://localhost:8080"); - assert_eq!(worker.worker_type(), WorkerType::Regular); - assert_eq!(worker.connection_mode(), ConnectionMode::Http); + assert_eq!(worker.worker_type(), &WorkerType::Regular); + assert_eq!(worker.connection_mode(), &ConnectionMode::Http); assert!(worker.is_healthy()); } @@ -364,8 +364,8 @@ mod tests { .build(); assert_eq!(worker.url(), "http://localhost:8080"); - assert_eq!(worker.worker_type(), WorkerType::Decode); - assert_eq!(worker.connection_mode(), ConnectionMode::Http); + assert_eq!(worker.worker_type(), &WorkerType::Decode); + assert_eq!(worker.connection_mode(), &ConnectionMode::Http); assert!(worker.is_healthy()); } @@ -403,13 +403,13 @@ mod tests { assert_eq!(worker.url(), "http://localhost:8080"); assert_eq!( worker.worker_type(), - WorkerType::Prefill { + &WorkerType::Prefill { bootstrap_port: None } ); assert_eq!( worker.connection_mode(), - ConnectionMode::Grpc { port: Some(50051) } + &ConnectionMode::Grpc { port: Some(50051) } ); assert_eq!(worker.metadata().labels, labels); assert_eq!( @@ -459,7 +459,7 @@ mod tests { assert_eq!(worker.url(), "http://localhost:8080@2"); assert_eq!(worker.dp_rank(), Some(2)); assert_eq!(worker.dp_size(), Some(8)); - assert_eq!(worker.worker_type(), WorkerType::Regular); + assert_eq!(worker.worker_type(), &WorkerType::Regular); } #[test] @@ -522,10 +522,10 @@ mod tests { assert_eq!(worker.url(), "grpc://cluster.local@1"); assert_eq!(worker.dp_rank(), Some(1)); assert_eq!(worker.dp_size(), Some(4)); - assert_eq!(worker.worker_type(), WorkerType::Decode); + assert_eq!(worker.worker_type(), &WorkerType::Decode); assert_eq!( worker.connection_mode(), - ConnectionMode::Grpc { port: Some(50051) } + &ConnectionMode::Grpc { port: Some(50051) } ); assert_eq!( worker.metadata().labels.get("transport"), diff --git a/sgl-model-gateway/src/core/worker_registry.rs b/sgl-model-gateway/src/core/worker_registry.rs index 6a3d9a7fec52..be79597719f5 100644 --- a/sgl-model-gateway/src/core/worker_registry.rs +++ b/sgl-model-gateway/src/core/worker_registry.rs @@ -36,6 +36,7 @@ impl Default for WorkerId { } } +/// Model index type for O(1) lookups (stores Arc directly) type ModelIndex = Arc>>>>>; /// Worker registry with model-based indexing @@ -44,10 +45,8 @@ pub struct WorkerRegistry { /// All workers indexed by ID workers: Arc>>, - /// Workers indexed by model ID (stores WorkerId for reference) - model_workers: Arc>>, - - /// Optimized model index for O(1) lookups (stores Arc directly) + /// Model index for O(1) lookups (stores Arc directly) + /// This replaces the previous dual-index approach for better memory efficiency model_index: ModelIndex, /// Workers indexed by worker type @@ -55,6 +54,7 @@ pub struct WorkerRegistry { /// Workers indexed by connection mode connection_workers: Arc>>, + /// URL to worker ID mapping url_to_id: Arc>, } @@ -64,7 +64,6 @@ impl WorkerRegistry { pub fn new() -> Self { Self { workers: Arc::new(DashMap::new()), - model_workers: Arc::new(DashMap::new()), model_index: Arc::new(DashMap::new()), type_workers: Arc::new(DashMap::new()), connection_workers: Arc::new(DashMap::new()), @@ -88,14 +87,8 @@ impl WorkerRegistry { self.url_to_id .insert(worker.url().to_string(), worker_id.clone()); - // Update model index (both ID-based and optimized) + // Update model index for O(1) lookups let model_id = worker.model_id().to_string(); - self.model_workers - .entry(model_id.clone()) - .or_default() - .push(worker_id.clone()); - - // Update optimized model index for O(1) lookups self.model_index .entry(model_id) .or_insert_with(|| Arc::new(RwLock::new(Vec::new()))) @@ -103,15 +96,15 @@ impl WorkerRegistry { .expect("RwLock for model_index is poisoned") .push(worker.clone()); - // Update type index + // Update type index (clone needed for DashMap key ownership) self.type_workers - .entry(worker.worker_type()) + .entry(worker.worker_type().clone()) .or_default() .push(worker_id.clone()); - // Update connection mode index + // Update connection mode index (clone needed for DashMap key ownership) self.connection_workers - .entry(worker.connection_mode()) + .entry(worker.connection_mode().clone()) .or_default() .push(worker_id.clone()); @@ -124,12 +117,7 @@ impl WorkerRegistry { // Remove from URL mapping self.url_to_id.remove(worker.url()); - // Remove from model index (both ID-based and optimized) - if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) { - model_workers.retain(|id| id != worker_id); - } - - // Remove from optimized model index + // Remove from model index if let Some(model_index_entry) = self.model_index.get(worker.model_id()) { let worker_url = worker.url(); model_index_entry @@ -139,13 +127,13 @@ impl WorkerRegistry { } // Remove from type index - if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) { + if let Some(mut type_workers) = self.type_workers.get_mut(worker.worker_type()) { type_workers.retain(|id| id != worker_id); } // Remove from connection mode index if let Some(mut conn_workers) = - self.connection_workers.get_mut(&worker.connection_mode()) + self.connection_workers.get_mut(worker.connection_mode()) { conn_workers.retain(|id| id != worker_id); } @@ -178,17 +166,9 @@ impl WorkerRegistry { self.url_to_id.get(url).and_then(|id| self.get(&id)) } - /// Get all workers for a model + /// Get all workers for a model (O(1) optimized) + /// Uses the pre-indexed model_index for fast lookups pub fn get_by_model(&self, model_id: &str) -> Vec> { - self.model_workers - .get(model_id) - .map(|ids| ids.iter().filter_map(|id| self.get(id)).collect()) - .unwrap_or_default() - } - - /// Get all workers for a model (O(1) optimized version) - /// This method uses the pre-indexed model_index for fast lookups - pub fn get_by_model_fast(&self, model_id: &str) -> Vec> { self.model_index .get(model_id) .map(|workers| { @@ -200,6 +180,12 @@ impl WorkerRegistry { .unwrap_or_default() } + /// Alias for get_by_model for backwards compatibility + #[inline] + pub fn get_by_model_fast(&self, model_id: &str) -> Vec> { + self.get_by_model(model_id) + } + /// Get all workers by worker type pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec> { self.type_workers @@ -273,9 +259,15 @@ impl WorkerRegistry { /// Get all model IDs with workers pub fn get_models(&self) -> Vec { - self.model_workers + self.model_index .iter() - .filter(|entry| !entry.value().is_empty()) + .filter(|entry| { + entry + .value() + .read() + .map(|workers| !workers.is_empty()) + .unwrap_or(false) + }) .map(|entry| entry.key().clone()) .collect() } @@ -310,7 +302,7 @@ impl WorkerRegistry { .filter(|w| { // Check worker_type if specified if let Some(ref wtype) = worker_type { - if w.worker_type() != *wtype { + if *w.worker_type() != *wtype { return false; } } @@ -344,9 +336,15 @@ impl WorkerRegistry { let total_workers = self.workers.len(); // Count models directly instead of allocating Vec via get_models() let total_models = self - .model_workers + .model_index .iter() - .filter(|entry| !entry.value().is_empty()) + .filter(|entry| { + entry + .value() + .read() + .map(|workers| !workers.is_empty()) + .unwrap_or(false) + }) .count(); let mut healthy_count = 0; @@ -416,10 +414,18 @@ impl WorkerRegistry { .map(|entry| entry.value().clone()) .collect(); - // Perform health checks - for worker in &workers { - let _ = worker.check_health_async().await; // Use async version directly - } + // Perform health checks in parallel for better performance + // This is especially important when there are many workers + let health_futures: Vec<_> = workers + .iter() + .map(|worker| { + let worker = worker.clone(); + async move { + let _ = worker.check_health_async().await; + } + }) + .collect(); + futures::future::join_all(health_futures).await; // Reset loads periodically check_count += 1; diff --git a/sgl-model-gateway/tests/test_pd_routing.rs b/sgl-model-gateway/tests/test_pd_routing.rs index f26049f70df4..f6a2e069a2a3 100644 --- a/sgl-model-gateway/tests/test_pd_routing.rs +++ b/sgl-model-gateway/tests/test_pd_routing.rs @@ -51,7 +51,7 @@ mod test_pd_routing { assert_eq!(prefill_worker.url(), "http://prefill:8080"); match prefill_worker.worker_type() { WorkerType::Prefill { bootstrap_port } => { - assert_eq!(bootstrap_port, Some(9000)); + assert_eq!(*bootstrap_port, Some(9000)); } _ => panic!("Expected Prefill worker type"), } @@ -353,7 +353,7 @@ mod test_pd_routing { let bootstrap_port = match prefill_worker.worker_type() { WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, + _ => &None, }; single_json["bootstrap_host"] = json!(prefill_worker.bootstrap_host()); @@ -697,7 +697,7 @@ mod test_pd_routing { let bootstrap_port = match prefill_worker.worker_type() { WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, + _ => &None, }; let batch_size = 16; let hostname = prefill_worker.bootstrap_host(); @@ -823,7 +823,7 @@ mod test_pd_routing { let bootstrap_port = match prefill_worker.worker_type() { WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, + _ => &None, }; let hostname = prefill_worker.bootstrap_host();