-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Description
Summary
This RFC proposes extending vLLM's Expert Parallel Load Balancer (EPLB) with fault tolerance capabilities, enabling automatic detection and recovery from individual expert failures without service interruption. The system continues serving with degraded performance while automatically redistributing load and eventually restarting failed experts using elastic scaling mechanisms.
Motivation
Current EP implementations assume all experts remain healthy throughout deployment. In production environments serving models like DeepSeek-V3 with 256+ experts across multiple GPUs, individual expert failures are inevitable:
- Hardware degradation: Specific compute units on a GPU can fail while others continue working
- Memory errors: ECC errors affecting individual expert weight storage
- Numerical instabilities: Specific experts producing NaN/Inf due to data patterns
- Thermal throttling: Localized hotspots affecting expert execution
Without fault tolerance, expert failures cause:
- Request failures for tokens routed to dead experts
- Cascading load imbalances as traffic shifts unpredictably
- Service degradation requiring manual intervention
Note: Individual expert failures naturally encompass GPU-level failures. When all experts on a GPU fail, our rebalancing algorithm assigns prohibitive weights to all of them, effectively quarantining the entire GPU.
Proposed Change.
Proposed Solution
Core Approach: EPLB-Integrated Fault Tolerance
Extend EPLB with health monitoring as a first-class feature:
# Extension to existing EPLBConfig
@dataclass
class EPLBConfig:
# ... existing EPLB fields ...
# Fault tolerance (Phase 0)
health_check_enabled: bool = True
health_latency_threshold: float = 3.0 # 3x median = unhealthy
health_penalty_factor: float = 10.0 # Weight multiplier for avoidance
- Integrated with EPLB: Not a separate system - fault tolerance IS load balancing
- Service continuity: System serves with degraded performance, never stops
- Progressive recovery: Detection → Avoidance → Rebalancing → Restart
Phase 0: Health Monitoring & Detection (This RFC - Immediate)
Goal: Detect failures within seconds (or ~100-1000 forward passes. using EPLBConfig.window_size)
- Track per-expert latency with low overhead
- Detect failures when latency exceeds 3x (configurable through
health_latency_threshold) median - Automatically reduce traffic to unhealthy experts via EPLB rebalancing
def step(self, model: MixtureOfExperts, ...):
# Existing EPLB step logic
# NEW: Health check
if self.eplb_config.health_check_enabled:
avg_latency = self.expert_latency_window.mean(dim=0)
baseline = avg_latency.median(dim=-1, keepdim=True).values
self.expert_health_mask = (avg_latency < self.eplb_config.health_latency_threshold * baseline)
# Pass health mask to existing rebalance
if self.should_rebalance():
self.rearrange(model, health_mask=self.expert_health_mask)
Phase 1: Automatic Load Redistribution (This RFC - Immediate)
Goal: Maintain service with degraded but stable performance
- Unhealthy experts get 10x (configurable through
health_penalty_factor)weight penalty during rebalancing - EPLB naturally redistributes load to healthy experts
- GPU-level failures handled when all its experts marked unhealthy
No new code needed - just parameter passing:
def rebalance_experts(weight, ..., health_mask=None, health_penalty_factor=10.0):
if health_mask is not None:
# Unhealthy experts become unattractive for routing
weight = weight * (1 + (health_penalty_factor - 1) * (1 - health_mask))
# Existing rebalance logic handles the rest
Phase 2: Expert Restart via Elastic Scaling (#27908 - Leverages #20323)
Goal: Restore full capacity by restarting failed experts
When degradation persists beyond threshold:
- Trigger elastic scale-down to remove failed experts
- Immediately scale-up to spawn replacements
- EPLB rebalances to include new healthy experts
Monitoring Integration (optional)
# Existing stats extended with:
metrics["unhealthy_expert_count"] = (1 - health_mask).sum()
metrics["health_degradation_ratio"] = unhealthy_count / total_countunhealthy_count / total_count
Advantages
- Simple integration into EPLB
- Never stops serving, just degrades gracefully
- Automatic recovery
- Low overhead
Disadvantages
- Detection delay: Takes
EPLBConfig.window_sizeforward passes to detect - Capacity loss: Degraded performance until Phase 2 restart
Conclusion
This proposal adds critical fault tolerance to EP with minimal complexity by treating fault tolerance as an extension of load balancing. The phased approach allows immediate value (detection/avoidance) while building toward full recovery capabilities using existing elastic scaling infrastructure.
Q & A
Q: How is per-expert latency measured?
A: In Phase 0, we measure per-layer latency using CUDA events and assign it to all active experts in that layer. Since MoE experts run in fused parallel kernels (pplx, DeepEP, etc.), we cannot directly measure individual expert execution time without kernel-level instrumentation. Instead:
- Measure: Total MoE layer forward pass time (via CUDA events)
- Identify: Which experts were active (from router's topk selection)
- Record: Assign layer latency to all active experts; 0 for inactive experts
- Detect: Compare each expert's current latency to its historical mean (ignoring inactive passes)
Why this works for fault detection:
If a specific expert fails/degrades, the layer containing it will slow down. The failed expert will consistently appear in slow layer executions. Over the sliding window, the failing expert accumulates higher average latency
Healthy experts show lower historical means (they appear in fast executions too).
This approach is hardware-agnostic and requires no kernel modifications, making it suitable for Phase 0 detection. Future phases can add finer-grained per-expert timing if needed.
EP Kernel Fault Tolerance Requirements
Critical Dependency: For EPLB masking to work reliably, the underlying EP communication kernel must have embedded fault tolerance that prevents all_reduce operations from hanging when a rank becomes unresponsive.
Required Kernel Behavior
When a rank times out or becomes unresponsive during all_reduce operations in rearrange():
- The kernel must NOT hang indefinitely - It should detect the timeout
- The kernel should automatically exclude the failed rank - Remove it from collective communication
- The kernel should complete with remaining healthy ranks - Continue operation with reduced rank set
- vLLM should be able to call all_reduce on the problematic rank - Without blocking the entire system
This is essential because EPLB's masking happens at the orchestration layer (marking ranks as unhealthy), but the actual communication during rearrange still involves all ranks. If the EP kernel hangs waiting for an unresponsive rank, the entire masking mechanism fails.
Kernel Support Matrix
As of 11/12/2025
| Kernel | Official Repo Verified | In-Kernel Fault Tolerance | Behavior on Rank Failure |
|---|---|---|---|
| Mooncake EP | ✅ github.com/kvcache-ai/Mooncake | ✅ Yes - activeRanks masking | Continues with healthy ranks |
| DeepEP | ✅ github.com/deepseek-ai/DeepEP | ✅ Yes - mask_buffer_ptr masking | Continues with healthy ranks (if mask buffer provided) |
| pplx-kernels | ✅ github.com/perplexityai/pplx-kernels | ❌ No - Will hang indefinitely | Kernel hangs forever |
Implementation Challenges and Caveats
Challenge 1: Two-Layer Timeout Detection
EPLB masking requires coordinated timeout detection at two layers:
- EP Kernel Layer (GPU-side): Detects unresponsive ranks during collective operations
- Orchestration Layer (vLLM-side): Marks ranks as unhealthy and triggers rearrangement
These two layers must work together. The kernel provides the foundation (non-blocking collectives), while vLLM provides the policy (when to mask, how to rebalance).
Challenge 2: Mask Buffer Coordination (DeepEP)
For DeepEP, the mask_buffer_ptr must be:
- Allocated in GPU-accessible memory
- Shared across all collective operations (dispatch, combine, barrier)
- Synchronized between vLLM's masking state and kernel's runtime state
- Properly initialized and cleaned when ranks are added/removed
Implementation Note: We need to extend DeepEP buffer initialization to expose and manage
the mask buffer from Python-side vLLM code.
Challenge 3: Asymmetric Failure Detection
Different ranks may detect the same failure at different times:
- Rank A may timeout waiting for Rank B after 100s
- Rank C may timeout waiting for Rank B after 95s
- This creates a brief period where the system has inconsistent views of which ranks are healthy
Mitigation: Use consensus-based masking where multiple ranks must agree a rank is failed
before triggering rearrangement, or accept temporary inconsistency during the detection window.
Challenge 4: Rearrange During Partial Failure
When calling rearrange() with some ranks masked:
- The failed rank may still be running (just slow/stuck)
- Calling all_reduce on a masked rank should either:
- Skip that rank entirely (requires kernel support), OR
- Use a separate communication path that can timeout safely
Current Approach: Rely on kernel's dynamic masking to handle this automatically.
Feedback Period.
No response
CC List.
@pavanimajety @GuanLuo @benchislett @xinli-sw
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.