Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/expert_parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Currently, DeepEP, Mooncake, `ascend_fuseep` and MORI only support cases where `
| `deep_gemm` | DeepGEMM backend optimized for MoE matrix multiplications, supporting contiguous layouts for prefill and masked layouts for decode; often JIT-compiled for performance. | Large-scale EP deployments with FP8 block-wise quantization. |
| `cutlass` | CUTLASS-based backend for efficient GEMMs. | NVIDIA architectures with CUTLASS support. |
| `flashinfer_trtllm` | FlashInfer integrated with TensorRT-LLM for accelerated MoE computations, supporting FP4 communication operators and high-performance GEMMs. | Blackwell with TRT-LLM. |
| `flashinfer_trtllm_routed` | FlashInfer integrated with TensorRT-LLM for accelerated routed MoE computations, consuming SGLang-computed top-k expert assignments and weights. | Blackwell with TRT-LLM. |
| `flashinfer_cutlass` | FlashInfer combined with CUTLASS for high-performance grouped GEMMs in MoE layers, handling FP4/FP8 quantization efficiently. | Blackwell with FP4/FP8 models. |
| `flashinfer_mxfp4` | FlashInfer variant optimized for MXFP4 (mixed FP4) quantization in MoE runners, focusing on memory-efficient low-precision inference. | Low-precision models with MXFP4. |
| `flashinfer_cutedsl` | FlashInfer with a custom DSL for flexible and efficient MoE kernel generation, integrated with ModelOpt FP4 quantization. | Low-precision models with NVFP4. |
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| --- | --- | --- | --- |
| `--expert-parallel-size`<br>`--ep-size`<br>`--ep` | The expert parallelism size. | `1` | Type: int |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`|
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_trtllm_routed`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
| `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,11 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
elif (
quant_config is None
or quant_config.get_name() == "fp8"
or quant_config.get_name() == "mxfp8"
or quant_config.get_name() == "modelopt_fp8"
or quant_config.get_name() == "compressed_tensors"
):
# FlashInferFusedMoE support bf16, fp8 and compressed_tensors
# FlashInferFusedMoE supports bf16, fp8, mxfp8 and compressed_tensors
return FusedMoE

