Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ class BypassedTopKOutput(NamedTuple):
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED

def to_standard(self, layer_id: Optional[int] = None) -> "StandardTopKOutput":
"""Materialize routing tensors. Used by MoE kernels that need explicit
topk_ids / topk_weights rather than doing routing internally."""
return select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
topk_config=self.topk_config,
layer_id=layer_id,
num_token_non_padded=self.num_token_non_padded,
expert_location_dispatch_info=self.expert_location_dispatch_info,
)


# -------------------------------- TopK ---------------------------------------

Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ def get_quant_method(
return Mxfp4MarlinMoEMethod(fp8_method, prefix=prefix)

if self.is_fp4_experts and get_moe_runner_backend().is_flashinfer_mxfp4():
# SM100 (Blackwell) -> trtllm-gen path.
# SM90 (Hopper) -> cutlass mixed-input path (FlashInfer #3084).
if is_sm90_supported() and not is_sm100_supported():
from sglang.srt.layers.quantization.mxfp4_flashinfer_cutlass_moe import (
Mxfp4FlashinferCutlassMoEMethod,
)

return Mxfp4FlashinferCutlassMoEMethod(fp8_method, prefix=prefix)

from sglang.srt.layers.quantization.mxfp4_flashinfer_trtllm_moe import (
Mxfp4FlashinferTrtllmMoEMethod,
)
Expand Down
270 changes: 269 additions & 1 deletion python/sglang/srt/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@

from __future__ import annotations

import os
from dataclasses import replace
from typing import TYPE_CHECKING, List, Optional

import torch
from torch.nn.parameter import Parameter

# Silence the TRT-LLM cutlass autotune trace embedded inside FlashInfer's
# cutlass_fused_moe. Its C++ logger reads TLLM_LOG_LEVEL on first kernel launch;
# setdefault preserves any explicit user override.
os.environ.setdefault("TLLM_LOG_LEVEL", "INFO")
Comment thread
yuan-luo marked this conversation as resolved.

from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
Expand Down Expand Up @@ -62,7 +68,27 @@
nvfp4_block_scale_interleave,
trtllm_fp4_block_scale_moe,
)
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from flashinfer.fused_moe.core import (
ActivationType,
get_w2_permute_indices_with_cache,
)

# SM90 mixed-input helpers landed in FlashInfer #3084 (post-0.6.10). Older
# versions don't ship them; gate at import so unrelated code paths still load.
try:
from flashinfer.fused_moe import (
interleave_moe_scales_for_sm90_mixed_gemm,
interleave_moe_weights_for_sm90_mixed_gemm,
)

_FI_HAS_SM90_CUTLASS_MXFP4 = True
except ImportError:
interleave_moe_scales_for_sm90_mixed_gemm = None
interleave_moe_weights_for_sm90_mixed_gemm = None
_FI_HAS_SM90_CUTLASS_MXFP4 = False
else:
_FI_HAS_SM90_CUTLASS_MXFP4 = False

