diff --git a/docs/advanced_features/expert_parallelism.md b/docs/advanced_features/expert_parallelism.md index 5dc8f26043b5..fdde94f8caf7 100644 --- a/docs/advanced_features/expert_parallelism.md +++ b/docs/advanced_features/expert_parallelism.md @@ -15,6 +15,7 @@ SGLang's EP integrates diverse, highly efficient backends for different use case | **`none` (default)** | Disables all-to-all for EP. Uses All-Reduce or All-Gather for token dispatch. | Hybrid EP and TP setups. | | `deepep` | DeepEP, a communication library for efficient token shuffling in MoE models. | Large-scale EP deployments. | | `mooncake` | An extension of DeepEP for elastic inference, leveraging RDMA for high-performance data transfers. | Elastic EP serving. | +| `flashinfer` | Flashinfer implementation of all-to-all. | Large-scale EP deployments. | | `ascend_fuseep` | Ascend NPU native fused all-to-all communication. | Ascend NPU deployments. | DeepEP and Mooncake backends support two modes for token dispatch: `normal` mode (optimized for prefill workloads with high throughput) and `low_latency` mode (optimized for decode workloads with low latency and CUDA Graph compatibility). Users are recommended to set `--deepep-mode auto` to enable automatic dispatch mode switching during runtime. Setting `--deepep-mode normal` or `--deepep-mode low_latency` is useful for debugging or development purposes. diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index e9dc3b6fac95..52d864034921 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -61,6 +61,7 @@ SGLang supports various environment variables that can be used to configure its | --- | --- | --- | | `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` | | `SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | The maximum number of dispatched tokens on each GPU | `"128"` | +| `SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK` | The maximum number of dispatched tokens on each GPU for --moe-a2a-backend=flashinfer | `"1024"` | | `SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS` | Number of SMs used for DeepEP combine when single batch overlap is enabled | `"32"` | | `SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO` | Run shared experts on an alternate stream when single batch overlap is enabled on GB200. When not setting this flag, shared experts and down gemm will be overlapped with DeepEP combine together. | `"false"` | diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b5f90f2559d2..139fd2fc283a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -37,6 +37,7 @@ ) from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher +from sglang.srt.layers.moe.token_dispatcher.flashinfer import FlashinferDispatcher from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardDispatcher, StandardDispatchOutput, @@ -117,6 +118,14 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: hidden_size=moe_runner_config.hidden_size, params_dtype=moe_runner_config.params_dtype, ) + elif a2a_backend.is_flashinfer(): + return FlashinferDispatcher( + group=get_tp_group().device_group, + router_topk=moe_runner_config.top_k, + num_experts=moe_runner_config.num_experts, + num_local_experts=moe_runner_config.num_local_experts, + hidden_size=moe_runner_config.hidden_size, + ) else: raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}") diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 1eac345ce896..cff7bde1d4ce 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -16,6 +16,10 @@ DeepEPNormalCombineInput, DeepEPNormalDispatchOutput, ) +from sglang.srt.layers.moe.token_dispatcher.flashinfer import ( + FlashinferDispatcher, + FlashinferDispatchOutput, +) from sglang.srt.layers.moe.token_dispatcher.fuseep import NpuFuseEPDispatcher from sglang.srt.layers.moe.token_dispatcher.mooncake import ( MooncakeCombineInput, @@ -37,6 +41,8 @@ "DispatchOutput", "DispatchOutputFormat", "DispatchOutputChecker", + "FlashinferDispatchOutput", + "FlashinferDispatcher", "MooncakeCombineInput", "MooncakeDispatchOutput", "MooncakeEPDispatcher", diff --git a/python/sglang/srt/layers/moe/token_dispatcher/base.py b/python/sglang/srt/layers/moe/token_dispatcher/base.py index 06e2e2e5d70f..8134a4dea7c1 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/base.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/base.py @@ -25,6 +25,8 @@ DeepEPLLDispatchOutput, DeepEPNormalCombineInput, DeepEPNormalDispatchOutput, + FlashinferCombineInput, + FlashinferDispatchOutput, StandardCombineInput, StandardDispatchOutput, ) @@ -149,12 +151,19 @@ def format_is_deepep( ) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]: return dispatch_output.format.is_deepep() + @staticmethod + def format_is_flashinfer( + dispatch_output: DispatchOutput, + ) -> TypeGuard[FlashinferDispatchOutput]: + return dispatch_output.format.is_flashinfer() + class DispatchOutputFormat(Enum): STANDARD = "standard" DEEPEP_NORMAL = "deepep_normal" DEEPEP_LL = "deepep_ll" + FLASHINFER = "flashinfer" def is_standard(self) -> bool: return self == DispatchOutputFormat.STANDARD @@ -171,6 +180,9 @@ def is_deepep(self) -> bool: DispatchOutputFormat.DEEPEP_LL, ] + def is_flashinfer(self) -> bool: + return self == DispatchOutputFormat.FLASHINFER + @runtime_checkable class DispatchOutput(Protocol): @@ -213,11 +225,18 @@ def format_is_deepep( CombineInputFormat.DEEPEP_LL, ] + @staticmethod + def format_is_flashinfer( + combine_input: CombineInput, + ) -> TypeGuard[FlashinferCombineInput]: + return combine_input.format == CombineInputFormat.FLASHINFER + class CombineInputFormat(Enum): STANDARD = "standard" DEEPEP_NORMAL = "deepep_normal" DEEPEP_LL = "deepep_ll" + FLASHINFER = "flashinfer" @runtime_checkable diff --git a/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py new file mode 100644 index 000000000000..72d5b2ea3754 --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import logging +from typing import NamedTuple, Optional + +import torch + +from sglang.srt.environ import envs +from sglang.srt.layers.dp_attention import get_dp_global_num_tokens +from sglang.srt.layers.moe.token_dispatcher import ( + BaseDispatcher, + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, +) +from sglang.srt.layers.moe.token_dispatcher.flashinfer_utils import ( + TorchDistributedCommBackend, +) +from sglang.srt.layers.moe.topk import StandardTopKOutput, TopKOutput +from sglang.srt.layers.moe.utils import get_moe_runner_backend +from sglang.srt.server_args import get_global_server_args +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import get_int_env_var + +try: + from flashinfer import fp4_quantize, nvfp4_block_scale_interleave + from flashinfer.comm import MoeAlltoAll, moe_a2a_get_workspace_size_per_rank + from flashinfer.comm.mapping import Mapping + from flashinfer.comm.mnnvl import MnnvlConfig + + use_flashinfer = True +except ImportError: + use_flashinfer = False + +logger = logging.getLogger(__name__) + +MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get() + + +class FlashinferDispatchOutput(NamedTuple): + """Flashinfer EP dispatch output.""" + + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + topk_output: StandardTopKOutput + # Provide an output tensor to fused_moe so it writes directly to our buffer + moe_output: Optional[torch.Tensor] = None + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.FLASHINFER + + +assert isinstance(FlashinferDispatchOutput, DispatchOutput) + + +class FlashinferCombineInput(NamedTuple): + """Flashinfer combine input.""" + + hidden_states: torch.Tensor + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.FLASHINFER + + +assert isinstance(FlashinferCombineInput, CombineInput) + + +class FlashinferDispatcher(BaseDispatcher): + """Main dispatcher class for Flashinfer A2A backend.""" + + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + num_experts: int = None, + num_local_experts: int = None, # Unused + hidden_size: int = None, + params_dtype: torch.dtype = None, # Unused + ): + super().__init__() + if not use_flashinfer: + raise ImportError( + "Flashinfer is not installed or does not support A2A. " + "Please install the appropriate version of Flashinfer." + ) + + self.ep_size = group.size() + self.ep_rank = group.rank() + self.router_topk = router_topk + self.hidden_size = hidden_size + self.num_experts = num_experts + self.num_local_experts = num_local_experts + + # TODO: Can other moe runners use payload_in_workspace too? + self.payload_in_workspace = get_moe_runner_backend().is_flashinfer_cutlass() + + # TODO: Can this be a server arg and shared with deepep/mooncakeep? + self.max_num_tokens = ( + get_int_env_var("SGLANG_FLASHINFER_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 1024) + * self.ep_size + ) + + # Calculate workspace size. For eagle mode, use the larger workspace size since nextn layer will be unquantized. + speculative_algo = SpeculativeAlgorithm.from_string( + get_global_server_args().speculative_algorithm + ) + if MOE_NVFP4_DISPATCH and not speculative_algo.is_eagle(): + total_dispatch_payload_size_per_token = ( + hidden_size // 2 # nvfp4 hidden states + + hidden_size // 16 # fp8 scaling factors + + self.router_topk * 4 # int32 topks ids + + self.router_topk * 4 # float32 topk weights + ) + else: + total_dispatch_payload_size_per_token = ( + hidden_size * 2 # bf16 hidden states + + self.router_topk * 4 # int32 topks ids + + self.router_topk * 4 # float32 topk weights + ) + combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states + self.workspace_size = moe_a2a_get_workspace_size_per_rank( + ep_size=self.ep_size, + max_num_tokens=self.max_num_tokens, + total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token, + combine_payload_size_per_token=combine_payload_size_per_token, + ) + + self.mapping = Mapping( + rank=self.ep_rank, + tp_size=self.ep_size, + moe_ep_size=self.ep_size, + world_size=self.ep_size, + gpus_per_node=torch.cuda.device_count(), + pp_size=1, + cp_size=1, + ) + self.moe_a2a = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=self.max_num_tokens, + top_k=self.router_topk, + num_experts=self.num_experts, + workspace_size_per_rank=self.workspace_size, + mnnvl_config=MnnvlConfig(comm_backend=TorchDistributedCommBackend(group)), + ) + + # Preallocate dummy tensors (to overcome numLocalTokens > 0 restriction) + self.dummy_x = torch.empty( + (1, hidden_size), + dtype=torch.bfloat16, + device="cuda", + ) + # -1 will be ignored by flashinfer cutlass moe + self.dummy_topk_ids = torch.full( + (1, self.router_topk), -1, dtype=torch.int32, device="cuda" + ) + # Hack for dispatch with dummy token - will route the dummy token to this rank so it doesn't require any transfer. + self.dummy_topk_ids_current_rank = torch.full( + (1, self.router_topk), + self.ep_rank * self.num_local_experts, + dtype=torch.int32, + device="cuda", + ) + self.dummy_topk_weights = torch.zeros( + (1, self.router_topk), dtype=torch.float32, device="cuda" + ) + + def dispatch( + self, hidden_states: torch.Tensor, topk_output: TopKOutput + ) -> FlashinferDispatchOutput: + output_dtype = hidden_states.dtype + x = hidden_states + x_sf = None + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + + # Handle case where there are no tokens on this DP worker + # moe_a2a.dispatch requires at least one token + self.has_dummy_token = False + if x.shape[0] == 0: + logger.warning("No tokens on this DP worker, using dummy token") + self.has_dummy_token = True + x = self.dummy_x + topk_ids = self.dummy_topk_ids + topk_weights = self.dummy_topk_weights + + global_scale = self.quant_config.get("input_global_scale", None) + if global_scale is not None: + if x.shape[0] > 0: + x, x_sf = fp4_quantize(x, global_scale, is_sf_swizzled_layout=False) + else: + x = torch.zeros( + 0, self.hidden_size // 2, dtype=torch.uint8, device=x.device + ) + x_sf = torch.zeros( + 0, self.hidden_size // 16, dtype=torch.uint8, device=x.device + ) + + payloads = [] + payloads.append(x) + if x_sf is not None: + payloads.append(x_sf) + expert_id_payload_index = 2 + else: + expert_id_payload_index = 1 + payloads.append(topk_ids) + payloads.append(topk_weights) + + self.runtime_max_tokens_per_rank = ( + max(get_dp_global_num_tokens()) + if get_dp_global_num_tokens() is not None + else x.shape[0] + ) + recv_tensors = self.moe_a2a.dispatch( + self.dummy_topk_ids_current_rank if self.has_dummy_token else topk_ids, + payloads, + self.runtime_max_tokens_per_rank, + expert_id_payload_index=expert_id_payload_index, + ) + if x_sf is not None: + x_recv, x_sf_recv, topk_ids_recv, topk_weights_recv = recv_tensors + x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1]) + # TODO: fuse interleave into cutlass moe + x_sf = nvfp4_block_scale_interleave(x_sf) + else: + x_recv, topk_ids_recv, topk_weights_recv = recv_tensors + x = x_recv.view(-1, x_recv.shape[-1]) + topk_ids = topk_ids_recv.view(-1, topk_ids_recv.shape[-1]) + topk_weights = topk_weights_recv.view(-1, topk_weights_recv.shape[-1]) + + # Provide an output tensor to fused_moe so it writes directly to our buffer + moe_output = None + if self.payload_in_workspace: + moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace( + self.runtime_max_tokens_per_rank, self.hidden_size, output_dtype + ).view(-1, self.hidden_size) + return FlashinferDispatchOutput( + x, + x_sf, + StandardTopKOutput(topk_weights, topk_ids, topk_output.router_logits), + moe_output, + ) + + def combine(self, combine_input: FlashinferCombineInput) -> torch.Tensor: + hidden_states = combine_input.hidden_states + output_hidden_size = hidden_states.shape[-1] + hidden_states = self.moe_a2a.combine( + hidden_states.view( + self.ep_size, self.runtime_max_tokens_per_rank, output_hidden_size + ), + self.runtime_max_tokens_per_rank, + payload_in_workspace=self.payload_in_workspace, + ) + + # Remove dummy token if it was added in dispatch + if self.has_dummy_token: + hidden_states = hidden_states[1:, :] + + del self.runtime_max_tokens_per_rank + del self.has_dummy_token + return hidden_states diff --git a/python/sglang/srt/layers/moe/token_dispatcher/flashinfer_utils.py b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer_utils.py new file mode 100644 index 000000000000..7ba3071413b8 --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer_utils.py @@ -0,0 +1,47 @@ +import torch.distributed as dist + +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + from flashinfer.comm.mnnvl import CommBackend +else: + + class CommBackend: + """ + Placeholder base class when flashinfer is not available + """ + + pass + + +class TorchDistributedCommBackend(CommBackend): + """ + Use torch distributed instead of MPI to set up flashinfer MNNVL workspaces during initialization + """ + + def __init__(self, group: dist.ProcessGroup): + self._group = group + + def Get_rank(self) -> int: + return self._group.rank() + + def Get_size(self) -> int: + return self._group.size() + + def allgather(self, data: int): + gathered = [None] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + + def bcast(self, data, root: int = 0): + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] + + def Split(self, color: int, key: int): + # No need to split, we already use the proper group + return self + + def barrier(self): + dist.barrier(group=self._group) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 2c959c799389..0a127009885a 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -125,6 +125,7 @@ def dispatch( topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv( [topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens() ) + # TODO: fuse into cutlass moe x_sf = nvfp4_block_scale_interleave(x_sf) hidden_states = x diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 41c7cfdae11b..ba6ca01ff140 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -24,6 +24,7 @@ class MoeA2ABackend(Enum): DEEPEP = "deepep" MOONCAKE = "mooncake" ASCEND_FUSEEP = "ascend_fuseep" + FLASHINFER = "flashinfer" @classmethod def _missing_(cls, value): @@ -43,6 +44,9 @@ def is_deepep(self): def is_mooncake(self): return self == MoeA2ABackend.MOONCAKE + def is_flashinfer(self): + return self == MoeA2ABackend.FLASHINFER + def is_ascend_fuseep(self): return self == MoeA2ABackend.ASCEND_FUSEEP @@ -266,6 +270,7 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather(): """ return ( not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER + and get_moe_a2a_backend().is_none() and get_moe_runner_backend().is_flashinfer_cutlass() and is_dp_attention_enabled() and MOE_QUANTIZATION == "modelopt_fp4" diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 828a4343e770..f7fc3d026081 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -18,6 +18,7 @@ MoeRunner, MoeRunnerBackend, MoeRunnerConfig, + get_moe_a2a_backend, get_moe_runner_backend, ) from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType @@ -1479,6 +1480,7 @@ def _slice_scale(w): (1 / w2_input_scale).to(torch.float32), requires_grad=False ) + # TODO: for flashinfer always do MOE_NVFP4_DISPATCH layer.dispatcher.set_quant_config( { "input_global_scale": ( @@ -1661,6 +1663,8 @@ def apply( return StandardCombineInput(hidden_states=layer.forward(x, topk_output)) if self.enable_flashinfer_cutlass_moe: + from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker + assert ( not moe_runner_config.apply_router_weight_on_input ), "apply_router_weight_on_input is not supported for Flashinfer" @@ -1670,20 +1674,23 @@ def apply( output_dtype = torch.bfloat16 - # If x_sf is not None, x is FP4 packed (half size), so we need * 2 - # If x_sf is None, x is not packed, so output_col = x.shape[1] - output_col = x.shape[1] - if x_sf is not None and layer.moe_runner_config.is_gated: - output_col *= 2 - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - symm_output = torch.empty( - x.shape[0], - output_col, - dtype=output_dtype, - device=x.device, - ) + if DispatchOutputChecker.format_is_flashinfer(dispatch_output): + symm_output = dispatch_output.moe_output + else: + # If x_sf is not None, x is FP4 packed (half size), so we need * 2 + # If x_sf is None, x is not packed, so output_col = x.shape[1] + output_col = x.shape[1] + if x_sf is not None and layer.moe_runner_config.is_gated: + output_col *= 2 + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + symm_output = torch.empty( + x.shape[0], + output_col, + dtype=output_dtype, + device=x.device, + ) output = flashinfer_cutlass_fused_moe( output=symm_output, @@ -1694,6 +1701,7 @@ def apply( fc2_expert_weights=layer.w2_weight.view(torch.long), output_dtype=output_dtype, input_sf=x_sf, + # swizzled_input_sf=not get_moe_a2a_backend().is_flashinfer(), quant_scales=[ layer.w13_input_scale_quant, layer.w13_blockscale_swizzled.view(torch.int32), @@ -1708,6 +1716,7 @@ def apply( tp_rank=layer.moe_tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), activation_type=ACT_STR_TO_TYPE_MAP[activation], + enable_alltoall=get_moe_a2a_backend().is_flashinfer(), )[0] from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 365fd6690b49..3e5576cc3b15 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -467,6 +467,7 @@ def __init__( if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() or get_moe_a2a_backend().is_ascend_fuseep() + or get_moe_a2a_backend().is_flashinfer() or should_use_flashinfer_cutlass_moe_fp4_allgather() else {} ), @@ -529,6 +530,7 @@ def __init__( get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() or get_moe_a2a_backend().is_ascend_fuseep() + or get_moe_a2a_backend().is_flashinfer() ) self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo() diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 784f9c30dd70..f48056b07e5d 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -413,6 +413,7 @@ def __init__( dict(tp_rank=0, tp_size=1) if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_flashinfer() or should_use_flashinfer_cutlass_moe_fp4_allgather() else {} ), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index aea0c65bfec3..3bc14983a362 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -181,7 +181,7 @@ "cutlass", ] -MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep"] +MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep", "flashinfer"] FP8_GEMM_RUNNER_BACKEND_CHOICES = [ "auto", @@ -469,7 +469,9 @@ class ServerArgs: # Expert parallelism ep_size: int = 1 - moe_a2a_backend: Literal["none", "deepep", "mooncake", "ascend_fuseep"] = "none" + moe_a2a_backend: Literal[ + "none", "deepep", "mooncake", "ascend_fuseep", "flashinfer" + ] = "none" moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False @@ -2030,6 +2032,25 @@ def _handle_a2a_moe(self): logger.warning( f"Ascend fused EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.moe_a2a_backend == "flashinfer": + self.ep_size = self.tp_size + logger.warning( + f"Flashinfer MoE A2A is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + self.disable_shared_experts_fusion = True + logger.warning( + "Flashinfer MoE A2A is enabled. --disable-shared-experts-fusion is automatically set." + ) + if self.deepep_mode != "auto": + logger.warning("--deepep-mode is ignored for Flashinfer MoE A2A") + if os.environ.get("SGLANG_MOE_NVFP4_DISPATCH") is None: + envs.SGLANG_MOE_NVFP4_DISPATCH.set(True) + logger.warning( + "SGLANG_MOE_NVFP4_DISPATCH is set to True for Flashinfer MoE A2A" + ) + assert self.moe_runner_backend in [ + "flashinfer_cutlass" + ], "Flashinfer MoE A2A is only supported with flashinfer_cutlass moe runner backend" def _handle_eplb_and_dispatch(self): if self.enable_eplb and (self.expert_distribution_recorder_mode is None): diff --git a/python/sglang/test/test_flashinfer_dispatcher.py b/python/sglang/test/test_flashinfer_dispatcher.py new file mode 100644 index 000000000000..cbb6bccdfa89 --- /dev/null +++ b/python/sglang/test/test_flashinfer_dispatcher.py @@ -0,0 +1,322 @@ +import unittest + +import torch + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.parallel_state import ( + get_tp_group, + initialize_model_parallel, +) +from sglang.srt.layers.dp_attention import set_dp_buffer_len +from sglang.srt.layers.moe.token_dispatcher.flashinfer import FlashinferDispatcher +from sglang.srt.layers.moe.utils import initialize_moe_config +from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.test.test_utils import CustomTestCase + + +class TestFlashinferDispatcher(CustomTestCase): + + @classmethod + def setUpClass(cls): + server_args = ServerArgs(model_path="dummy") + server_args.moe_runner_backend = "flashinfer_cutlass" + server_args.moe_a2a_backend = "flashinfer" + set_global_server_args_for_scheduler(server_args) + initialize_moe_config(server_args) + + init_distributed_environment( + world_size=-1, # Auto-detect from environment + rank=-1, # Auto-detect from environment + local_rank=-1, # Auto-detect from environment + backend="nccl", + ) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + initialize_model_parallel( + tensor_model_parallel_size=world_size, expert_model_parallel_size=world_size + ) + + @classmethod + def tearDownClass(cls): + # Clean up distributed environment + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + def create_dispatcher( + self, router_topk=2, num_experts=8, num_local_experts=4, hidden_size=128 + ): + """Helper to create dispatcher instance""" + return FlashinferDispatcher( + group=get_tp_group().device_group, + router_topk=router_topk, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=torch.bfloat16, + ) + + def test_dispatch_basic(self): + """Test basic dispatch functionality""" + num_tokens = 16 + hidden_size = 128 + router_topk = 1 # Single expert per token for simplicity + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + num_experts = world_size + num_local_experts = 1 # One expert per rank + + set_dp_buffer_len( + global_dp_buffer_len=num_tokens * world_size, + local_dp_buffer_len=num_tokens, + dp_max_padding=True, + global_num_tokens=None, + ) + + # Create tokens with rank number + hidden_states = torch.full( + (num_tokens, hidden_size), 100.0 + rank, dtype=torch.bfloat16, device="cuda" + ) + + # Route all tokens from rank i to expert (i+1) % world_size + target_rank = (rank + 1) % world_size + target_expert = target_rank # Since we have 1 expert per rank + + topk_ids = torch.full( + (num_tokens, router_topk), target_expert, dtype=torch.int32, device="cuda" + ) + topk_weights = torch.ones( + (num_tokens, router_topk), dtype=torch.float32, device="cuda" + ) + + from sglang.srt.layers.moe.topk import StandardTopKOutput + + topk_output = StandardTopKOutput( + topk_weights=topk_weights, topk_ids=topk_ids, router_logits=None + ) + + torch.distributed.barrier() + dispatcher = self.create_dispatcher( + router_topk=router_topk, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + ) + dispatcher.set_quant_config({"input_global_scale": None}) + + dispatch_output = dispatcher.dispatch(hidden_states, topk_output) + received_hidden_states = dispatch_output.hidden_states + self.assertEqual(dispatch_output.hidden_states_scale, None) + + # Expected: we should receive tokens from rank (rank - 1) % world_size + expected_source_rank = (rank - 1 + world_size) % world_size + + # Verify we received the right number of tokens + self.assertEqual( + received_hidden_states.shape[0], + num_tokens * world_size, + f"Should receive {num_tokens * world_size} tokens", + ) + + # Verify tokens came from the expected source + self.assertTrue( + torch.all( + received_hidden_states[ + expected_source_rank + * num_tokens : (expected_source_rank + 1) + * num_tokens + ] + == 100.0 + expected_source_rank + ) + ) + self.assertTrue( + torch.all( + received_hidden_states[: expected_source_rank * num_tokens] == 0.0 + ) + ) + self.assertTrue( + torch.all( + received_hidden_states[(expected_source_rank + 1) * num_tokens :] == 0.0 + ) + ) + + def test_dispatch_with_empty_tokens(self): + """Test dispatch when there are no tokens (edge case)""" + # This tests the dummy token handling + num_tokens = 16 + hidden_size = 1 + router_topk = 1 # Single expert per token for simplicity + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + num_experts = world_size + num_local_experts = 1 # One expert per rank + + set_dp_buffer_len( + global_dp_buffer_len=num_tokens * world_size, + local_dp_buffer_len=num_tokens, + dp_max_padding=False, + global_num_tokens=[16, 0, 16, 16], + ) + + # Route all tokens from rank i to expert (i+1) % world_size + target_rank = (rank + 1) % world_size + target_expert = target_rank # Since we have 1 expert per rank + + # Create tokens with rank number, rank 1 has no tokens + if rank == 1: + hidden_states = torch.empty( + 0, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + topk_ids = torch.empty(0, router_topk, dtype=torch.int32, device="cuda") + topk_weights = torch.empty( + 0, router_topk, dtype=torch.float32, device="cuda" + ) + else: + hidden_states = torch.full( + (num_tokens, hidden_size), + 100.0 + rank, + dtype=torch.bfloat16, + device="cuda", + ) + topk_ids = torch.full( + (num_tokens, router_topk), + target_expert, + dtype=torch.int32, + device="cuda", + ) + topk_weights = torch.ones( + (num_tokens, router_topk), dtype=torch.float32, device="cuda" + ) + + from sglang.srt.layers.moe.topk import StandardTopKOutput + + topk_output = StandardTopKOutput( + topk_weights=topk_weights, topk_ids=topk_ids, router_logits=None + ) + + dispatcher = self.create_dispatcher( + router_topk=router_topk, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + ) + dispatcher.set_quant_config({"input_global_scale": None}) + + dispatch_output = dispatcher.dispatch(hidden_states, topk_output) + received_hidden_states = dispatch_output.hidden_states + + # Expected: we should receive tokens from rank (rank - 1) % world_size + expected_source_rank = (rank - 1 + world_size) % world_size + + # Verify we received the right number of tokens + self.assertEqual( + received_hidden_states.shape[0], + num_tokens * world_size, + f"Should receive {num_tokens * world_size} tokens", + ) + + # Verify tokens came from the expected source + if rank == 2: + # Rank 2 should receive no tokens since rank 1 was empty + self.assertTrue( + torch.all(received_hidden_states == 0.0), + "Rank should receive no tokens", + ) + else: + self.assertTrue( + torch.all( + received_hidden_states[ + expected_source_rank + * num_tokens : (expected_source_rank + 1) + * num_tokens + ] + == 100.0 + expected_source_rank + ), + "Rank {rank} should receive tokens from the expected source {expected_source_rank}", + ) + self.assertTrue( + torch.all( + received_hidden_states[: expected_source_rank * num_tokens] == 0.0 + ), + "Rank should receive no tokens from previous ranks", + ) + self.assertTrue( + torch.all( + received_hidden_states[(expected_source_rank + 1) * num_tokens :] + == 0.0 + ), + "Rank should receive no tokens from next ranks", + ) + + def test_dispatch_with_fp4_quantization(self): + """Test dispatch with FP4 quantization enabled""" + num_tokens = 128 + hidden_size = 128 + router_topk = 1 # Single expert per token for simplicity + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + num_experts = world_size + num_local_experts = 1 # One expert per rank + + set_dp_buffer_len( + global_dp_buffer_len=num_tokens * world_size, + local_dp_buffer_len=num_tokens, + dp_max_padding=True, + global_num_tokens=None, + ) + + # Create tokens with random values + hidden_states = torch.randn( + (num_tokens, hidden_size), dtype=torch.bfloat16, device="cuda" + ) + + # Route all tokens from rank i to expert (i+1) % world_size + target_rank = (rank + 1) % world_size + target_expert = target_rank # Since we have 1 expert per rank + + topk_ids = torch.full( + (num_tokens, router_topk), target_expert, dtype=torch.int32, device="cuda" + ) + topk_weights = torch.ones( + (num_tokens, router_topk), dtype=torch.float32, device="cuda" + ) + + from sglang.srt.layers.moe.topk import StandardTopKOutput + + topk_output = StandardTopKOutput( + topk_weights=topk_weights, topk_ids=topk_ids, router_logits=None + ) + + dispatcher = self.create_dispatcher( + router_topk=router_topk, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + ) + # Set input global scale to enable FP4 quantization + input_global_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") + dispatcher.set_quant_config({"input_global_scale": input_global_scale}) + + dispatch_output = dispatcher.dispatch(hidden_states, topk_output) + + self.assertEqual( + dispatch_output.hidden_states.shape, + (num_tokens * world_size, hidden_size // 2), + ) + self.assertEqual(dispatch_output.hidden_states.dtype, torch.uint8) + + self.assertNotEqual(dispatch_output.hidden_states_scale, None) + self.assertEqual( + dispatch_output.hidden_states_scale.numel(), + num_tokens * world_size * (hidden_size // 16), + ) + self.assertEqual(dispatch_output.hidden_states_scale.dtype, torch.uint8) + + +if __name__ == "__main__": + """ + Usage + torchrun --nproc_per_node=4 test_flashinfer_dispatcher.py + """ + unittest.main()