diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 6e8074ab99ea..1bef438c3493 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -336,6 +336,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--elastic-ep-backend` | Specify the collective communication backend for elastic EP. Currently supports 'mooncake'. | `none` | `none`, `mooncake` | | `--enable-elastic-expert-backup` | Enable elastic EP backend to backup expert weights in DRAM feature. Currently supports 'mooncake'.| `False` | bool flag (set to enable) | | `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend transfer, accepts multiple comma-separated devices (e.g., --mooncake-ib-device mlx5_0,mlx5_1). Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | `None` | Type: str | +| `--enable-deepep-waterfill` | Enable DeepEP Waterfill: dispatch the shared expert as the 9th routed expert to the least-loaded EP rank. Automatically sets `--moe-a2a-backend deepep`, implicitly enables shared-expert fusion, and supports `--deepep-mode auto`, `normal`, or `low_latency`. Use `auto` or `low_latency` for production decode so CUDA graph remains enabled. Supported on DeepSeek-V3/R1 with EP >= 2. By default, Waterfill uses the static local-batch path; set `SGLANG_DISABLE_STATIC_WATERFILL=1` to force dynamic Waterfill with runtime EP all-reduce. | `False` | bool flag (set to enable) | | `--elastic-ep-rejoin` | Indicates that this process is a relaunched elastic EP rank that should rejoin an existing process group during rank recovery. | `False` | bool flag (set to enable) | ## Mamba Cache diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 97e8040e8d35..6f8dd37f512c 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -412,6 +412,9 @@ class Envs: SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(128) SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS = EnvInt(32) SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO = EnvBool(False) + # Force dynamic DeepEP Waterfill with runtime EP all-reduce instead of the + # default static local-batch path. + SGLANG_DISABLE_STATIC_WATERFILL = EnvBool(False) # NIXL-EP SGLANG_NIXL_EP_BF16_DISPATCH = EnvBool(False) diff --git a/python/sglang/srt/layers/moe/deepep_waterfill.py b/python/sglang/srt/layers/moe/deepep_waterfill.py new file mode 100644 index 000000000000..caa8a912ce79 --- /dev/null +++ b/python/sglang/srt/layers/moe/deepep_waterfill.py @@ -0,0 +1,584 @@ +# Copyright 2023-2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DeepEP Waterfill: shared expert as 9th routed expert, dispatched to least-loaded rank.""" + +from typing import NamedTuple, Optional, Tuple + +import torch +import triton +import triton.language as tl +from torch import Tensor + +from sglang.srt.environ import envs +from sglang.srt.layers.moe.topk import StandardTopKOutput + +LOCAL_SHARED_MARKER = -1 # Invalid expert ID; DeepEP ignores expert_id < 0. +_LOCAL_PREF_NUMER = 11 # local-rank preference = 11/10 +_LOCAL_PREF_DENOM = 10 + + +class WaterfillDispatchPlan(NamedTuple): + """Inputs needed by the fused DeepEP Waterfill expansion path.""" + + # Effective rank load consumed by the fused kernel. + rank_load: Tensor + allow_all_ranks: bool + target_total: int + + +def _empty_expanded(topk_ids: Tensor, topk_weights: Tensor): + """Return empty expanded tensors for zero-token batches.""" + topk, d = topk_ids.shape[1], topk_ids.device + return ( + torch.empty(0, topk + 1, dtype=topk_ids.dtype, device=d), + torch.empty(0, topk + 1, dtype=topk_weights.dtype, device=d), + ) + + +@triton.jit +def _count_routed_per_rank_kernel( + topk_ids_ptr, # [num_tokens, topk] + counts_ptr, # [world_size] output (atomic add) + num_tokens, + topk: tl.constexpr, + experts_per_rank, + world_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Count routed tokens per rank using block-level histogram.""" + pid = tl.program_id(0) + token_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = token_idx < num_tokens + + for r in range(world_size): + rank_count = tl.zeros([BLOCK_SIZE], dtype=tl.int64) + + for k in range(topk): + expert_id = tl.load( + topk_ids_ptr + token_idx * topk + k, mask=mask, other=-1 + ).to(tl.int64) + valid = expert_id >= 0 + target_rank = expert_id // experts_per_rank + target_rank = tl.minimum(tl.maximum(target_rank, 0), world_size - 1) + rank_count += tl.where( + mask & valid & (target_rank == r), + tl.full([BLOCK_SIZE], 1, dtype=tl.int64), + tl.zeros([BLOCK_SIZE], dtype=tl.int64), + ) + + block_total = tl.sum(rank_count) + if block_total > 0: + tl.atomic_add(counts_ptr + r, block_total) + + +@triton.jit +def _waterfill_expand_kernel( + topk_ids_ptr, + topk_weights_ptr, + rank_load_ptr, + expanded_ids_ptr, + expanded_weights_ptr, + num_tokens, + topk: tl.constexpr, + old_experts_per_rank, + new_experts_per_rank, + world_size: tl.constexpr, + source_rank, + shared_weight, + local_marker, + local_pref_numer, + local_pref_denom, + precomputed_target_total, + ALLOW_ALL_RANKS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused waterfill + expand. ID remap: old_id -> old_id + old_id // old_epr.""" + pid = tl.program_id(0) + token_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = token_idx < num_tokens + + r_idx = tl.arange(0, world_size) + rank_load_vec = tl.load(rank_load_ptr + r_idx, mask=r_idx < world_size, other=0).to( + tl.int64 + ) + total_effective_k = tl.sum(rank_load_vec) + total_tokens_global_k = total_effective_k // topk + derived_target_total = ( + total_effective_k + total_tokens_global_k + world_size - 1 + ) // world_size + target_total = tl.where( + precomputed_target_total > 0, + precomputed_target_total, + derived_target_total, + ) + + # Step 1: Select destination rank for shared expert (waterfill sampling). + source_count = tl.load(rank_load_ptr + source_rank) + best_count = tl.where(mask, source_count, 2**30) + best_rank = tl.full([BLOCK_SIZE], source_rank, dtype=tl.int64) + has_valid = tl.zeros([BLOCK_SIZE], dtype=tl.int1) + src_rank_i32 = tl.full([BLOCK_SIZE], source_rank, dtype=tl.int32) + + if ALLOW_ALL_RANKS: + candidate_mask = tl.full([BLOCK_SIZE], (1 << world_size) - 1, dtype=tl.int32) + for r in range(world_size): + target_count = tl.load(rank_load_ptr + r).to(tl.int64) + better = ( + target_count * local_pref_numer < best_count * local_pref_denom + ) & mask + best_count = tl.where(better, target_count, best_count) + best_rank = tl.where( + better, tl.full([BLOCK_SIZE], r, dtype=tl.int64), best_rank + ) + else: + candidate_mask = (tl.full([BLOCK_SIZE], 1, dtype=tl.int32) << src_rank_i32).to( + tl.int32 + ) + + for k in range(topk): + expert_id = tl.load( + topk_ids_ptr + token_idx * topk + k, mask=mask, other=-1 + ).to(tl.int64) + valid = expert_id >= 0 + has_valid = has_valid | valid + + if not ALLOW_ALL_RANKS: + target_rank = expert_id // old_experts_per_rank + target_rank = tl.minimum(tl.maximum(target_rank, 0), world_size - 1) + target_rank_i32 = target_rank.to(tl.int32) + shift_amt = tl.where(valid, target_rank_i32, 0) + bit = tl.full([BLOCK_SIZE], 1, dtype=tl.int32) << shift_amt + candidate_mask = tl.where( + valid & mask, candidate_mask | bit, candidate_mask + ) + + target_count = tl.load( + rank_load_ptr + target_rank, mask=mask & valid, other=2**30 + ) + + better = ( + (target_count * local_pref_numer < best_count * local_pref_denom) + & valid + & mask + ) + best_count = tl.where(better, target_count, best_count) + best_rank = tl.where(better, target_rank, best_rank) + + total_w = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + for r in range(world_size): + present = ((candidate_mask >> r) & 1) == 1 + rank_load_r = tl.load(rank_load_ptr + r).to(tl.int64) + w = tl.where(target_total > rank_load_r, target_total - rank_load_r, 0).to( + tl.int32 + ) + w_vec = tl.full([BLOCK_SIZE], w, dtype=tl.int32) + w_vec = tl.where( + src_rank_i32 == r, + w_vec, + (w_vec * local_pref_denom) // local_pref_numer, + ) + total_w += tl.where(present, w_vec, 0) + + token_seed = token_idx.to(tl.uint32) ^ ( + src_rank_i32.to(tl.uint32) * tl.full([BLOCK_SIZE], 0x9E3779B9, dtype=tl.uint32) + ) + token_seed = token_seed * tl.full([BLOCK_SIZE], 1664525, dtype=tl.uint32) + tl.full( + [BLOCK_SIZE], 1013904223, dtype=tl.uint32 + ) + u = tl.where(total_w > 0, token_seed % total_w.to(tl.uint32), 0).to(tl.int32) + + chosen = src_rank_i32 + cum = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + for r in range(world_size): + present = ((candidate_mask >> r) & 1) == 1 + rank_load_r = tl.load(rank_load_ptr + r).to(tl.int64) + w = tl.where(target_total > rank_load_r, target_total - rank_load_r, 0).to( + tl.int32 + ) + w_vec = tl.full([BLOCK_SIZE], w, dtype=tl.int32) + w_vec = tl.where( + src_rank_i32 == r, + w_vec, + (w_vec * local_pref_denom) // local_pref_numer, + ) + w_vec = tl.where(present, w_vec, 0) + pick = (total_w > 0) & present & (u >= cum) & (u < (cum + w_vec)) + chosen = tl.where(pick, r, chosen) + cum += w_vec + + best_rank = tl.where(total_w > 0, chosen.to(tl.int64), best_rank) + + # Step 2: Compute shared expert ID and local mask. + is_local = best_rank == source_rank + local_shared_id = source_rank * new_experts_per_rank + old_experts_per_rank + remote_shared_id = best_rank * new_experts_per_rank + old_experts_per_rank + shared_expert_id = tl.where( + is_local, + tl.full([BLOCK_SIZE], local_shared_id, dtype=tl.int64), + remote_shared_id, + ).to(tl.int64) + shared_expert_id = tl.where( + has_valid, + shared_expert_id, + tl.full([BLOCK_SIZE], local_marker, dtype=tl.int64), + ) + + # Step 3: Copy and remap topk_ids, copy weights. + for k in range(topk): + old_id = tl.load(topk_ids_ptr + token_idx * topk + k, mask=mask, other=-1).to( + tl.int64 + ) + valid_id = old_id >= 0 + new_id = tl.where(valid_id, old_id + (old_id // old_experts_per_rank), old_id) + tl.store(expanded_ids_ptr + token_idx * (topk + 1) + k, new_id, mask=mask) + + for k in range(topk): + val = tl.load(topk_weights_ptr + token_idx * topk + k, mask=mask, other=0.0) + expert_id = tl.load( + topk_ids_ptr + token_idx * topk + k, mask=mask, other=-1 + ).to(tl.int64) + val = tl.where(expert_id >= 0, val, 0.0) + tl.store(expanded_weights_ptr + token_idx * (topk + 1) + k, val, mask=mask) + + # Step 4: Write shared expert column. + tl.store( + expanded_ids_ptr + token_idx * (topk + 1) + topk, + shared_expert_id, + mask=mask, + ) + tl.store( + expanded_weights_ptr + token_idx * (topk + 1) + topk, + tl.where(has_valid, shared_weight, 0.0), + mask=mask, + ) + + +def materialize_waterfill_dispatch_fused( + topk_ids: Tensor, + topk_weights: Tensor, + rank_load: Tensor, + num_routed_experts: int, + world_size: int, + source_rank: int, + shared_weight: float, + allow_all_ranks: bool = False, + target_total: int = 0, +) -> Tuple[Tensor, Tensor]: + """Run fused Waterfill rank selection and DeepEP TopK expansion. + + The Triton kernel intentionally selects each token's shared-expert rank and + writes the expanded DeepEP TopK layout in one pass. + """ + num_tokens = topk_ids.shape[0] + topk = topk_ids.shape[1] + old_experts_per_rank = num_routed_experts // world_size + new_experts_per_rank = old_experts_per_rank + 1 + device = topk_ids.device + + if num_tokens == 0: + return _empty_expanded(topk_ids, topk_weights) + + expanded_topk_ids = torch.empty( + num_tokens, topk + 1, dtype=topk_ids.dtype, device=device + ) + expanded_topk_weights = torch.empty( + num_tokens, topk + 1, dtype=topk_weights.dtype, device=device + ) + BLOCK_SIZE = 256 + grid = ((num_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE,) + _waterfill_expand_kernel[grid]( + topk_ids, + topk_weights, + rank_load, + expanded_topk_ids, + expanded_topk_weights, + num_tokens, + topk, + old_experts_per_rank, + new_experts_per_rank, + world_size, + source_rank, + shared_weight, + LOCAL_SHARED_MARKER, + _LOCAL_PREF_NUMER, + _LOCAL_PREF_DENOM, + target_total, + allow_all_ranks, + BLOCK_SIZE, + ) + + return expanded_topk_ids, expanded_topk_weights + + +@torch.compile(dynamic=True) +def expand_topk_with_shared_expert( + topk_ids: Tensor, + topk_weights: Tensor, + num_routed_experts: int, + world_size: int, + source_rank: int, + shared_weight: float, +) -> Tuple[Tensor, Tensor]: + """Expand topk [N, 8] → [N, 9] with ID remap; shared expert always local.""" + num_tokens = topk_ids.shape[0] + topk = topk_ids.shape[1] + device = topk_ids.device + old_epr = num_routed_experts // world_size + new_epr = old_epr + 1 + has_valid = (topk_ids >= 0).any(dim=1) + valid_mask = topk_ids >= 0 + old_ranks = torch.where(valid_mask, topk_ids // old_epr, torch.zeros_like(topk_ids)) + expanded_topk_ids = torch.empty( + num_tokens, topk + 1, dtype=topk_ids.dtype, device=device + ) + expanded_topk_ids[:, :topk] = torch.where( + valid_mask, topk_ids + old_ranks, topk_ids + ) + + shared_id = source_rank * new_epr + old_epr + expanded_topk_ids[:, topk] = torch.where(has_valid, shared_id, LOCAL_SHARED_MARKER) + expanded_topk_weights = torch.empty( + num_tokens, topk + 1, dtype=topk_weights.dtype, device=device + ) + expanded_topk_weights[:, :topk] = torch.where(valid_mask, topk_weights, 0.0) + expanded_topk_weights[:, topk] = torch.where(has_valid, shared_weight, 0.0).to( + topk_weights.dtype + ) + return expanded_topk_ids, expanded_topk_weights + + +class DeepEPWaterfillBalancer: + """Waterfill load balancer: shared expert fused as real routed expert (topk 8→9).""" + + MIN_BATCH_FOR_BALANCE = 64 + + def __init__( + self, + num_routed_experts: int, + world_size: int, + rank: int, + layer_id: int, + routed_scaling_factor: float = 1.0, + ): + self.num_routed_experts = num_routed_experts + self.world_size = world_size + self.rank = rank + self.layer_id = layer_id + self.old_experts_per_rank = num_routed_experts // world_size + self.shared_weight = ( + 1.0 / routed_scaling_factor if routed_scaling_factor != 0 else 1.0 + ) + self._counts_buf: Optional[Tensor] = None + self.use_static_waterfill = not envs.SGLANG_DISABLE_STATIC_WATERFILL.get() + + def count_local_routed(self, topk_ids: Tensor) -> Tensor: + """Count routed tokens per rank via Triton kernel (uses original expert IDs).""" + if self._counts_buf is None: + self._counts_buf = torch.zeros( + self.world_size, dtype=torch.int64, device=topk_ids.device + ) + buf = self._counts_buf + buf.zero_() + num_tokens = topk_ids.shape[0] + if num_tokens == 0: + return buf + topk = topk_ids.shape[1] + BLOCK_SIZE = 256 + grid = ((num_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE,) + _count_routed_per_rank_kernel[grid]( + topk_ids, + buf, + num_tokens, + topk, + self.old_experts_per_rank, + self.world_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + return buf + + def _is_low_batch(self, num_tokens: int) -> bool: + """Return whether waterfill should skip balancing for small batches.""" + return num_tokens < self.MIN_BATCH_FOR_BALANCE + + def _can_skip_dispatch_plan_for_low_batch(self, num_tokens: int) -> bool: + """Return whether static mode can skip dispatch-plan setup entirely.""" + return self.use_static_waterfill and self._is_low_batch(num_tokens) + + def _build_static_dispatch_plan( + self, routed_counts: Tensor + ) -> WaterfillDispatchPlan: + """Build static-mode Waterfill inputs from current local routed counts.""" + return WaterfillDispatchPlan( + rank_load=routed_counts, + allow_all_ranks=True, + target_total=0, + ) + + def _build_dynamic_dispatch_plan( + self, + routed_counts: Tensor, + local_tokens_per_rank: Optional[Tensor], + topk: int, + ) -> WaterfillDispatchPlan: + """Build dynamic waterfill inputs from globally reduced routed counts.""" + # Dynamic Waterfill balances against effective rank load: globally + # reduced routed counts plus each rank's active token count. + rank_load = ( + routed_counts + local_tokens_per_rank + if local_tokens_per_rank is not None + else routed_counts + ) + total_routed_t = routed_counts.sum() + total_tokens_global_t = total_routed_t // topk + total_effective_t = rank_load.sum() + max_effective_t = rank_load.max() + target_total = int( + (total_effective_t + total_tokens_global_t + self.world_size - 1) + // self.world_size + ) + allow_all_ranks = bool(max_effective_t <= target_total) + return WaterfillDispatchPlan( + rank_load=rank_load, + allow_all_ranks=allow_all_ranks, + target_total=target_total, + ) + + @staticmethod + def _all_reduce_dynamic_rank_load( + local_routed_counts: Tensor, num_tokens: int + ) -> Tuple[Tensor, Tensor]: + """Aggregate dynamic load with SGLang EP communication.""" + from sglang.srt.distributed import get_moe_ep_group + from sglang.srt.distributed.communication_op import ( + moe_expert_parallel_all_reduce, + ) + + group = get_moe_ep_group() + world = group.world_size + buf = torch.zeros( + world * 2, dtype=torch.int64, device=local_routed_counts.device + ) + buf[:world] = local_routed_counts + rank = group.rank_in_group + buf[world + rank : world + rank + 1].fill_(num_tokens) + buf = moe_expert_parallel_all_reduce(buf) + return buf[:world], buf[world:] + + def _build_dispatch_plan( + self, topk_ids: Tensor, num_tokens: int + ) -> Optional[WaterfillDispatchPlan]: + """Prepare dispatch state for the waterfill selection boundary.""" + local_routed_counts = self.count_local_routed(topk_ids) + if self.use_static_waterfill: + return self._build_static_dispatch_plan(local_routed_counts) + + global_routed_counts, local_tokens_per_rank = ( + DeepEPWaterfillBalancer._all_reduce_dynamic_rank_load( + local_routed_counts, num_tokens + ) + ) + if self._is_low_batch(num_tokens): + return None + return self._build_dynamic_dispatch_plan( + global_routed_counts, + local_tokens_per_rank=local_tokens_per_rank, + topk=topk_ids.shape[1], + ) + + def _materialize_dispatch( + self, + topk_ids: Tensor, + topk_weights: Tensor, + dispatch_plan: WaterfillDispatchPlan, + ) -> Tuple[Tensor, Tensor]: + """Expand TopK using local expansion or fused Waterfill.""" + num_tokens = topk_ids.shape[0] + if num_tokens == 0: + return _empty_expanded(topk_ids, topk_weights) + + if self._is_low_batch(num_tokens): + return expand_topk_with_shared_expert( + topk_ids, + topk_weights, + self.num_routed_experts, + self.world_size, + self.rank, + self.shared_weight, + ) + + return materialize_waterfill_dispatch_fused( + topk_ids, + topk_weights, + dispatch_plan.rank_load, + self.num_routed_experts, + self.world_size, + self.rank, + self.shared_weight, + allow_all_ranks=dispatch_plan.allow_all_ranks, + target_total=dispatch_plan.target_total, + ) + + @staticmethod + def _with_expanded_topk( + topk_output: StandardTopKOutput, + expanded_ids: Tensor, + expanded_weights: Tensor, + ) -> StandardTopKOutput: + """Wrap expanded tensors back into SGLang's StandardTopKOutput.""" + return StandardTopKOutput( + topk_weights=expanded_weights, + topk_ids=expanded_ids, + router_logits=topk_output.router_logits, + ) + + def _expand_local_shared( + self, topk_output: StandardTopKOutput + ) -> StandardTopKOutput: + expanded_ids, expanded_weights = expand_topk_with_shared_expert( + topk_output.topk_ids, + topk_output.topk_weights, + self.num_routed_experts, + self.world_size, + self.rank, + self.shared_weight, + ) + return self._with_expanded_topk(topk_output, expanded_ids, expanded_weights) + + def expand_topk( + self, topk_output: StandardTopKOutput, num_tokens: int + ) -> StandardTopKOutput: + """Expand topk [N, 8] -> [N, 9] with waterfill-assigned shared expert.""" + if self._can_skip_dispatch_plan_for_low_batch(num_tokens): + # Static mode can use local expansion without communication for small + # decode-sized batches. Dynamic mode still all-reduces before local + # expansion so all ranks participate consistently. + return self._expand_local_shared(topk_output) + + dispatch_plan = self._build_dispatch_plan(topk_output.topk_ids, num_tokens) + if dispatch_plan is None: + if num_tokens == 0: + expanded_ids, expanded_weights = _empty_expanded( + topk_output.topk_ids, topk_output.topk_weights + ) + return self._with_expanded_topk( + topk_output, expanded_ids, expanded_weights + ) + else: + return self._expand_local_shared(topk_output) + expanded_ids, expanded_weights = self._materialize_dispatch( + topk_output.topk_ids, + topk_output.topk_weights, + dispatch_plan, + ) + return self._with_expanded_topk(topk_output, expanded_ids, expanded_weights) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b5663e44be18..202eb46d4f3d 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -284,6 +284,25 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.layer_id = layer_id + if num_fused_shared_experts > 0: + from sglang.srt.server_args import get_global_server_args + + try: + self.enable_deepep_waterfill = ( + get_global_server_args().enable_deepep_waterfill + ) + except ValueError: + self.enable_deepep_waterfill = False + else: + self.enable_deepep_waterfill = False + + self.deepep_waterfill_balancer = None + if self.enable_deepep_waterfill: + # TODO(ch-wan): Refactor shared-expert fusion and routed TopK fusion. + top_k -= num_fused_shared_experts + num_fused_shared_experts = 0 + output_format = TopKOutputFormat.STANDARD + # flashinfer_mxfp4 backend only: True -> STANDARD (Mxfp4FlashinferTrtllmMoEMethod # consumes), False -> BYPASSED (flashinfer's own mxfp4 kernel). No-op otherwise. self.is_fp4_experts = is_fp4_experts @@ -303,6 +322,18 @@ def __init__( scoring_func=scoring_func, ) + def _apply_deepep_waterfill( + self, topk_output: TopKOutput, num_tokens: int + ) -> TopKOutput: + if self.enable_deepep_waterfill and self.deepep_waterfill_balancer is None: + raise RuntimeError( + "DeepEP waterfill TopK must be prepared by ModelRunner before forward." + ) + if self.deepep_waterfill_balancer is None: + return topk_output + assert TopKOutputChecker.format_is_standard(topk_output) + return self.deepep_waterfill_balancer.expand_topk(topk_output, num_tokens) + def forward_native( self, hidden_states: torch.Tensor, @@ -312,7 +343,7 @@ def forward_native( expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ) -> TopKOutput: self.topk_config.torch_native = True - return select_experts( + topk_output = select_experts( hidden_states=hidden_states, layer_id=self.layer_id, router_logits=router_logits, @@ -320,6 +351,7 @@ def forward_native( num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) + return self._apply_deepep_waterfill(topk_output, hidden_states.shape[0]) def forward_cuda( self, @@ -369,7 +401,7 @@ def forward_cuda( num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) - return topk_output + return self._apply_deepep_waterfill(topk_output, hidden_states.shape[0]) def forward_cpu( self, @@ -379,7 +411,7 @@ def forward_cpu( num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ) -> TopKOutput: - return select_experts( + topk_output = select_experts( hidden_states=hidden_states, layer_id=self.layer_id, router_logits=router_logits, @@ -387,6 +419,7 @@ def forward_cpu( num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) + return self._apply_deepep_waterfill(topk_output, hidden_states.shape[0]) def forward_npu( self, @@ -417,7 +450,18 @@ def empty_topk_output(self, device: torch.device) -> TopKOutput: topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) # FIXME: router_logits should be of size (0, num_experts) router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) - return StandardTopKOutput(topk_weights, topk_ids, router_logits) + topk_output = StandardTopKOutput(topk_weights, topk_ids, router_logits) + if self.topk_config.num_fused_shared_experts > 0 and is_deepep_class_backend(): + n = self.topk_config.num_fused_shared_experts + topk_output = topk_output._replace( + topk_ids=topk_output.topk_ids.new_empty( + (0, topk_output.topk_ids.shape[-1] + n) + ), + topk_weights=topk_output.topk_weights.new_empty( + (0, topk_output.topk_weights.shape[-1] + n) + ), + ) + return self._apply_deepep_waterfill(topk_output, 0) # ------------------------------- TopK implementation ------------------------------------- diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7537be81cd2d..ac0f5aa198b3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -122,6 +122,7 @@ set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.pooler import EmbeddingPoolerOutput from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.sampler import create_sampler @@ -650,6 +651,7 @@ def initialize(self, pre_model_load_memory: float): # Load the model self.sampler = create_sampler() self.load_model() + self._prepare_moe_topk() # Load the expert backup client self.expert_backup_client = ( @@ -1582,6 +1584,49 @@ def load_model(self): f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." ) from None + def _prepare_moe_topk(self): + balancer_cls = None + num_prepared = 0 + num_routed_experts = None + for module in self.model.modules(): + if not isinstance(module, TopK): + continue + if ( + not module.enable_deepep_waterfill + or module.deepep_waterfill_balancer is not None + ): + continue + if num_routed_experts is None: + num_routed_experts = getattr( + self.model_config.hf_config, "n_routed_experts", None + ) + if num_routed_experts is None: + raise ValueError( + "DeepEP waterfill requires model config n_routed_experts." + ) + if balancer_cls is None: + from sglang.srt.layers.moe.deepep_waterfill import ( + DeepEPWaterfillBalancer, + ) + + balancer_cls = DeepEPWaterfillBalancer + module.deepep_waterfill_balancer = balancer_cls( + num_routed_experts=num_routed_experts, + world_size=self.moe_ep_size, + rank=self.moe_ep_rank, + layer_id=module.layer_id, + routed_scaling_factor=( + module.topk_config.routed_scaling_factor + if module.topk_config.routed_scaling_factor is not None + else 1.0 + ), + ) + num_prepared += 1 + if num_prepared: + log_info_on_rank0( + logger, f"Prepared {num_prepared} DeepEP waterfill TopK modules." + ) + def update_expert_location( self, new_expert_location_metadata: ExpertLocationMetadata, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2c8637f0ec48..fdc852d08cba 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -953,16 +953,6 @@ def forward_deepep( ) else: topk_output = self.topk.empty_topk_output(hidden_states.device) - if is_deepep_class_backend() and self.num_fused_shared_experts > 0: - n = self.num_fused_shared_experts - topk_output = topk_output._replace( - topk_ids=topk_output.topk_ids.new_empty( - (0, topk_output.topk_ids.shape[-1] + n) - ), - topk_weights=topk_output.topk_weights.new_empty( - (0, topk_output.topk_weights.shape[-1] + n) - ), - ) if sbo_overlap_dispatch_flag: shared_output = None @@ -2352,19 +2342,10 @@ def determine_num_fused_shared_experts( if server_args.disable_shared_experts_fusion: return - # DeepEP + enforce: the only path that enables fusion under DeepEP. - if is_deepep_class_backend() and server_args.enforce_shared_experts_fusion: - log_info_on_rank0( - logger, - "DeepEP shared expert fusion: fusing shared expert into MoE kernel " - "at home EP rank local slot (--enforce-shared-experts-fusion).", - ) - self.num_fused_shared_experts = self.config.n_shared_experts - return - - # Check all conditions that disable fusion. disable_reason = None - if is_sbo_enabled() or is_tbo_enabled(): + if server_args.enforce_shared_experts_fusion: + pass + elif is_sbo_enabled() or is_tbo_enabled(): disable_reason = "SBO/TBO enabled: incompatible with fusing shared expert into MoE kernel." elif is_deepep_class_backend(): disable_reason = "DeepEP: fusion off by default (use --enforce-shared-experts-fusion to enable)." diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 941ca78a1fb4..67ec4841c24d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -622,6 +622,7 @@ class ServerArgs: elastic_ep_backend: Literal[None, "mooncake", "nixl"] = None enable_elastic_expert_backup: bool = False mooncake_ib_device: Optional[str] = None + enable_deepep_waterfill: bool = False elastic_ep_rejoin: bool = False # Mamba cache @@ -3128,6 +3129,13 @@ def _handle_moe_kernel_config(self): ) def _handle_a2a_moe(self): + if self.enable_deepep_waterfill and self.moe_a2a_backend != "deepep": + logger.warning( + "moe_a2a_backend is overridden to 'deepep' because DeepEP " + "Waterfill requires the DeepEP backend." + ) + self.moe_a2a_backend = "deepep" + if self.moe_a2a_backend == "deepep": if self.deepep_mode == "normal": logger.warning("Cuda graph is disabled because deepep_mode=`normal`") @@ -3136,6 +3144,16 @@ def _handle_a2a_moe(self): logger.warning( f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.enable_deepep_waterfill: + if self.disable_shared_experts_fusion: + logger.warning( + "disable_shared_experts_fusion is overridden to False because DeepEP Waterfill requires shared expert fusion." + ) + self.disable_shared_experts_fusion = False + self.enforce_shared_experts_fusion = True + logger.info( + "DeepEP Waterfill is enabled. Shared expert will be dispatched through DeepEP for load balancing." + ) if self.moe_a2a_backend == "mooncake": self.ep_size = self.tp_size @@ -5979,6 +5997,18 @@ def add_cli_args(parser: argparse.ArgumentParser): "(e.g., --mooncake-ib-device mlx5_0,mlx5_1). " "Default is None, which triggers automatic device detection when Mooncake Backend is enabled.", ) + parser.add_argument( + "--enable-deepep-waterfill", + action="store_true", + default=ServerArgs.enable_deepep_waterfill, + help="Enable DeepEP Waterfill: dispatch the shared expert as the 9th " + "routed expert to the least-loaded EP rank. Automatically sets " + "--moe-a2a-backend deepep, implicitly enables shared-expert fusion, " + "and supports --deepep-mode auto, normal, or low_latency. Use auto " + "or low_latency for production decode so CUDA graph remains enabled. " + "Supported on DeepSeek-V3/R1 " + "with EP >= 2.", + ) parser.add_argument( "--elastic-ep-rejoin", action="store_true", @@ -6526,7 +6556,12 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--disable-shared-experts-fusion", action="store_true", - help="Disable shared experts fusion optimization for deepseek v3/r1.", + help=( + "Disable the built-in shared experts fusion optimization for DeepSeek V3/R1. " + "Note: DeepEP Waterfill (--enable-deepep-waterfill) still routes shared expert " + "through DeepEP as an extra MoE slot, so shared expert is not separated from the " + "MoE path when Waterfill is enabled." + ), ) parser.add_argument( "--enforce-shared-experts-fusion", diff --git a/test/registered/unit/server_args/test_server_args.py b/test/registered/unit/server_args/test_server_args.py index a24fd5425aa5..939a91b49caf 100644 --- a/test/registered/unit/server_args/test_server_args.py +++ b/test/registered/unit/server_args/test_server_args.py @@ -515,6 +515,47 @@ def test_external_corpus_max_tokens_must_be_positive(self): self.assertIn("external-corpus-max-tokens", str(context.exception)) +class TestDeepEPWaterfillArgs(CustomTestCase): + def test_waterfill_enforces_shared_experts_fusion(self): + server_args = ServerArgs( + model_path="dummy", + moe_a2a_backend="deepep", + enable_deepep_waterfill=True, + disable_shared_experts_fusion=True, + ) + # dummy-model path short-circuits __post_init__; invoke the handler directly. + server_args._handle_a2a_moe() + + self.assertFalse(server_args.disable_shared_experts_fusion) + self.assertTrue(server_args.enforce_shared_experts_fusion) + + def test_waterfill_overrides_moe_a2a_backend_to_deepep(self): + server_args = ServerArgs( + model_path="dummy", + moe_a2a_backend="none", + enable_deepep_waterfill=True, + ) + # dummy-model path short-circuits __post_init__; invoke the handler directly. + server_args._handle_a2a_moe() + + self.assertEqual(server_args.moe_a2a_backend, "deepep") + self.assertTrue(server_args.enforce_shared_experts_fusion) + + def test_waterfill_supports_deepep_low_latency_mode(self): + server_args = ServerArgs( + model_path="dummy", + moe_a2a_backend="deepep", + enable_deepep_waterfill=True, + deepep_mode="low_latency", + ) + # dummy-model path short-circuits __post_init__; invoke the handler directly. + server_args._handle_a2a_moe() + + self.assertEqual(server_args.deepep_mode, "low_latency") + self.assertFalse(server_args.disable_cuda_graph) + self.assertTrue(server_args.enforce_shared_experts_fusion) + + class TestPrefillOnlyDisableKvCache(unittest.TestCase): """Validation for --prefill-only-disable-kv-cache.