Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,27 +536,22 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_ascend_fuseep():
return NpuFuseEPMoE

# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if get_moe_runner_backend().is_flashinfer_trtllm():
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
if quant_config is None:
return FusedMoE
try:
# Check the quantization argument directly
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE,
)

return FlashInferFP4MoE
except:
pass
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE

return FlashInferFP4MoE
elif (
quant_config is None
or quant_config.get_name() == "fp8"
or quant_config.get_name() == "modelopt_fp8"
):
# FlashInferFusedMoE support bf16 and fp8
return FlashInferFusedMoE

if get_moe_runner_backend().is_flashinfer_trtllm() and quant_config is not None:
# FIXME: FlashInferFusedMoE only supports fp8 quant now
return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
return FusedMoE
71 changes: 56 additions & 15 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(
self.use_presharded_weights = use_presharded_weights

self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm()

self.quant_config = quant_config
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
Expand Down Expand Up @@ -236,7 +237,9 @@ def __init__(
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method is None:
self.quant_method = UnquantizedFusedMoEMethod(self.use_triton_kernels)
self.quant_method = UnquantizedFusedMoEMethod(
self.use_triton_kernels, self.use_flashinfer_trtllm_moe
)

self.quant_method.create_weights(
layer=self,
Expand Down Expand Up @@ -640,9 +643,10 @@ def _weight_loader_impl(
raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.")

# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if get_moe_runner_backend().is_flashinfer_trtllm() and (
if self.use_flashinfer_trtllm_moe and (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
or isinstance(self.quant_method, Fp8MoEMethod)
or isinstance(self.quant_method, UnquantizedFusedMoEMethod)
):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]

Expand Down Expand Up @@ -1036,29 +1040,66 @@ def __init__(self, *args, **kwargs):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert (
self.moe_runner_config.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
), "Only silu is supported for flashinfer trtllm moe"
assert self.quant_method is not None
assert (
topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
), "Renormalize is required for flashinfer trtllm moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
), "Fused shared experts are not supported for flashinfer trtllm moe"
assert (
self.moe_runner_config.is_gated
), "Only gated MoEs are supported for flashinfer blockscale fp8 moe"
), "Only gated MoEs are supported for flashinfer trtllm moe"

assert TopKOutputChecker.format_is_bypassed(topk_output)

# Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
dispatch_output=StandardDispatchOutput(
hidden_states=hidden_states,
hidden_states_scale=None,
topk_output=topk_output,
),
)
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
correction_bias = topk_config.correction_bias

if isinstance(self.quant_method, UnquantizedFusedMoEMethod):
# lazy import
try:
from flashinfer.fused_moe import trtllm_bf16_moe
except ImportError as e:
raise ImportError(
"Can't import trtllm_bf16_moe from flashinfer. "
"Please check flashinfer version to use bf16 with flashinfer_trtllm backend."
) from e

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# TODO: Now trtllm_bf16_moe doesn't support inplace output,
# we can move this out when it support that.
final_hidden_states = trtllm_bf16_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hidden_states,
gemm1_weights=self.w13_weight,
gemm2_weights=self.w2_weight,
num_experts=self.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routing_method_type=self.routing_method_type,
)

else:

# FP8 Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
dispatch_output=StandardDispatchOutput(
hidden_states=hidden_states,
hidden_states_scale=None,
topk_output=topk_output,
),
)

