diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 4cb4acf2ace6..8b7350c98df2 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -160,13 +160,9 @@ def __init__( self.num_experts = num_experts self.num_fused_shared_experts = num_fused_shared_experts - enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass() - - if enable_flashinfer_cutlass_moe and quant_config is None: - logger.warning("Disable flashinfer MoE when quantization config is None.") - enable_flashinfer_cutlass_moe = False - - self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe + self.enable_flashinfer_cutlass_moe = ( + get_moe_runner_backend().is_flashinfer_cutlass() + ) self.moe_ep_size = get_moe_expert_parallel_world_size() self.moe_ep_rank = get_moe_expert_parallel_rank() self.moe_tp_size = get_moe_tensor_parallel_world_size() diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index a7e9ce26fd47..28805c070cc7 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -124,6 +124,7 @@ def is_auto(self) -> bool: TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None DEEPEP_CONFIG: Optional[str] = None DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None +MOE_QUANTIZATION: Optional[str] = None def initialize_moe_config(server_args: ServerArgs): @@ -136,6 +137,7 @@ def initialize_moe_config(server_args: ServerArgs): global IS_SBO_ENABLED global TBO_TOKEN_DISTRIBUTION_THRESHOLD global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER + global MOE_QUANTIZATION MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend) MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend) @@ -152,6 +154,7 @@ def initialize_moe_config(server_args: ServerArgs): DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = ( server_args.disable_flashinfer_cutlass_moe_fp4_allgather ) + MOE_QUANTIZATION = server_args.quantization def get_moe_a2a_backend() -> MoeA2ABackend: @@ -231,6 +234,7 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather(): not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER and get_moe_runner_backend().is_flashinfer_cutlass() and is_dp_attention_enabled() + and MOE_QUANTIZATION == "modelopt_fp4" and get_moe_expert_parallel_world_size() == get_attention_dp_size() ) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index c0f35905a975..67c65d5f3664 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -8,7 +8,12 @@ from sglang.srt.custom_op import CustomOp from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe import ( + MoeRunner, + MoeRunnerBackend, + MoeRunnerConfig, + get_moe_runner_backend, +) from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -20,6 +25,7 @@ get_bool_env_var, is_cpu, is_hip, + next_power_of_2, set_weight_attrs, use_intel_amx_backend, ) @@ -41,6 +47,11 @@ from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight +try: + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe +except ImportError: + flashinfer_cutlass_fused_moe = None + class UnquantizedEmbeddingMethod(QuantizeMethodBase): """Unquantized method for embeddings.""" @@ -137,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, use_triton_kernels: bool = False): super().__init__() + self.use_flashinfer_cutlass = get_moe_runner_backend().is_flashinfer_cutlass() self.use_triton_kernels = use_triton_kernels self.with_bias = False @@ -228,6 +240,11 @@ def create_moe_runner( ) self.runner = MoeRunner(backend, moe_runner_config) + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 + return self.use_flashinfer_cutlass + def apply( self, layer: torch.nn.Module, @@ -263,6 +280,22 @@ def forward_cuda( w2_bias=getattr(layer, "w2_weight_bias", None), ) return self.runner.run(dispatch_output, quant_info) + elif self.use_flashinfer_cutlass: + output = flashinfer_cutlass_fused_moe( + input=x, + token_selected_experts=topk_output.topk_ids, + token_final_scales=topk_output.topk_weights, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + output_dtype=x.dtype, + quant_scales=None, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, + tp_size=layer.moe_tp_size, + tp_rank=layer.moe_tp_rank, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + )[0] + return StandardCombineInput(hidden_states=output) else: if _use_aiter: assert not moe_runner_config.no_combine, "unsupported" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 792201428cc3..80d02c248678 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1202,6 +1202,7 @@ def _handle_model_specific_adjustments(self): ) self.disable_overlap_schedule = True if is_sm100_supported(): + self.attention_backend = "triton" quantization_config = getattr(hf_config, "quantization_config", None) quant_method = ( quantization_config.get("quant_method") @@ -1468,8 +1469,8 @@ def _handle_data_parallelism(self): def _handle_moe_kernel_config(self): if self.moe_runner_backend == "flashinfer_cutlass": assert ( - self.quantization == "modelopt_fp4" - ), "modelopt_fp4 quantization is required for Flashinfer Cutlass MOE" + self.quantization == "modelopt_fp4" or self.quantization is None + ), "modelopt_fp4 quantization or bf16 is required for Flashinfer Cutlass MOE" assert self.ep_size in [ 1, self.tp_size, diff --git a/python/sglang/test/test_cutlass_w16a16_moe.py b/python/sglang/test/test_cutlass_w16a16_moe.py new file mode 100644 index 000000000000..f3dd0772f9f5 --- /dev/null +++ b/python/sglang/test/test_cutlass_w16a16_moe.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.topk import TopKConfig, select_experts +from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1024), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +# Reference implementation of torch_moe for unquantized weights +def torch_moe_reference(a, w13, w2, score, topk): + B, D = a.shape + + set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) + + # Flip w13 layout + dim = -2 + size = w13.size(dim) + assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" + half = size // 2 + # Reorder weight + w1, w3 = w13.split(half, dim=dim) + w13 = torch.cat([w3, w1], dim=dim).contiguous() + + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + for i in range(w13.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w13[i].transpose(0, 1)) @ w2[ + i + ].transpose(0, 1) + + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@torch.inference_mode() +def test_flashinfer_bf16_cutlass_moe(m: int, n: int, k: int, e: int, topk: int): + """ + Test the bf16 cutlass moe API. + + Args: + m: number of tokens + n: intermediate size + k: hidden size + e: number of experts + topk: top-k experts per token + """ + torch.manual_seed(7) + + dtype = torch.bfloat16 + + # Create unquantized weights + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + # w13: fused gate_up projection [num_experts, 2*intermediate, hidden] + # FlashInfer CUTLASS expects [up, gate] layout + w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + + # w2: down projection [num_experts, hidden, intermediate] + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Generate router scores + score = torch.randn((m, e), device="cuda", dtype=dtype) + + # Get topk routing + topk_output = select_experts( + hidden_states=a, + router_logits=score, + topk_config=TopKConfig(top_k=topk, renormalize=False), + ) + topk_weights, topk_ids, _ = topk_output + + # Test: Call FlashInfer CUTLASS fused_moe (unquantized version) + test_output = flashinfer_cutlass_fused_moe( + input=a, + token_selected_experts=topk_ids, + token_final_scales=topk_weights, + fc1_expert_weights=w13, + fc2_expert_weights=w2, + output_dtype=dtype, + quant_scales=None, + )[0] + + # Reference: Torch implementation + torch_output = torch_moe_reference(a, w13, w2, score, topk) + + # Compare outputs + torch.testing.assert_close(torch_output, test_output, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + # Run a simple test case + test_flashinfer_bf16_cutlass_moe(224, 1024, 1024, 8, 2)