Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,9 @@ def __init__(
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts

enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()

if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False

self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.enable_flashinfer_cutlass_moe = (
get_moe_runner_backend().is_flashinfer_cutlass()
)
self.moe_ep_size = get_moe_expert_parallel_world_size()
self.moe_ep_rank = get_moe_expert_parallel_rank()
self.moe_tp_size = get_moe_tensor_parallel_world_size()
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def is_auto(self) -> bool:
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
MOE_QUANTIZATION: Optional[str] = None


def initialize_moe_config(server_args: ServerArgs):
Expand All @@ -136,6 +137,7 @@ def initialize_moe_config(server_args: ServerArgs):
global IS_SBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
global MOE_QUANTIZATION

MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
Expand All @@ -152,6 +154,7 @@ def initialize_moe_config(server_args: ServerArgs):
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
)
MOE_QUANTIZATION = server_args.quantization


def get_moe_a2a_backend() -> MoeA2ABackend:
Expand Down Expand Up @@ -231,6 +234,7 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather():
not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
and get_moe_runner_backend().is_flashinfer_cutlass()
and is_dp_attention_enabled()
and MOE_QUANTIZATION == "modelopt_fp4"
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
)

Expand Down
35 changes: 34 additions & 1 deletion python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
Expand All @@ -20,6 +25,7 @@
get_bool_env_var,
is_cpu,
is_hip,
next_power_of_2,
set_weight_attrs,
use_intel_amx_backend,
)
Expand All @@ -41,6 +47,11 @@
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight

try:
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
flashinfer_cutlass_fused_moe = None


class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
Expand Down Expand Up @@ -137,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):

def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_flashinfer_cutlass = get_moe_runner_backend().is_flashinfer_cutlass()
self.use_triton_kernels = use_triton_kernels
self.with_bias = False

Expand Down Expand Up @@ -228,6 +240,11 @@ def create_moe_runner(
)
self.runner = MoeRunner(backend, moe_runner_config)

@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.use_flashinfer_cutlass

def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -263,6 +280,22 @@ def forward_cuda(
w2_bias=getattr(layer, "w2_weight_bias", None),
)
return self.runner.run(dispatch_output, quant_info)
elif self.use_flashinfer_cutlass:
output = flashinfer_cutlass_fused_moe(
input=x,
token_selected_experts=topk_output.topk_ids,
token_final_scales=topk_output.topk_weights,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
output_dtype=x.dtype,
quant_scales=None,
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0]
return StandardCombineInput(hidden_states=output)
else:
if _use_aiter:
assert not moe_runner_config.no_combine, "unsupported"
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,7 @@ def _handle_model_specific_adjustments(self):
)
self.disable_overlap_schedule = True
if is_sm100_supported():
self.attention_backend = "triton"
quantization_config = getattr(hf_config, "quantization_config", None)
quant_method = (
quantization_config.get("quant_method")
Expand Down Expand Up @@ -1468,8 +1469,8 @@ def _handle_data_parallelism(self):
def _handle_moe_kernel_config(self):
if self.moe_runner_backend == "flashinfer_cutlass":
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer Cutlass MOE"
self.quantization == "modelopt_fp4" or self.quantization is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, None here is not equivalent to bf16 quant. Check this:

if self.quantization is None:
self.quantization = quant_method

Copy link
Contributor

@samuellees samuellees Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I want to ensure the types of activation and weight of MoE are both fp16/bf16, do you have any suggestion here? @ch-wan

), "modelopt_fp4 quantization or bf16 is required for Flashinfer Cutlass MOE"
assert self.ep_size in [
1,
self.tp_size,
Expand Down
118 changes: 118 additions & 0 deletions python/sglang/test/test_cutlass_w16a16_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler

MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1024),
(224, 1024, 1024),
(224, 1024, 1536),
]


# Reference implementation of torch_moe for unquantized weights
def torch_moe_reference(a, w13, w2, score, topk):
B, D = a.shape

set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))

# Flip w13 layout
dim = -2
size = w13.size(dim)
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
half = size // 2
# Reorder weight
w1, w3 = w13.split(half, dim=dim)
w13 = torch.cat([w3, w1], dim=dim).contiguous()

a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)

for i in range(w13.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w13[i].transpose(0, 1)) @ w2[
i
].transpose(0, 1)

return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@torch.inference_mode()
def test_flashinfer_bf16_cutlass_moe(m: int, n: int, k: int, e: int, topk: int):
"""
Test the bf16 cutlass moe API.

Args:
m: number of tokens
n: intermediate size
k: hidden size
e: number of experts
topk: top-k experts per token
"""
torch.manual_seed(7)

dtype = torch.bfloat16

# Create unquantized weights
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

# w13: fused gate_up projection [num_experts, 2*intermediate, hidden]
# FlashInfer CUTLASS expects [up, gate] layout
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10

# w2: down projection [num_experts, hidden, intermediate]
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

# Generate router scores
score = torch.randn((m, e), device="cuda", dtype=dtype)

# Get topk routing
topk_output = select_experts(
hidden_states=a,
router_logits=score,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
topk_weights, topk_ids, _ = topk_output

# Test: Call FlashInfer CUTLASS fused_moe (unquantized version)
test_output = flashinfer_cutlass_fused_moe(
input=a,
token_selected_experts=topk_ids,
token_final_scales=topk_weights,
fc1_expert_weights=w13,
fc2_expert_weights=w2,
output_dtype=dtype,
quant_scales=None,
)[0]

# Reference: Torch implementation
torch_output = torch_moe_reference(a, w13, w2, score, topk)

# Compare outputs
torch.testing.assert_close(torch_output, test_output, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
# Run a simple test case
test_flashinfer_bf16_cutlass_moe(224, 1024, 1024, 8, 2)
Loading