Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions sgl-model-gateway/src/core/circuit_breaker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ impl CircuitBreaker {

/// Check if a request can be executed
pub fn can_execute(&self) -> bool {
self.check_and_update_state();

let state = *self.state.read().unwrap();
let state = self.state();
match state {
CircuitState::Closed => true,
CircuitState::Open => false,
Expand All @@ -100,8 +98,21 @@ impl CircuitBreaker {

/// Get the current state
pub fn state(&self) -> CircuitState {
self.check_and_update_state();
*self.state.read().unwrap()
self.check_and_update_state_returning()
}

/// Check and update state, returning the current state to avoid double lock
fn check_and_update_state_returning(&self) -> CircuitState {
let current_state = *self.state.read().unwrap();

if current_state == CircuitState::Open {
let last_change = *self.last_state_change.read().unwrap();
if last_change.elapsed() >= self.config.timeout_duration {
self.transition_to(CircuitState::HalfOpen);
return CircuitState::HalfOpen;
}
}
current_state
}

/// Record the outcome of a request
Expand Down Expand Up @@ -160,18 +171,6 @@ impl CircuitBreaker {
}
}

/// Check and update state based on timeout
fn check_and_update_state(&self) {
let current_state = *self.state.read().unwrap();

if current_state == CircuitState::Open {
let last_change = *self.last_state_change.read().unwrap();
if last_change.elapsed() >= self.config.timeout_duration {
self.transition_to(CircuitState::HalfOpen);
}
}
}

/// Transition to a new state
fn transition_to(&self, new_state: CircuitState) {
let mut state = self.state.write().unwrap();
Expand Down
102 changes: 63 additions & 39 deletions sgl-model-gateway/src/core/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@ use crate::{
routers::grpc::client::GrpcClient,
};

/// Default worker priority (mid-range on 0-100 scale)
pub const DEFAULT_WORKER_PRIORITY: u32 = 50;

/// Default worker cost factor (baseline cost)
pub const DEFAULT_WORKER_COST: f32 = 1.0;

/// Default HTTP client timeout for worker requests (in seconds)
pub const DEFAULT_WORKER_HTTP_TIMEOUT_SECS: u64 = 30;

static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(DEFAULT_WORKER_HTTP_TIMEOUT_SECS))
.build()
.expect("Failed to create worker HTTP client")
});
Expand All @@ -37,10 +46,12 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's API key
fn api_key(&self) -> &Option<String>;
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
/// Returns a reference to avoid cloning on every access
fn worker_type(&self) -> &WorkerType;

/// Get the worker's connection mode (HTTP or gRPC)
fn connection_mode(&self) -> ConnectionMode;
/// Returns a reference to avoid cloning on every access
fn connection_mode(&self) -> &ConnectionMode;

/// Get the bootstrap hostname for PD mode
/// Returns cached hostname parsed from URL at construction time
Expand All @@ -64,6 +75,18 @@ pub trait Worker: Send + Sync + fmt::Debug {
async fn check_health_async(&self) -> WorkerResult<()>;

/// Synchronous health check wrapper (for compatibility)
///
/// # Deprecation Notice
/// This method creates a new Tokio runtime for each call, which is expensive.
/// Prefer using `check_health_async()` within an async context instead.
///
/// # Performance Warning
/// Creating a runtime per call has significant overhead. Only use this
/// method when you cannot use the async version.
#[deprecated(
since = "0.4.6",
note = "Use check_health_async() instead. This method creates a new Tokio runtime per call."
)]
fn check_health(&self) -> WorkerResult<()> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down Expand Up @@ -191,16 +214,16 @@ pub trait Worker: Send + Sync + fmt::Debug {
.labels
.get("priority")
.and_then(|s| s.parse().ok())
.unwrap_or(50) // Default priority is 50 (mid-range)
.unwrap_or(DEFAULT_WORKER_PRIORITY)
}

/// Get the cost factor of this worker (1.0 = baseline)
/// Get the cost factor of this worker (baseline = 1.0)
fn cost(&self) -> f32 {
self.metadata()
.labels
.get("cost")
.and_then(|s| s.parse().ok())
.unwrap_or(1.0)
.unwrap_or(DEFAULT_WORKER_COST)
}

/// Get tokenizer path for a specific model.
Expand Down Expand Up @@ -550,12 +573,12 @@ impl Worker for BasicWorker {
&self.metadata.api_key
}

fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone()
fn worker_type(&self) -> &WorkerType {
&self.metadata.worker_type
}

