Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
79ba699
Add flashinfer CuteDSL masked grouped gemm support
wenscarl Aug 23, 2025
64f72e7
zero to empty init.
wenscarl Aug 24, 2025
4e03150
Address comment
wenscarl Aug 25, 2025
0a1d699
Upd
wenscarl Aug 27, 2025
73aa90a
Add masked_m
wenscarl Aug 27, 2025
b09d92d
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Aug 31, 2025
6b96c98
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 1, 2025
73b9605
Update python/sglang/srt/layers/quantization/modelopt_quant.py
fzyzcjy Sep 1, 2025
f7fc26d
fix error
fzyzcjy Sep 1, 2025
ec2c719
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 1, 2025
dfb3ac3
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 2, 2025
f9bb5bc
Merge branch 'sgl-project:main' into flashinfer_cutedsl_grp_gemm
wenscarl Sep 3, 2025
7773823
Skip fusing scaling factor into router weights for cutedsl backend
wenscarl Sep 4, 2025
1497ccc
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 4, 2025
a88062e
Make unittest rigorous
wenscarl Sep 4, 2025
7c7a6dc
Fix lint
wenscarl Sep 4, 2025
21ff185
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 4, 2025
5950cd6
Merge branch 'main-upstream' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 5, 2025
6bac5df
Add comment
wenscarl Sep 5, 2025
4d8812f
Enable fused scaling factor
wenscarl Sep 5, 2025
4cac99f
Address comments
wenscarl Sep 8, 2025
f387bf0
Add e2e test
wenscarl Sep 9, 2025
ee04919
Merge remote-tracking branch 'origin/main' into flashinfer_cutedsl_gr…
wenscarl Sep 9, 2025
cc2a57f
add e2e test to test_suite
wenscarl Sep 9, 2025
023e004
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 10, 2025
73a207c
Fix lint
wenscarl Sep 10, 2025
a2407b5
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 10, 2025
2cd04a1
Remove CI test temporarily
wenscarl Sep 10, 2025
54b2657
Merge remote-tracking branch 'origin/main' into flashinfer_cutedsl_gr…
wenscarl Sep 10, 2025
6276e25
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 10, 2025
c4552f1
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 11, 2025
2422f5d
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 11, 2025
cd64a56
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ SGLang supports various environment variables that can be used to configure its
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` |

## DeepEP Configuration

| Environment Variable | Description | Default Value |
| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` |

## Memory Management

| Environment Variable | Description | Default Value |
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ def moe_impl(self, dispatch_output: DispatchOutput):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
Expand Down Expand Up @@ -638,6 +640,22 @@ def forward_deepgemm_contiguous(

return gather_out

def forward_flashinfer_cutedsl(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, masked_m, _ = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"

output = self.quant_method.apply_without_routing_weights(
layer=self,
x=hidden_states,
masked_m=masked_m,
moe_runner_config=self.moe_runner_config,
)
return output

def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
Expand Down
156 changes: 156 additions & 0 deletions python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Any, Dict, Optional

import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from sgl_kernel.gemm import (
scaled_fp4_grouped_quant,
silu_and_mul_scaled_fp4_grouped_quant,
)


def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")


def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor,
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.

Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices

Notes:
- Assumes max(masked_m) <= m.
"""

# === Assertions on dtypes ===
assert (
input_global_scale.dtype == torch.float32
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
assert (
w1_blockscale.dtype == torch.float8_e4m3fn
), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
assert (
w1_alpha.dtype == torch.float32
), f"w1_alpha must be float32, got {w1_alpha.dtype}"
assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
assert (
a2_global_scale.dtype == torch.float32
), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
assert (
w2_blockscale.dtype == torch.float8_e4m3fn
), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
assert (
w2_alpha.dtype == torch.float32
), f"w2_alpha must be float32, got {w2_alpha.dtype}"

# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
num_experts, m, k = hidden_states.shape

assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert (
w1.shape[-1] * 2 == k
), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"

assert input_global_scale.shape == (
num_experts,
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
assert w1_alpha.shape == (
num_experts,
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
assert a2_global_scale.shape == (
num_experts,
), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
assert w2_alpha.shape == (
num_experts,
), f"w2_alpha must be (l,), got {w2_alpha.shape}"

aq, aq_sf = scaled_fp4_grouped_quant(
hidden_states,
input_global_scale,
masked_m,
)
gateup_output = torch.empty(
(num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
)
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"

c_dtype = get_cute_dtype(hidden_states)

# Gemm1

grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
gateup_output,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]

# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
gateup_output.permute(2, 0, 1),
a2_global_scale,
masked_m,
)

# Gemm2
out = torch.empty_like(hidden_states)
out = out.permute(1, 2, 0) # requirement of kernel
grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
return out.permute(2, 0, 1)
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ def dispatch_a(
hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states,
topk_idx,
use_fp8=True,
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
)
return (
hidden_states,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"

def is_auto(self):
return self == MoeRunnerBackend.AUTO
Expand All @@ -65,6 +66,9 @@ def is_flashinfer_trtllm(self):
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS

def is_flashinfer_cutedsl(self):
return self == MoeRunnerBackend.FLASHINFER_CUTEDSL

def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4

Expand Down
42 changes: 41 additions & 1 deletion python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,13 @@ def enable_flashinfer_cutlass_moe(self) -> bool:
"""Access the global enable_flashinfer_cutlass_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutlass()

@property
def enable_flashinfer_cutedsl_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend

"""Access the global enable_flashinfer_cutedsl_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutedsl()

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -1398,5 +1405,38 @@ def apply(
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts.

return StandardCombineInput(hidden_states=output)

def apply_without_routing_weights(
self,
layer: FusedMoE,
x: torch.Tensor,
masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."

assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
assert (
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"

from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
)

out = flashinfer_cutedsl_moe_masked(
hidden_states=x,
input_global_scale=layer.w13_input_scale_quant,
w1=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.w2_weight,
a2_global_scale=layer.w2_input_scale_quant,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=masked_m,
)
return out
8 changes: 6 additions & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,10 +673,14 @@ def forward_deepep(

if shared_output is not None:
x = shared_output
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
if self.experts.should_fuse_routed_scaling_factor_in_topk():
x.add_(final_hidden_states)
else:
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
final_hidden_states *= self.routed_scaling_factor
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
final_hidden_states *= self.routed_scaling_factor

return final_hidden_states

Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_cutedsl_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
Expand All @@ -420,6 +421,11 @@ def __post_init__(self):
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutedsl_moe:
self.moe_runner_backend = "flashinfer_cutedsl"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
Expand Down Expand Up @@ -1622,6 +1628,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
Expand Down Expand Up @@ -2204,6 +2211,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-cutedsl-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
Expand Down
Loading
Loading