Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 109 additions & 56 deletions python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,13 +559,17 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
dispatch_output: StandardDispatchOutput,
quant_info: FlashInferTrtllmFp4MoeQuantInfo,
runner_config: MoeRunnerConfig,
use_routed_topk: bool = False,
) -> StandardCombineInput:
"""FlashInfer TRTLLM FP4 MoE forward pass.

This function handles the FP4 TRTLLM MoE path that was previously in
ModelOptNvFp4FusedMoEMethod.apply.
"""
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
from flashinfer.fused_moe import (
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
)

from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker
Expand All @@ -576,25 +580,13 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(

hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert TopKOutputChecker.format_is_bypassed(topk_output)

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
routing_method_type = quant_info.routing_method_type

# Quantize hidden states to FP4
hs_fp4, hs_scale_linear = quantize_hidden_states_fp4(
hidden_states, quant_info.w13_input_scale_quant
)

# DeepSeekV3 style routing requires float32 router logits
if routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
hs_scale = hs_scale_linear.view(torch.float8_e4m3fn).reshape(
*hs_scale_linear.shape[:-1], -1
)

with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()):
Expand All @@ -603,49 +595,103 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
hs_fp4.shape[-1] * 2 if hs_fp4.dtype == torch.uint8 else hs_fp4.shape[-1]
)
symm_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
num_tokens, hidden_size, dtype=hidden_states.dtype, device=hs_fp4.device
)

result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape(
*hs_scale_linear.shape[:-1], -1
),
gemm1_weights=quant_info.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
output2_scale_scalar=quant_info.g2_alphas,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=quant_info.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=runner_config.routed_scaling_factor,
routing_method_type=(
routing_method_type
if routing_method_type is not None
else RoutingMethodType.Default
),
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]
if use_routed_topk:
assert TopKOutputChecker.format_is_standard(topk_output)

packed_topk_ids = _pack_topk_for_flashinfer_routed(
topk_output.topk_ids, topk_output.topk_weights
)
result = trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=quant_info.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
output2_scale_scalar=quant_info.g2_alphas,
num_experts=quant_info.global_num_experts,
top_k=topk_output.topk_ids.shape[1],
n_group=0,
topk_group=0,
intermediate_size=quant_info.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1, # Unused, but must be 1 to pass validation.
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]
else:
assert TopKOutputChecker.format_is_bypassed(topk_output)

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
routing_method_type = quant_info.routing_method_type

# DeepSeekV3 style routing requires float32 router logits
if routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=quant_info.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
output2_scale_scalar=quant_info.g2_alphas,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=quant_info.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=runner_config.routed_scaling_factor,
routing_method_type=(
routing_method_type
if routing_method_type is not None
else RoutingMethodType.Default
),
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]

return StandardCombineInput(hidden_states=result)

Expand Down Expand Up @@ -801,6 +847,13 @@ def fused_experts_none_to_flashinfer_trtllm_routed(
quant_info: MoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
if isinstance(quant_info, FlashInferTrtllmFp4MoeQuantInfo):
return fused_experts_none_to_flashinfer_trtllm_fp4(
dispatch_output,
quant_info,
runner_config,
use_routed_topk=True,
)
if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo):
return fused_experts_none_to_flashinfer_trtllm_fp8(
dispatch_output,
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,7 @@ def __init__(self, quant_config: ModelOptFp4Config):
)
self.enable_flashinfer_trtllm_moe = (
get_moe_runner_backend().is_flashinfer_trtllm()
or get_moe_runner_backend().is_flashinfer_trtllm_routed()
)
self._cache_permute_indices = {}

Expand Down Expand Up @@ -1904,6 +1905,10 @@ def create_moe_runner(
self.runner = MoeRunner(
MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config
)
elif get_moe_runner_backend().is_flashinfer_trtllm_routed():
self.runner = MoeRunner(
MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED, moe_runner_config
)

def apply(
self,
Expand Down
Loading