Skip to content

feat: [2/2][DeepEP] Add waterfill load balancing for shared expert dispatch#19290

Merged
ch-wan merged 118 commits into
sgl-project:mainfrom
xutizhou:feat/deepep-waterfill-eplb-balance
May 14, 2026
Merged

feat: [2/2][DeepEP] Add waterfill load balancing for shared expert dispatch#19290
ch-wan merged 118 commits into
sgl-project:mainfrom
xutizhou:feat/deepep-waterfill-eplb-balance

Conversation

@xutizhou
Copy link
Copy Markdown
Collaborator

@xutizhou xutizhou commented Feb 25, 2026

Motivation

In DeepSeek V3/R1 with expert parallelism (EP), each rank processes a subset of routed experts plus the full shared expert. Since every token must visit the shared expert, it creates a fixed compute load on all ranks regardless of routed expert distribution. When routed expert load is imbalanced across ranks (common with skewed token distributions), the shared expert amplifies the bottleneck on already-overloaded ranks.
Waterfill addresses this by treating the shared expert as an additional routed expert (the "9th expert" per rank) and dispatching it to the least-loaded rank via DeepEP, effectively filling the load gap like water filling a container. This converts the shared expert from a fixed per-rank cost into a dynamic load-balancing lever.

coauthor: @AichenF

Thanks to @ch-wan for the discussion and suggestions.