if get_moe_runner_backend().is_flashinfer_cutlass():
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def __init__(
self.use_presharded_weights = use_presharded_weights

self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm()
self.use_flashinfer_trtllm_moe = (
get_moe_runner_backend().is_flashinfer_trtllm()
or get_moe_runner_backend().is_flashinfer_trtllm_routed()
)

# flashinfer_trtllm kernel requires intermediate_size to be a multiple of 128
# Pad the intermediate_size_per_partition if necessary
Expand Down Expand Up @@ -302,7 +305,10 @@ def __init__(
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
and get_moe_runner_backend().is_cutlass()
and (
get_moe_runner_backend().is_cutlass()
or get_moe_runner_backend().is_flashinfer_trtllm_routed()
)
)

self.routing_method_type = routing_method_type
Expand Down
264 changes: 218 additions & 46 deletions python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,77 @@ def align_fp8_moe_weights_for_flashinfer_trtllm(
layer.output2_scales_scalar = Parameter(output2_scales_scalar, requires_grad=False)


def align_mxfp8_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
"""Prepare MXFP8 MoE weights/scales for FlashInfer TRT-LLM kernels."""
from flashinfer import (
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)

w13_weight = cast(torch.Tensor, layer.w13_weight).contiguous()
w2_weight = cast(torch.Tensor, layer.w2_weight).contiguous()
w13_scale = cast(torch.Tensor, layer.w13_weight_scale_inv).contiguous()
w2_scale = cast(torch.Tensor, layer.w2_weight_scale_inv).contiguous()

assert w13_scale.dtype == torch.uint8
assert w2_scale.dtype == torch.uint8

num_experts, two_n, _ = w13_weight.shape
_, hidden_size, _ = w2_weight.shape
epilogue_tile_m = 128

w13_interleaved = [
reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
]
w13_scale_interleaved = [
reorder_rows_for_gated_act_gemm(w13_scale[i]) for i in range(num_experts)
]

w13_shuffled = [
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
for i in range(num_experts)
]
w2_shuffled = [
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
for i in range(num_experts)
]
w13_scale_shuffled = [
shuffle_matrix_sf_a(
w13_scale_interleaved[i].view(torch.uint8).reshape(two_n, -1),
epilogue_tile_m,
)
for i in range(num_experts)
]
w2_scale_shuffled = [
shuffle_matrix_sf_a(
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
epilogue_tile_m,
)
for i in range(num_experts)
]

# Keep parameter identities stable for CUDA graph capture reuse.
copy_or_rebind_param(
layer, "w13_weight", torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
)
copy_or_rebind_param(
layer, "w2_weight", torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
)
copy_or_rebind_param(
layer,
"w13_weight_scale_inv",
torch.stack(w13_scale_shuffled).reshape_as(w13_scale).contiguous(),
)
copy_or_rebind_param(
layer,
"w2_weight_scale_inv",
torch.stack(w2_scale_shuffled).reshape_as(w2_scale).contiguous(),
)
layer.w13_weight_scale_inv.format_ue8m0 = True
layer.w2_weight_scale_inv.format_ue8m0 = True


def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
"""Prepare FP4 MoE weights/scales for FlashInfer TRT-LLM kernels.

Expand Down Expand Up @@ -197,6 +268,7 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):

# Block-quant path
block_quant: bool
use_mxfp8: bool = False
weight_block_k: int | None = None
w13_weight_scale_inv: torch.Tensor | None = None
w2_weight_scale_inv: torch.Tensor | None = None
Expand All @@ -209,13 +281,27 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):
use_routing_scales_on_input: bool = False


def _pack_topk_for_flashinfer_routed(
topk_ids: torch.Tensor, topk_weights: torch.Tensor
) -> torch.Tensor:
"""Pack routed top-k tensors into FlashInfer's int32 format."""
packed_ids = topk_ids.to(torch.int32)
packed_weights = topk_weights.to(torch.bfloat16)
packed = (packed_ids << 16) | packed_weights.view(torch.int16).to(torch.int32)
# SGLang can mark padded tokens with -1 expert ids.
return packed.masked_fill_(packed_ids < 0, 0)


def fused_experts_none_to_flashinfer_trtllm_fp8(
dispatch_output: StandardDispatchOutput,
quant_info: FlashInferTrtllmFp8MoeQuantInfo,
runner_config: MoeRunnerConfig,
use_routed_topk: bool = False,
) -> StandardCombineInput:
from flashinfer.fused_moe import (
Fp8QuantizationType,
trtllm_fp8_block_scale_moe,
trtllm_fp8_block_scale_routed_moe,
trtllm_fp8_per_tensor_scale_moe,
)

Expand All @@ -228,64 +314,132 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(

hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert TopKOutputChecker.format_is_bypassed(topk_output)

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)
if TopKOutputChecker.format_is_bypassed(topk_output):
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)
else:
router_logits = None
topk_config = None
correction_bias = None

routing_method_type = quant_info.routing_method_type
fp8_quantization_type = (
Fp8QuantizationType.MxFp8
if quant_info.use_mxfp8
else Fp8QuantizationType.DeepSeekFp8
)
use_shuffled_weight = quant_info.use_mxfp8

if quant_info.block_quant:
assert quant_info.weight_block_k is not None
assert quant_info.w13_weight_scale_inv is not None
assert quant_info.w2_weight_scale_inv is not None

a_q, a_sf = per_token_group_quant_fp8(hidden_states, quant_info.weight_block_k)
a_sf_t = a_sf.t().contiguous()
if quant_info.use_mxfp8:
assert quant_info.weight_block_k == 32
from flashinfer import mxfp8_quantize

a_q, a_sf = mxfp8_quantize(hidden_states, False)
# FlashInfer TRT-LLM MxFP8 expects token-major activation scales:
# [num_tokens, hidden_size // 32] (no transpose).
a_sf_t = a_sf.view(torch.uint8).reshape(hidden_states.shape[0], -1)
else:
a_q, a_sf = per_token_group_quant_fp8(
hidden_states, quant_info.weight_block_k
)
a_sf_t = a_sf.t().contiguous()

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output = trtllm_fp8_block_scale_moe(
routing_logits=(
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale_inv,
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale_inv,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=(
topk_config.num_expert_group if topk_config.num_expert_group else 0
),
topk_group=topk_config.topk_group if topk_config.topk_group else 0,
intermediate_size=quant_info.intermediate_size,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=(
runner_config.routed_scaling_factor
if runner_config.routed_scaling_factor is not None
else 1.0
),
routing_method_type=routing_method_type,
use_shuffled_weight=False,
tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
)
if use_routed_topk:
assert (
runner_config.top_k is not None
), "runner_config.top_k is required for flashinfer_trtllm_routed."
assert TopKOutputChecker.format_is_standard(topk_output)
packed_topk_ids = _pack_topk_for_flashinfer_routed(
topk_ids=topk_output.topk_ids,
topk_weights=topk_output.topk_weights,
)

output = trtllm_fp8_block_scale_routed_moe(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale_inv,
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale_inv,
num_experts=quant_info.global_num_experts,
top_k=runner_config.top_k,
n_group=None,
topk_group=None,
intermediate_size=quant_info.intermediate_size,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=(
runner_config.routed_scaling_factor
if runner_config.routed_scaling_factor is not None
else 1.0
),
routing_method_type=(
RoutingMethodType.TopK
if routing_method_type == RoutingMethodType.DeepSeekV3
else routing_method_type
),
use_shuffled_weight=use_shuffled_weight,
weight_layout=0,
tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
fp8_quantization_type=fp8_quantization_type,
)
else:
assert TopKOutputChecker.format_is_bypassed(topk_output)

# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output = trtllm_fp8_block_scale_moe(
routing_logits=(
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale_inv,
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale_inv,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=(
topk_config.num_expert_group
if topk_config.num_expert_group
else 0
),
topk_group=topk_config.topk_group if topk_config.topk_group else 0,
intermediate_size=quant_info.intermediate_size,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=(
runner_config.routed_scaling_factor
if runner_config.routed_scaling_factor is not None
else 1.0
),
routing_method_type=routing_method_type,
use_shuffled_weight=use_shuffled_weight,
weight_layout=0,
tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
fp8_quantization_type=fp8_quantization_type,
)
else:
assert quant_info.w13_input_scale is not None
assert quant_info.output1_scales_scalar is not None
Expand Down Expand Up @@ -577,3 +731,21 @@ def fused_experts_none_to_flashinfer_trtllm(
raise TypeError(
f"Unexpected quant_info type for flashinfer_trtllm: {type(quant_info)}"
)


@register_fused_func("none", "flashinfer_trtllm_routed")
def fused_experts_none_to_flashinfer_trtllm_routed(
dispatch_output: StandardDispatchOutput,
quant_info: MoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo):
return fused_experts_none_to_flashinfer_trtllm_fp8(
dispatch_output,
quant_info,
runner_config,
use_routed_topk=True,
)
raise TypeError(
f"Unexpected quant_info type for flashinfer_trtllm_routed: {type(quant_info)}"
)
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/moe/moe_runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.runner_core = DeepGemmRunnerCore(config)
elif runner_backend.is_marlin():
self.runner_core = None # Marlin only supports fused path
elif runner_backend.is_flashinfer_trtllm():
elif (
runner_backend.is_flashinfer_trtllm()
or runner_backend.is_flashinfer_trtllm_routed()
):
self.runner_core = None # FlashInfer TRT-LLM only supports fused path
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/token_dispatcher/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def __init__(self, moe_runner_config: MoeRunnerConfig):
self.enable_flashinfer_cutlass_moe = (
get_moe_runner_backend().is_flashinfer_cutlass()
)
self.enable_flashinfer_trtllm_routed_moe = (
get_moe_runner_backend().is_flashinfer_trtllm_routed()
)
self.num_experts = moe_runner_config.num_experts
self.num_local_shared_experts = moe_runner_config.num_fused_shared_experts
self.num_local_routed_experts = (
Expand Down Expand Up @@ -142,6 +145,7 @@ def dispatch(
if (
self.moe_ep_size > 1
and not self.enable_flashinfer_cutlass_moe
and not self.enable_flashinfer_trtllm_routed_moe
and TopKOutputChecker.format_is_standard(topk_output)
):
if self.local_expert_mapping is None:
Expand Down
Loading
Loading