# NOTE for symmetric memory tagging:
# We do not create the context in this function.
Expand Down
71 changes: 70 additions & 1 deletion python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,15 @@ def apply(
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

def __init__(self, use_triton_kernels: bool = False):
def __init__(
self, use_triton_kernels: bool = False, use_flashinfer_trtllm_moe: bool = False
):
super().__init__()
self.use_flashinfer_cutlass = get_moe_runner_backend().is_flashinfer_cutlass()
self.use_triton_kernels = use_triton_kernels
self.with_bias = False
self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
self._cache_permute_indices = dict({})

def create_weights(
self,
Expand Down Expand Up @@ -227,6 +231,71 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

# Reorder rows of W1 for fused gated activation
if self.use_flashinfer_trtllm_moe:
from flashinfer.fused_moe.core import (
_maybe_get_cached_w3_w1_permute_indices,
convert_to_block_layout,
get_w2_permute_indices_with_cache,
)

# w1 and w3 have been swapped, so we don't need do that here
epilogue_tile_m = 128
block_k = 128
old_shape_w13 = layer.w13_weight.data[0].shape
old_shape_w2 = layer.w2_weight.data[0].shape
new_shape_w13 = None
new_shape_w2 = None
for i in range(layer.num_local_experts):
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
layer.w13_weight.data[i].view(torch.uint8),
epilogue_tile_m,
)
tmp_weights1 = (
layer.w13_weight.data[i]
.clone()
.view(torch.uint8)[permute_indices.to(layer.w13_weight.data.device)]
.contiguous()
)

permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
layer.w2_weight.data[i].view(torch.uint8),
epilogue_tile_m,
)
tmp_weights2 = (
layer.w2_weight.data[i]
.clone()
.view(torch.uint8)[permute_indices.to(layer.w2_weight.data.device)]
.contiguous()
)

tmp_weights1 = convert_to_block_layout(
tmp_weights1.view(torch.uint8), block_k
)
tmp_weights2 = convert_to_block_layout(
tmp_weights2.view(torch.uint8), block_k
)

new_shape_w13 = tmp_weights1.view(torch.bfloat16).shape
new_shape_w2 = tmp_weights2.view(torch.bfloat16).shape
layer.w13_weight.data[i] = (
tmp_weights1.view(torch.bfloat16)
.contiguous()
.reshape(old_shape_w13)
)
layer.w2_weight.data[i] = (
tmp_weights2.view(torch.bfloat16).contiguous().reshape(old_shape_w2)
)

layer.w13_weight.data = layer.w13_weight.data.reshape(
layer.num_local_experts, *new_shape_w13
)
layer.w2_weight.data = layer.w2_weight.data.reshape(
layer.num_local_experts, *new_shape_w2
)

return

def create_moe_runner(
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,13 @@ def __init__(
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
# and requires the output format to be standard (except trtllm). We use quant_config to determine the output format.
output_format=(
TopKOutputFormat.STANDARD
if (quant_config is None)
and (not get_moe_runner_backend().is_flashinfer_trtllm())
else None
),
)

self.shared_experts_is_int8 = False
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,11 +1591,12 @@ def _handle_moe_kernel_config(self):
], "The expert parallel size must be 1 or the same as the tensor parallel size"

if self.moe_runner_backend == "flashinfer_trtllm":
assert (
self.quantization == "modelopt_fp4"
or self.quantization == "modelopt_fp8"
or self.quantization == "fp8"
), "modelopt_fp4, modelopt_fp8 or fp8 quantization is required for Flashinfer TRTLLM MoE"
assert self.quantization in [
"modelopt_fp4",
"fp8",
"modelopt_fp8",
None,
], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', or bfloat16 (None)."
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
Expand Down
48 changes: 47 additions & 1 deletion test/nightly/test_flashinfer_trtllm_gen_moe_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
register_cuda_ci(est_time=300, suite="nightly-4-gpu-b200", nightly=True)


class TestFlashinferTrtllmGenMoeBackend(CustomTestCase):
class TestFlashinferTrtllmGenMoeBackendFP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
Expand Down Expand Up @@ -60,5 +60,51 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.93)


class TestFlashinferTrtllmGenMoeBackendBF16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--attention-backend",
"triton",
"--moe-runner-backend",
"flashinfer_trtllm",
"--cuda-graph-max-bs",
"512",
"--tp-size",
"4",
"--ep-size",
"4",
"--mem-fraction-static",
"0.7",
"--mamba-ssm-dtype",
"bfloat16",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.93)


if __name__ == "__main__":
unittest.main()
Loading