This is the second half of the waterfill feature split (pending on #20089):

Modifications

New file: python/sglang/srt/layers/moe/deepep_waterfill.py (514 lines)

  • Triton kernel _count_routed_per_rank_kernel: counts routed tokens per EP rank from topk_ids.
  • Triton kernel _waterfill_kernel: assigns each token's shared expert to the least-loaded rank using a waterfill algorithm, with local-rank preference (1.1x bias).
  • DeepEPWaterfillBalancer class: orchestrates the waterfill dispatch with two modes:
    • Static mode: Uses precomputed rank_load from EPLB expert distribution data (no runtime all_reduce).
    • Dynamic mode: Falls back to runtime all_reduce of per-rank routed counts when EPLB data is unavailable.
  • expand_topk(): Expands topk output by one column (shared expert slot) with computed expert IDs and scaling factors.

Modified files:

  • models/deepseek_v2.py (+112/-21): Creates DeepEPWaterfillBalancer in MoE init; calls expand_topk() after TopK selection; skips separate shared expert forward when waterfill is enabled (shared expert is now part of MoE dispatch/compute/combine).
  • layers/moe/fused_moe_triton/layer.py (+40/-2): Adjusts weight loading to map checkpoint expert IDs correctly when waterfill expands the expert count by ep_size (one shared slot per rank).
  • eplb/expert_location.py (+33/-1): Adds rank_load field to ExpertLocationMetadata and _compute_rank_load() to derive per-rank load from logical expert counts + physical-to-logical mapping.
  • server_args.py (+29/-1): Adds --enable-deepep-waterfill flag with validation.

Accuracy Tests

MMLU accuracy on DeepSeek-V3 (FP8, 2-node 16xH20, EP16, DP16):

Configuration MMLU Accuracy
Baseline (no waterfill) 91.55%
Waterfill + static EPLB 91.80%

No accuracy degradation observed.

Benchmarking and Profiling

Throughput benchmark on DeepSeek-V3 (FP8, 2-node 16×H20, EP16, DP16, input_len≈~600 MMLU prompts, output_len=1, 256 concurrent requests × multiple rounds, 1000 prompts / round):

Configuration Output Throughput (tok/s) Gain
Baseline + static EPLB (no waterfill) ~29,700
Waterfill + static EPLB ~30,900 +4%

waterfill

python3 -m sglang.launch_server \
    --model-path /lustre/raplab/client/xutingz/workspace/model/DeepSeek-V3 \
    --tp 16 --dp-size 16 --nnodes 2 --node-rank {0|1} \
    --dist-init-addr 10.6.131.5:{PORT} \
    --host 0.0.0.0 --port 30000 --trust-remote-code \
    --moe-a2a-backend deepep --deepep-mode normal \
    --enable-dp-attention --mem-fraction-static 0.75 --max-running-requests 2048 \
    --watchdog-timeout 1800 --disable-radix-cache --disable-cuda-graph \
    --chunked-prefill-size -1 \
    --enable-deepep-waterfill \
    --init-expert-location /lustre/raplab/client/xutingz/workspace/bench/waterfill/mmlu_expert_dist/ep16_mmlu_logical_count.pt

Baseline

python3 -m sglang.launch_server \
    --model-path /lustre/raplab/client/xutingz/workspace/model/DeepSeek-V3 \
    --tp 16 --dp-size 16 --nnodes 2 --node-rank {0|1} \
    --dist-init-addr 10.6.131.5:{PORT} \
    --host 0.0.0.0 --port 30000 --trust-remote-code \
    --moe-a2a-backend deepep --deepep-mode normal \
    --enable-dp-attention --mem-fraction-static 0.75 --max-running-requests 2048 \
    --watchdog-timeout 1800 --disable-radix-cache --disable-cuda-graph \
    --chunked-prefill-size -1 \
    --init-expert-location /lustre/raplab/client/xutingz/workspace/bench/waterfill/mmlu_expert_dist/ep16_mmlu_logical_count.pt

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

This commit implements waterfill load balancing for shared expert using DeepEP
dispatch mechanism. The key idea is to treat shared expert as a virtual 9th
expert and dispatch it through DeepEP along with routed experts.

Design principles:
1. Each token's shared expert can be sent to:
   - One of the ranks it already routes to (no extra communication)
   - Or stay at source rank for local computation
2. Waterfill algorithm selects the lowest-loaded rank from candidates
3. Shared expert weight = 1.0 / routed_scaling_factor (for correct combine)

New files:
- python/sglang/srt/layers/moe/deepep_waterfill.py: Waterfill algorithm and helpers

Modified files:
- python/sglang/srt/server_args.py: Add --enable-deepep-waterfill flag
- python/sglang/srt/models/deepseek_v2.py: Add forward_deepep_waterfill method

Usage:
  python -m sglang.launch_server --model-path <model> --tp 8 --ep 8 \
    --moe-a2a-backend deepep --enable-deepep-waterfill
This script runs benchmarks to compare:
- Experiment 1: DeepEP baseline (no waterfill)
- Experiment 2: DeepEP + Waterfill
- Experiment 3: DeepEP + Waterfill with debug logging

Usage:
  bash test/run_deepep_waterfill_benchmark.sh
Bugs fixed:
1. Used wrong rank function (get_tensor_model_parallel_rank -> get_moe_expert_parallel_rank)
2. expand_topk_for_shared_expert didn't use shared_destination parameter
3. Simplified implementation: all shared experts computed locally
4. Added alt_stream optimization for parallel shared expert computation
5. Added debug logging for load distribution analysis

This is a simplified implementation where shared experts are computed locally
on the source rank in parallel with DeepEP dispatch/combine. True cross-rank
waterfill (dispatching shared expert to already-routed ranks) requires
DeepEP protocol modifications and is left as future work.

Current flow:
1. Router + topk computation
2. Shared expert on alt_stream (parallel)
3. DeepEP dispatch for routed experts
4. MoE computation
5. DeepEP combine
6. Add shared expert result
…expert

Key Design:
1. Shared expert treated as virtual 9th routed expert
2. Virtual expert ID = target_rank * experts_per_rank (routes to correct rank)
3. Waterfill only assigns to ranks token already routes to (no extra comm)
4. Receiver identifies shared tokens via virtual ID and computes separately
5. Shared weight = 1/routed_scaling_factor for correct final scaling

Flow:
1. Router + topk(8)
2. AllReduce global routed counts
3. Waterfill assigns shared destination
4. Expand topk to 9 columns
5. DeepEP dispatch with topk=9
6. Receiver: MoE(8 cols) + shared expert + merge
7. DeepEP combine with topk=9
8. Apply routed_scaling_factor
Improvements:
1. LOCAL_SHARED_MARKER (-1): tokens compute shared expert locally
2. MIN_BATCH_FOR_BALANCE (64): small batches compute all shared locally
3. alt_stream optimization: local shared expert parallel with dispatch
4. Separate handling of local vs remote shared expert computation

Flow:
1. Router + topk(8)
2. AllReduce global routed counts
3. Waterfill assigns destination (local or remote rank)
4. Expand topk to 9 cols (LOCAL_SHARED_MARKER or virtual ID)
5. Local shared expert on alt_stream (parallel)
6. DeepEP dispatch with topk=9
7. Receiver: MoE(8 cols) + remote shared expert
8. DeepEP combine
9. Add local shared expert output
10. Apply routed_scaling_factor
Fix: rsf should only be applied to routed experts, not shared experts.

Before (wrong order):
  combined += local_shared * (1/rsf)
  combined *= rsf  # rsf affects local_shared!

After (correct order):
  combined *= rsf  # only affects routed
  combined += local_shared  # not affected by rsf

Weight handling:
- Local shared: weight = 1.0 (added AFTER rsf multiplication)
- Remote shared: weight = 1/rsf (added BEFORE combine, rsf cancels out)

Final result: routed * rsf + shared
If a remote rank would receive fewer than MIN_TOKENS_PER_RANK (16) tokens
for shared expert computation, redirect those tokens to local computation.

This avoids sending only a few tokens to a remote rank, which would have
more overhead than computing locally.

Thresholds:
- MIN_BATCH_FOR_BALANCE = 64: small batches compute all shared locally
- MIN_TOKENS_PER_RANK = 16: sparse destinations redirected to local
Tests cover:
1. count_routed_per_rank_pytorch - token counting per rank
2. assign_shared_destination_pytorch - waterfill assignment
3. assign_shared_destination - source rank preference
4. expand_topk_with_shared_expert - topk expansion to 9 cols
5. identify_shared_expert_tokens - receiver side identification
6. compute_local_shared_expert - local computation
7. DeepEPWaterfillBalancer - small batch optimization
8. DeepEPWaterfillBalancer - sparse destination redirect
9. End-to-end scenario
10. shared_weight calculation

All 10 tests pass on CPU.
Additional tests added:
- Empty batch handling
- Single token handling
- All tokens route to same rank
- Waterfill load balancing effectiveness
- MIN_TOKENS_PER_RANK threshold
- identify_shared_expert_tokens with all local markers
- identify_shared_expert_tokens mixed scenarios
- compute_local_shared_expert with no local tokens
- Virtual ID to rank mapping
- Weight preservation in topk expansion
- Routed count accuracy
- Consistency across repeated calls

Total: 22 tests, all passing
Changes:
1. Increase MIN_TOKENS_PER_RANK from 16 to 128 (tile size)
2. Redirect local shared tokens to remote if count < 128

Before: Multiple ranks received <128 shared tokens (wasted tiles)
After: All ranks receive 0 or >=128 shared tokens (no waste)

Load balance improvement: 15-39% reduction in imbalance ratio
DeepEP Normal mode and Low Latency mode handle topk_weights differently:
- Normal mode: run_moe_core applies weights, combine does NOT
- Low Latency mode: run_moe_core does NOT apply weights, combine DOES

Fixed remote shared expert weight application:
- Normal mode: Apply weight (1/rsf) before combine
- Low Latency mode: Let combine handle weight multiplication

Also verified:
- DeepGEMM tile size (BLOCK_M) = 128 (confirms MIN_TOKENS_PER_RANK = 128)
- DeepEP topk_ids=-1 means no selection (confirms LOCAL_SHARED_MARKER = -1)
Added test_deepep_waterfill_comprehensive.py with 15 test cases:
- count_routed_per_rank accuracy
- assign_shared_destination correctness
- expand_topk_with_shared_expert
- identify_shared_expert_tokens
- Virtual ID to rank mapping
- MIN_BATCH_FOR_BALANCE optimization
- MIN_TOKENS_PER_RANK redirect
- Shared weight calculation (1/rsf)
- Empty batch handling
- compute_local_shared_expert
- Weights preservation
- Waterfill load balancing effectiveness
- Invalid expert ID handling
- Large batch performance

All 15 tests pass.
Key optimizations:
1. assign_shared_destination: Replace for-loop with scatter-based vectorized ops
   - Old: O(topk) loop iterations, each with indexing
   - New: Single scatter operation for all topk values
   - Speedup: 2.6x - 4.2x depending on batch size

2. expand_topk_with_shared_expert: Pre-allocate output tensors
   - Avoid torch.cat overhead by pre-allocating and copying
   - Reduce memory allocation operations

3. prepare_dispatch: Vectorized sparse rank redirect
   - Replace for-loop with lookup table approach

Benchmark results (CPU):
- batch=128:  0.11ms -> 0.03ms (4.21x faster)
- batch=4096: 0.70ms -> 0.18ms (3.79x faster)
- batch=8192: 1.09ms -> 0.31ms (3.47x faster)
- Add assign_shared_destination_triton() kernel for GPU
- Auto-select Triton on GPU, fallback to PyTorch on CPU
- Update benchmark to compare PyTorch vs Triton on GPU

Triton kernel processes one token per thread block, iterating over
topk experts to find the minimum-load destination rank.
Major changes:
1. New fused kernel: _waterfill_expand_topk_fused_kernel
   - Combines waterfill assignment + topk expansion in single pass
   - Reduces kernel launches from 3 to 1
   - Eliminates intermediate tensor allocations

2. Kernel design:
   - Each thread block handles BLOCK_SIZE=256 tokens
   - Loop over topk experts to find minimum-load rank
   - Write expanded topk_ids, weights, and local_mask in-place

3. Vectorized post-processing:
   - Sparse rank redirect: use boolean indexing instead of for-loop
   - Local count redirect: single tensor operation
   - Minimal GPU-CPU synchronization

Performance (CPU):
- assign_shared_destination: 3.4-4.3x speedup vs loop version
- prepare_dispatch: 0.28ms for 4096 tokens

GPU benefits (when Triton available):
- Single kernel launch vs multiple PyTorch ops
- No intermediate memory allocation
- Better memory coalescing
- Map checkpoint routed expert_ids (0..255) with old experts_per_rank when Waterfill expands num_experts to +ep_size

- Add unit test to prevent EP-rank mis-mapping regression
…red expert when SGLANG_DEEPEP_WATERFILL_FIXED_LOCAL is set. This change allows for fixed local computation of shared experts, improving control over load balancing behavior.
@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

1 similar comment
@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

# Conflicts:
#	test/registered/unit/server_args/test_server_args.py
@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

3 similar comments
@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@xutizhou
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@ch-wan ch-wan merged commit c701a08 into sgl-project:main May 14, 2026
393 of 471 checks passed
Fridge003 pushed a commit that referenced this pull request May 14, 2026
…spatch (#19290)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
Co-authored-by: root <aichenf@nvidia.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Shunkangz pushed a commit to Shunkangz/sglang that referenced this pull request May 27, 2026
…spatch (sgl-project#19290)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
Co-authored-by: root <aichenf@nvidia.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
alphabetc1 pushed a commit to alphabetc1/sglang that referenced this pull request Jun 4, 2026
…spatch (sgl-project#19290)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
Co-authored-by: root <aichenf@nvidia.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
zijiexia added a commit to zijiexia/sglang that referenced this pull request Jun 4, 2026
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants