diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index 00e879449cef..86af5511aa55 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -11,6 +11,7 @@ pub mod error; pub mod retry; pub mod token_bucket; pub mod worker; +pub mod worker_builder; pub mod worker_registry; // Re-export commonly used types at the module level @@ -23,4 +24,5 @@ pub use worker::{ start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig, Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, }; +pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder}; pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats}; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index b6fb25e75ff7..e68ef6886daf 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -586,6 +586,22 @@ impl DPAwareWorker { base_url, } } + + /// Create a new DP-aware worker with a pre-configured base worker + /// This is primarily used by the builder pattern + pub fn with_base_worker( + base_worker: BasicWorker, + base_url: String, + dp_rank: usize, + dp_size: usize, + ) -> Self { + Self { + base_worker, + dp_rank, + dp_size, + base_url, + } + } } #[async_trait] @@ -1102,7 +1118,7 @@ pub fn start_health_checker( // Periodically reset load counters to prevent drift // Only do this when we believe all workers should be idle - if check_count % LOAD_RESET_INTERVAL == 0 { + if check_count.is_multiple_of(LOAD_RESET_INTERVAL) { let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0); // Only reset if load appears to be very low (likely drift) if max_load <= 2 { diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs new file mode 100644 index 000000000000..741326a7cfcc --- /dev/null +++ b/sgl-router/src/core/worker_builder.rs @@ -0,0 +1,431 @@ +use super::circuit_breaker::CircuitBreakerConfig; +use super::worker::{BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerType}; +use crate::grpc::client::SglangSchedulerClient; +use std::collections::HashMap; + +/// Builder for creating BasicWorker instances with fluent API +pub struct BasicWorkerBuilder { + // Required fields + url: String, + + // Optional fields with defaults + worker_type: WorkerType, + connection_mode: ConnectionMode, + labels: HashMap, + health_config: HealthConfig, + circuit_breaker_config: CircuitBreakerConfig, + grpc_client: Option, +} + +impl BasicWorkerBuilder { + /// Create a new builder with only the URL (defaults to Regular worker type) + pub fn new(url: impl Into) -> Self { + Self { + url: url.into(), + worker_type: WorkerType::Regular, + connection_mode: ConnectionMode::Http, + labels: HashMap::new(), + health_config: HealthConfig::default(), + circuit_breaker_config: CircuitBreakerConfig::default(), + grpc_client: None, + } + } + + /// Create a new builder with URL and worker type (for backwards compatibility) + pub fn new_with_type(url: impl Into, worker_type: WorkerType) -> Self { + Self { + url: url.into(), + worker_type, + connection_mode: ConnectionMode::Http, + labels: HashMap::new(), + health_config: HealthConfig::default(), + circuit_breaker_config: CircuitBreakerConfig::default(), + grpc_client: None, + } + } + + /// Set the worker type (Regular, Prefill, or Decode) + pub fn worker_type(mut self, worker_type: WorkerType) -> Self { + self.worker_type = worker_type; + self + } + + /// Set the connection mode (HTTP or gRPC) + pub fn connection_mode(mut self, mode: ConnectionMode) -> Self { + self.connection_mode = mode; + self + } + + /// Set labels for worker identification + pub fn labels(mut self, labels: HashMap) -> Self { + self.labels = labels; + self + } + + /// Add a single label + pub fn label(mut self, key: impl Into, value: impl Into) -> Self { + self.labels.insert(key.into(), value.into()); + self + } + + /// Set health check configuration + pub fn health_config(mut self, config: HealthConfig) -> Self { + self.health_config = config; + self + } + + /// Set circuit breaker configuration + pub fn circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self { + self.circuit_breaker_config = config; + self + } + + /// Set gRPC client for gRPC workers + pub fn grpc_client(mut self, client: SglangSchedulerClient) -> Self { + self.grpc_client = Some(client); + self + } + + /// Build the BasicWorker instance + pub fn build(self) -> BasicWorker { + // Use the existing constructor methods for now + let mut worker = + BasicWorker::with_connection_mode(self.url, self.worker_type, self.connection_mode); + + // Apply optional configurations using existing methods + if !self.labels.is_empty() { + worker = worker.with_labels(self.labels); + } + + worker = worker.with_health_config(self.health_config); + worker = worker.with_circuit_breaker_config(self.circuit_breaker_config); + + if let Some(client) = self.grpc_client { + worker = worker.with_grpc_client(client); + } + + worker + } +} + +/// Builder for creating DPAwareWorker instances with fluent API +pub struct DPAwareWorkerBuilder { + // Required fields + base_url: String, + dp_rank: usize, + dp_size: usize, + + // Optional fields with defaults + worker_type: WorkerType, + connection_mode: ConnectionMode, + labels: HashMap, + health_config: HealthConfig, + circuit_breaker_config: CircuitBreakerConfig, + grpc_client: Option, +} + +impl DPAwareWorkerBuilder { + /// Create a new DP-aware worker builder (defaults to Regular worker type) + pub fn new(base_url: impl Into, dp_rank: usize, dp_size: usize) -> Self { + Self { + base_url: base_url.into(), + dp_rank, + dp_size, + worker_type: WorkerType::Regular, + connection_mode: ConnectionMode::Http, + labels: HashMap::new(), + health_config: HealthConfig::default(), + circuit_breaker_config: CircuitBreakerConfig::default(), + grpc_client: None, + } + } + + /// Create a new DP-aware worker builder with worker type (for backwards compatibility) + pub fn new_with_type( + base_url: impl Into, + dp_rank: usize, + dp_size: usize, + worker_type: WorkerType, + ) -> Self { + Self { + base_url: base_url.into(), + dp_rank, + dp_size, + worker_type, + connection_mode: ConnectionMode::Http, + labels: HashMap::new(), + health_config: HealthConfig::default(), + circuit_breaker_config: CircuitBreakerConfig::default(), + grpc_client: None, + } + } + + /// Set the worker type (Regular, Prefill, or Decode) + pub fn worker_type(mut self, worker_type: WorkerType) -> Self { + self.worker_type = worker_type; + self + } + + /// Set the connection mode (HTTP or gRPC) + pub fn connection_mode(mut self, mode: ConnectionMode) -> Self { + self.connection_mode = mode; + self + } + + /// Set labels for worker identification + pub fn labels(mut self, labels: HashMap) -> Self { + self.labels = labels; + self + } + + /// Add a single label + pub fn label(mut self, key: impl Into, value: impl Into) -> Self { + self.labels.insert(key.into(), value.into()); + self + } + + /// Set health check configuration + pub fn health_config(mut self, config: HealthConfig) -> Self { + self.health_config = config; + self + } + + /// Set circuit breaker configuration + pub fn circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self { + self.circuit_breaker_config = config; + self + } + + /// Set gRPC client for gRPC workers + pub fn grpc_client(mut self, client: SglangSchedulerClient) -> Self { + self.grpc_client = Some(client); + self + } + + /// Build the DPAwareWorker instance + pub fn build(self) -> DPAwareWorker { + // Create URL with DP rank suffix for identification + let worker_url = format!("{}@{}", self.base_url, self.dp_rank); + + // Use BasicWorkerBuilder to create a properly configured base worker + let mut builder = BasicWorkerBuilder::new(worker_url) + .worker_type(self.worker_type) + .connection_mode(self.connection_mode) + .labels(self.labels) + .health_config(self.health_config) + .circuit_breaker_config(self.circuit_breaker_config); + + // Add gRPC client if provided + if let Some(client) = self.grpc_client { + builder = builder.grpc_client(client); + } + + let base_worker = builder.build(); + + // Create the DPAwareWorker with the configured base worker + DPAwareWorker::with_base_worker(base_worker, self.base_url, self.dp_rank, self.dp_size) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::worker::Worker; + use std::time::Duration; + + #[test] + fn test_basic_worker_builder_minimal() { + // Using new API - defaults to Regular type + let worker = BasicWorkerBuilder::new("http://localhost:8080").build(); + + assert_eq!(worker.url(), "http://localhost:8080"); + assert_eq!(worker.worker_type(), WorkerType::Regular); + assert_eq!(worker.connection_mode(), ConnectionMode::Http); + assert!(worker.is_healthy()); + } + + #[test] + fn test_basic_worker_builder_with_type() { + // Test setting worker type explicitly + let worker = BasicWorkerBuilder::new("http://localhost:8080") + .worker_type(WorkerType::Decode) + .build(); + + assert_eq!(worker.url(), "http://localhost:8080"); + assert_eq!(worker.worker_type(), WorkerType::Decode); + assert_eq!(worker.connection_mode(), ConnectionMode::Http); + assert!(worker.is_healthy()); + } + + #[test] + fn test_basic_worker_builder_full() { + let mut labels = HashMap::new(); + labels.insert("env".to_string(), "prod".to_string()); + labels.insert("region".to_string(), "us-east".to_string()); + + let health_config = HealthConfig { + endpoint: "/health".to_string(), + timeout_secs: 30, + check_interval_secs: 60, + failure_threshold: 3, + success_threshold: 2, + }; + + let cb_config = CircuitBreakerConfig { + failure_threshold: 10, + success_threshold: 5, + timeout_duration: Duration::from_millis(2000), + window_duration: Duration::from_millis(30000), + }; + + let worker = BasicWorkerBuilder::new("http://localhost:8080") + .worker_type(WorkerType::Prefill { + bootstrap_port: None, + }) + .connection_mode(ConnectionMode::Grpc { port: Some(50051) }) + .labels(labels.clone()) + .health_config(health_config.clone()) + .circuit_breaker_config(cb_config) + .build(); + + assert_eq!(worker.url(), "http://localhost:8080"); + assert_eq!( + worker.worker_type(), + WorkerType::Prefill { + bootstrap_port: None + } + ); + assert_eq!( + worker.connection_mode(), + ConnectionMode::Grpc { port: Some(50051) } + ); + assert_eq!(worker.metadata().labels, labels); + // Can't directly compare HealthConfig without PartialEq, so check individual fields + assert_eq!( + worker.metadata().health_config.endpoint, + health_config.endpoint + ); + assert_eq!( + worker.metadata().health_config.timeout_secs, + health_config.timeout_secs + ); + assert_eq!( + worker.metadata().health_config.check_interval_secs, + health_config.check_interval_secs + ); + assert_eq!( + worker.metadata().health_config.failure_threshold, + health_config.failure_threshold + ); + assert_eq!( + worker.metadata().health_config.success_threshold, + health_config.success_threshold + ); + } + + #[test] + fn test_basic_worker_builder_with_single_label() { + let worker = BasicWorkerBuilder::new("http://localhost:8080") + .worker_type(WorkerType::Decode) + .label("env", "staging") + .label("version", "v1.2.3") + .build(); + + assert_eq!( + worker.metadata().labels.get("env"), + Some(&"staging".to_string()) + ); + assert_eq!( + worker.metadata().labels.get("version"), + Some(&"v1.2.3".to_string()) + ); + } + + #[test] + fn test_dp_aware_worker_builder_minimal() { + // Using new API - defaults to Regular type + let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 2, 8).build(); + + assert_eq!(worker.url(), "http://localhost:8080@2"); + assert_eq!(worker.dp_rank(), Some(2)); + assert_eq!(worker.dp_size(), Some(8)); + // Note: base_url is a private field, we can only test through the url() method + assert_eq!(worker.worker_type(), WorkerType::Regular); + } + + #[test] + fn test_dp_aware_worker_builder_full() { + let mut labels = HashMap::new(); + labels.insert("cluster".to_string(), "main".to_string()); + + let health_config = HealthConfig { + endpoint: "/status".to_string(), + timeout_secs: 20, + check_interval_secs: 45, + failure_threshold: 5, + success_threshold: 3, + }; + + let worker = DPAwareWorkerBuilder::new("http://localhost:8080", 3, 16) + .worker_type(WorkerType::Prefill { + bootstrap_port: Some(9090), + }) + .connection_mode(ConnectionMode::Http) + .labels(labels.clone()) + .health_config(health_config.clone()) + .build(); + + assert_eq!(worker.url(), "http://localhost:8080@3"); + assert_eq!(worker.dp_rank(), Some(3)); + assert_eq!(worker.dp_size(), Some(16)); + assert_eq!(worker.metadata().labels, labels); + // Can't directly compare HealthConfig without PartialEq, so check individual fields + assert_eq!( + worker.metadata().health_config.endpoint, + health_config.endpoint + ); + assert_eq!( + worker.metadata().health_config.timeout_secs, + health_config.timeout_secs + ); + assert_eq!( + worker.metadata().health_config.check_interval_secs, + health_config.check_interval_secs + ); + assert_eq!( + worker.metadata().health_config.failure_threshold, + health_config.failure_threshold + ); + assert_eq!( + worker.metadata().health_config.success_threshold, + health_config.success_threshold + ); + } + + #[test] + fn test_dp_aware_worker_with_grpc() { + // Test that DPAwareWorkerBuilder can set a gRPC client + let worker = DPAwareWorkerBuilder::new("grpc://cluster.local", 1, 4) + .worker_type(WorkerType::Decode) + .connection_mode(ConnectionMode::Grpc { port: Some(50051) }) + .label("transport", "grpc") + .build(); + + assert_eq!(worker.url(), "grpc://cluster.local@1"); + assert_eq!(worker.dp_rank(), Some(1)); + assert_eq!(worker.dp_size(), Some(4)); + assert_eq!(worker.worker_type(), WorkerType::Decode); + assert_eq!( + worker.connection_mode(), + ConnectionMode::Grpc { port: Some(50051) } + ); + assert_eq!( + worker.metadata().labels.get("transport"), + Some(&"grpc".to_string()) + ); + + // Note: We can't directly test the grpc_client as it's private, + // but the fact that the worker builds successfully with grpc connection mode + // validates that the configuration is properly passed through + } +} diff --git a/sgl-router/src/core/worker_registry.rs b/sgl-router/src/core/worker_registry.rs index 65a74cc0208e..845ebc223578 100644 --- a/sgl-router/src/core/worker_registry.rs +++ b/sgl-router/src/core/worker_registry.rs @@ -390,7 +390,7 @@ impl WorkerRegistry { // Reset loads periodically check_count += 1; - if check_count % LOAD_RESET_INTERVAL == 0 { + if check_count.is_multiple_of(LOAD_RESET_INTERVAL) { tracing::debug!("Resetting worker loads (cycle {})", check_count); for worker in &workers { worker.reset_load();