fn connection_mode(&self) -> ConnectionMode {
self.metadata.connection_mode.clone()
fn connection_mode(&self) -> &ConnectionMode {
&self.metadata.connection_mode
}

fn is_healthy(&self) -> bool {
Expand Down Expand Up @@ -832,11 +855,11 @@ impl Worker for DPAwareWorker {
self.base_worker.api_key()
}

fn worker_type(&self) -> WorkerType {
fn worker_type(&self) -> &WorkerType {
self.base_worker.worker_type()
}

fn connection_mode(&self) -> ConnectionMode {
fn connection_mode(&self) -> &ConnectionMode {
self.base_worker.connection_mode()
}

Expand Down Expand Up @@ -1085,24 +1108,24 @@ impl HealthChecker {

/// Helper to convert Worker trait object to WorkerInfo struct
pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
// Cache values that are used multiple times to avoid redundant clones/allocations
// Cache references that are used multiple times to avoid redundant method calls
let worker_type = worker.worker_type();
let connection_mode = worker.connection_mode();
let url = worker.url();
let model_id = worker.model_id();

let worker_type_str = match &worker_type {
let worker_type_str = match worker_type {
WorkerType::Regular => "regular",
WorkerType::Prefill { .. } => "prefill",
WorkerType::Decode => "decode",
};

let bootstrap_port = match &worker_type {
let bootstrap_port = match worker_type {
WorkerType::Prefill { bootstrap_port } => *bootstrap_port,
_ => None,
};

let runtime_type = match &connection_mode {
let runtime_type = match connection_mode {
ConnectionMode::Grpc { .. } => Some(worker.metadata().runtime_type.to_string()),
ConnectionMode::Http => None,
};
Expand Down Expand Up @@ -1219,7 +1242,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.build();
assert_eq!(worker.url(), "http://test:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
assert_eq!(worker.worker_type(), &WorkerType::Regular);
assert!(worker.is_healthy());
assert_eq!(worker.load(), 0);
assert_eq!(worker.processed_requests(), 0);
Expand Down Expand Up @@ -1276,7 +1299,7 @@ mod tests {
let regular = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();
assert_eq!(regular.worker_type(), WorkerType::Regular);
assert_eq!(regular.worker_type(), &WorkerType::Regular);

let prefill = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Prefill {
Expand All @@ -1285,15 +1308,15 @@ mod tests {
.build();
assert_eq!(
prefill.worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: Some(9090)
}
);

let decode = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Decode)
.build();
assert_eq!(decode.worker_type(), WorkerType::Decode);
assert_eq!(decode.worker_type(), &WorkerType::Decode);
}

#[test]
Expand Down Expand Up @@ -1444,7 +1467,7 @@ mod tests {
.build(),
);
assert_eq!(worker.url(), "http://regular:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
assert_eq!(worker.worker_type(), &WorkerType::Regular);
}

#[test]
Expand All @@ -1459,7 +1482,7 @@ mod tests {
assert_eq!(worker1.url(), "http://prefill:8080");
assert_eq!(
worker1.worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: Some(9090)
}
);
Expand All @@ -1473,7 +1496,7 @@ mod tests {
);
assert_eq!(
worker2.worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: None
}
);
Expand All @@ -1487,7 +1510,7 @@ mod tests {
.build(),
);
assert_eq!(worker.url(), "http://decode:8080");
assert_eq!(worker.worker_type(), WorkerType::Decode);
assert_eq!(worker.worker_type(), &WorkerType::Decode);
}

#[test]
Expand Down Expand Up @@ -1573,7 +1596,7 @@ mod tests {
assert_eq!(workers.len(), 2);
assert_eq!(workers[0].url(), "http://w1:8080");
assert_eq!(workers[1].url(), "http://w2:8080");
assert_eq!(workers[0].worker_type(), WorkerType::Regular);
assert_eq!(workers[0].worker_type(), &WorkerType::Regular);
}

#[test]
Expand All @@ -1595,14 +1618,15 @@ mod tests {
assert_eq!(urls, vec!["http://w1:8080", "http://w2:8080"]);
}

#[test]
fn test_check_health_sync_wrapper() {
#[tokio::test]
async fn test_check_health_async() {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.build();

let result = worker.check_health();
// Health check should fail since there's no actual server
let result = worker.check_health_async().await;
assert!(result.is_err());
}

Expand Down Expand Up @@ -1640,7 +1664,7 @@ mod tests {
assert!(dp_worker.is_dp_aware());
assert_eq!(dp_worker.dp_rank(), Some(2));
assert_eq!(dp_worker.dp_size(), Some(4));
assert_eq!(dp_worker.worker_type(), WorkerType::Regular);
assert_eq!(dp_worker.worker_type(), &WorkerType::Regular);
}

#[test]
Expand All @@ -1655,7 +1679,7 @@ mod tests {
assert!(dp_worker.is_dp_aware());
assert_eq!(
dp_worker.worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: Some(9090)
}
);
Expand All @@ -1669,7 +1693,7 @@ mod tests {

assert_eq!(dp_worker.url(), "http://worker1:8080@0");
assert!(dp_worker.is_dp_aware());
assert_eq!(dp_worker.worker_type(), WorkerType::Decode);
assert_eq!(dp_worker.worker_type(), &WorkerType::Decode);
}

#[tokio::test]
Expand Down Expand Up @@ -1760,7 +1784,7 @@ mod tests {
assert!(worker.is_dp_aware());
assert_eq!(worker.dp_rank(), Some(1));
assert_eq!(worker.dp_size(), Some(4));
assert_eq!(worker.worker_type(), WorkerType::Regular);
assert_eq!(worker.worker_type(), &WorkerType::Regular);
}

#[tokio::test]
Expand All @@ -1779,7 +1803,7 @@ mod tests {
assert!(worker.is_dp_aware());
assert_eq!(
worker.worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: Some(8090)
}
);
Expand Down Expand Up @@ -1919,22 +1943,22 @@ mod tests {
assert!(workers[4].is_dp_aware());
assert!(workers[5].is_dp_aware());

assert_eq!(workers[0].worker_type(), WorkerType::Regular);
assert_eq!(workers[0].worker_type(), &WorkerType::Regular);
assert_eq!(
workers[1].worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: Some(9090)
}
);
assert_eq!(workers[2].worker_type(), WorkerType::Decode);
assert_eq!(workers[3].worker_type(), WorkerType::Regular);
assert_eq!(workers[2].worker_type(), &WorkerType::Decode);
assert_eq!(workers[3].worker_type(), &WorkerType::Regular);
assert_eq!(
workers[4].worker_type(),
WorkerType::Prefill {
&WorkerType::Prefill {
bootstrap_port: None
}
);
assert_eq!(workers[5].worker_type(), WorkerType::Decode);
assert_eq!(workers[5].worker_type(), &WorkerType::Decode);
}

// === Phase 1.3: WorkerMetadata model methods tests ===
Expand Down
Loading
Loading