_flashinfer_mxfp4_permute_indices_cache: dict[torch.Size, torch.Tensor] = {}
_flashinfer_mxfp4_permute_indices_device_cache: dict[
Expand Down Expand Up @@ -318,6 +344,28 @@ def __init__(
self.flashinfer_mxfp4_moe_precision = (
get_global_server_args().flashinfer_mxfp4_moe_precision
)
# When `flashinfer_mxfp4` is enabled, dispatch to one of two FlashInfer
# entry points depending on the GPU:
# - SM100 (Blackwell) -> trtllm_fp4_block_scale_moe (existing)
# - SM90 (Hopper) -> cutlass_fused_moe(use_w4_group_scaling=True)
# (FlashInfer PR #3084, post-0.6.10)
self._fi_kernel: Optional[str] = None
if self.use_flashinfer:
if is_sm100_supported():
self._fi_kernel = "trtllm_sm100"
elif is_sm90_supported():
if not _FI_HAS_SM90_CUTLASS_MXFP4:
raise RuntimeError(
"moe_runner_backend=flashinfer_mxfp4 on SM90 requires the "
"interleave_moe_{weights,scales}_for_sm90_mixed_gemm helpers "
"from FlashInfer PR #3084 (>= 0.6.11). Upgrade flashinfer-python "
"or pick a different backend (e.g. marlin / triton_kernel)."
)
self._fi_kernel = "cutlass_sm90"
else:
raise NotImplementedError(
"moe_runner_backend=flashinfer_mxfp4 requires SM90 or SM100."
)

def create_weights(
self,
Expand Down Expand Up @@ -349,6 +397,26 @@ def create_weights(
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, triton_kernels_padding_alignment
)
elif self._fi_kernel == "cutlass_sm90":
# cutlass mixed-input GEMM contraction dim K must be % 128 == 0
# (interleave factor for MXFP4 group_size=32 is 4). The kernel
# also expects ``fc1_expert_weights`` in halved ``[up; gate]``
# layout, which means the padding boundary must fall on the
# gate / up split.
#
# The mxfp4 weight loader (FusedMoE.weight_loader fast path) does
# a NAIVE copy of HF's ``[2*intermediate_size, hidden_packed]``
# tensor into the buffer's ``[:dim1, :dim2]`` slice. Padding the
# buffer here would push the gate/up boundary, so HF's "up"
# rows would land in the buffer's "gate" half and vice versa.
# Marlin sidesteps this by not padding; we do the same and
# rebuild a properly-padded buffer in
# ``_process_weights_for_sm90_cutlass`` after the load completes.
self._padded_intermediate = round_up(intermediate_size_per_partition, 128)
self._padded_hidden = round_up(hidden_size, 128)
# create_weights below uses the *unpadded* sizes so the loader's
# naive-copy fast path is correct.
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
elif _use_aiter:

intermediate_size_per_partition_after_pad = round_up(
Expand Down Expand Up @@ -438,6 +506,9 @@ def create_weights(
set_weight_attrs(w2_weight_bias, extra_weight_attrs)

def process_weights_after_loading(self, layer):
if self._fi_kernel == "cutlass_sm90":
self._process_weights_for_sm90_cutlass(layer)
return
if self.use_flashinfer:
# TODO: these values are hardcoded for now, we need to get them from the model
layer.gemm1_alpha = Parameter(
Expand Down Expand Up @@ -736,6 +807,133 @@ def swap_every_two_rows(x, axis=-1):
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache()

def _process_weights_for_sm90_cutlass(self, layer):
"""De-interleave + pad + halving-swap + byte-interleave MXFP4 weights
for FlashInfer's SM90 ``cutlass_fused_moe(use_w4_group_scaling=True)``
path (PR #3084).

The cutlass kernel needs (a) K (contraction dim) % 128 == 0, and (b)
``fc1_expert_weights`` in halved ``[up; gate]`` order -- the
``compute_with_experts`` reference in FlashInfer's
``test_trtllm_cutlass_fused_moe.py`` splits
``w3, w1 = chunk(W, 2, dim=0)`` and uses w3 as up, w1 as gate.

GPT-OSS's HF layout is *interleaved* ``[g_0, u_0, g_1, u_1, ..., g_{N-1}, u_{N-1}]``
(each pair occupies two adjacent rows). The mxfp4 weight loader does
a naive copy, so our unpadded buffer is interleaved post-load. We
de-interleave (even rows -> gate, odd rows -> up), pad each half from
N_un to N_pad, concatenate as halved ``[up; gate]``, and then run
FlashInfer's byte / scale interleave helpers.
"""
sf_block_size = 32 # MXFP4 group size

# Sizes from the unpadded loaded buffers.
N_un = layer.w13_weight.shape[1] // 2 # intermediate (unpadded)
K_un = (
layer.w13_weight.shape[2] * 2
) # hidden (unpadded, *2 because packed 4-bit)
N_pad = self._padded_intermediate
K_pad = self._padded_hidden
# Use the local expert count (matches the existing buffer allocation in
# create_weights) so the SM90 cutlass path remains correct under
# Expert Parallelism. `self.num_experts` is the *global* count.
E = layer.num_local_experts
device = layer.w13_weight.device
bias_dtype = layer.w13_weight_bias.dtype

# ---- De-interleave + pad w13 weight/scale/bias to halved [up; gate]
# Even rows of HF = gate, odd rows = up. After splitting we pad each
# half along its row dim (N) from N_un to N_pad with zeros, and along
# its last dim (K) from K_un (or K_un / sf_block_size) to K_pad.

def _stack_up_gate_w13(unpadded_w13, last_pad, last_un):
# unpadded_w13: [E, 2*N_un, last_un]
# Returns: [E, 2*N_pad, last_pad] in [up_padded; gate_padded] order.
gate_rows = unpadded_w13[:, 0::2, :] # [E, N_un, last_un]
up_rows = unpadded_w13[:, 1::2, :] # [E, N_un, last_un]
out = torch.zeros(
E, 2 * N_pad, last_pad, dtype=unpadded_w13.dtype, device=device
)
# First half: up (with row + col padding zeros).
out[:, :N_un, :last_un] = up_rows
# Second half: gate.
out[:, N_pad : N_pad + N_un, :last_un] = gate_rows
return out

w13_padded = _stack_up_gate_w13(
layer.w13_weight.data.view(torch.uint8), K_pad // 2, K_un // 2
)
w13_scale_padded = _stack_up_gate_w13(
layer.w13_weight_scale.data,
K_pad // sf_block_size,
K_un // sf_block_size,
)
# Bias: same de-interleave on dim=-1.
w13_bias_gate = layer.w13_weight_bias.data[:, 0::2] # [E, N_un]
w13_bias_up = layer.w13_weight_bias.data[:, 1::2] # [E, N_un]
w13_bias_padded = torch.zeros(E, 2 * N_pad, dtype=bias_dtype, device=device)
w13_bias_padded[:, :N_un] = w13_bias_up
w13_bias_padded[:, N_pad : N_pad + N_un] = w13_bias_gate

def _pad_w2_3d(unpadded, last_pad, last_un):
out = torch.zeros(E, K_pad, last_pad, dtype=unpadded.dtype, device=device)
out[:, :K_un, :last_un] = unpadded[:, :K_un, :]
return out

# ---- w2 (no halving, just pad to [E, K_pad, N_pad/2]) ----------------
w2_padded = _pad_w2_3d(
layer.w2_weight.data.view(torch.uint8), N_pad // 2, N_un // 2
)
w2_scale_padded = _pad_w2_3d(
layer.w2_weight_scale.data,
N_pad // sf_block_size,
N_un // sf_block_size,
)
w2_bias_padded = torch.zeros(E, K_pad, dtype=bias_dtype, device=device)
w2_bias_padded[:, :K_un] = layer.w2_weight_bias.data

# ---- Per-expert SwiGLU scalars (GPT-OSS defaults) ------------------
layer.swiglu_alpha = Parameter(
torch.full((E,), 1.702, dtype=torch.float32, device=device),
requires_grad=False,
)
layer.swiglu_beta = Parameter(
torch.full((E,), 1.0, dtype=torch.float32, device=device),
requires_grad=False,
)
layer.swiglu_limit = Parameter(
torch.full((E,), 7.0, dtype=torch.float32, device=device),
requires_grad=False,
)
Comment thread
yuan-luo marked this conversation as resolved.

# ---- FlashInfer SM90 byte / scale interleave -----------------------
# The padded buffers above are contiguous by construction (allocated
# via torch.zeros + slice assignment), so we feed them straight in.
layer.w13_weight = Parameter(
interleave_moe_weights_for_sm90_mixed_gemm(w13_padded, "fp4"),
requires_grad=False,
)
layer.w2_weight = Parameter(
interleave_moe_weights_for_sm90_mixed_gemm(w2_padded, "fp4"),
requires_grad=False,
)
layer.w13_weight_scale = Parameter(
interleave_moe_scales_for_sm90_mixed_gemm(
w13_scale_padded, group_size=sf_block_size
),
requires_grad=False,
)
layer.w2_weight_scale = Parameter(
interleave_moe_scales_for_sm90_mixed_gemm(
w2_scale_padded, group_size=sf_block_size
),
requires_grad=False,
)
layer.w13_weight_bias = Parameter(w13_bias_padded, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_bias_padded, requires_grad=False)

torch.cuda.empty_cache()

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
Expand All @@ -761,6 +959,74 @@ def create_moe_runner(
# TODO(cwan): refactor other backends
pass

def _apply_sm90_cutlass(self, layer, x, topk_output):
"""SM90 (Hopper) MXFP4 x BF16 MoE via FlashInfer's cutlass mixed-input
path (PR #3084). The fused kernel does GEMM1 + SwiGLU + GEMM2 in one
call; weights/scales were pre-interleaved at load time."""
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker

# Under ``--moe-runner-backend flashinfer_mxfp4`` the SGLang TopK layer
# emits BypassedTopKOutput by default (the SM100 trtllm-gen kernel does
# routing internally). The cutlass kernel needs explicit topk_ids /
# topk_weights, so materialize them here when bypassed.
if TopKOutputChecker.format_is_bypassed(topk_output):
Comment thread
yuan-luo marked this conversation as resolved.
topk_output = topk_output.to_standard()
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids

# Pad input hidden dim to the (already-padded) loaded weight width.
origin_hidden = x.shape[-1]
padded_hidden = self._padded_hidden
if padded_hidden != origin_hidden:
x = torch.nn.functional.pad(
x,
(0, padded_hidden - origin_hidden),
mode="constant",
value=0.0,
)

output_dtype = torch.bfloat16
Comment thread
yuan-luo marked this conversation as resolved.
# Output is allocated at padded width (kernel writes padded_hidden
# columns), then trimmed back to origin_hidden before returning.
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
out_padded = torch.empty(
x.shape[0], padded_hidden, dtype=output_dtype, device=x.device
)

flashinfer_cutlass_fused_moe(
input=x,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=layer.w13_weight, # uint8 [E, 2*N, K/2] interleaved
fc2_expert_weights=layer.w2_weight, # uint8 [E, K, N/2] interleaved
output_dtype=output_dtype,
quant_scales=[
layer.w13_weight_scale.view(torch.int32),
layer.w2_weight_scale.view(torch.int32),
],
fc1_expert_biases=layer.w13_weight_bias, # bf16 [E, 2*N]
fc2_expert_biases=layer.w2_weight_bias, # bf16 [E, K]
swiglu_alpha=layer.swiglu_alpha,
Comment thread
yuan-luo marked this conversation as resolved.
swiglu_beta=layer.swiglu_beta,
swiglu_limit=layer.swiglu_limit,
tp_size=layer.moe_tp_size,
Comment thread
yuan-luo marked this conversation as resolved.
tp_rank=layer.moe_tp_rank,
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
use_w4_group_scaling=True,
activation_type=ActivationType.Swiglu,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
output=out_padded,
)

if padded_hidden != origin_hidden:
out = out_padded[:, :origin_hidden].contiguous()
else:
out = out_padded
return StandardCombineInput(hidden_states=out)

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -773,6 +1039,8 @@ def apply(
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

if self._fi_kernel == "cutlass_sm90":
return self._apply_sm90_cutlass(layer, x, topk_output)
if self.use_flashinfer:
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
Expand Down
Loading
Loading