Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,117 @@ def trtllm_fp8_block_scale_routed_moe_wrapper(
return trtllm_fp8_block_scale_routed_moe(**kwargs)


def _fake_fp4_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: Optional[torch.Tensor],
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_bias: Optional[torch.Tensor],
gemm1_alpha: Optional[torch.Tensor],
gemm1_beta: Optional[torch.Tensor],
gemm1_clamp_limit: Optional[torch.Tensor],
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
gemm2_bias: Optional[torch.Tensor],
output1_scale_scalar: Optional[torch.Tensor],
output1_scale_gate_scalar: Optional[torch.Tensor],
output2_scale_scalar: Optional[torch.Tensor],
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
activation_type: int = 3,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
return torch.empty(
(hidden_states.shape[0], gemm2_weights.shape[1]),
dtype=torch.bfloat16,
device=hidden_states.device,
)


@register_custom_op(fake_impl=_fake_fp4_block_scale_routed_moe)
def trtllm_fp4_block_scale_routed_moe_wrapper(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: Optional[torch.Tensor],
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_bias: Optional[torch.Tensor],
gemm1_alpha: Optional[torch.Tensor],
gemm1_beta: Optional[torch.Tensor],
gemm1_clamp_limit: Optional[torch.Tensor],
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
gemm2_bias: Optional[torch.Tensor],
output1_scale_scalar: Optional[torch.Tensor],
output1_scale_gate_scalar: Optional[torch.Tensor],
output2_scale_scalar: Optional[torch.Tensor],
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
activation_type: int = 3,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe
except ImportError as e:
raise ImportError(
"Can't import trtllm_fp4_block_scale_routed_moe from flashinfer. "
"Please check flashinfer version."
) from e

return trtllm_fp4_block_scale_routed_moe(
topk_ids=topk_ids,
routing_bias=routing_bias,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
gemm1_weights=gemm1_weights,
gemm1_weights_scale=gemm1_weights_scale,
gemm1_bias=gemm1_bias,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
gemm2_weights=gemm2_weights,
gemm2_weights_scale=gemm2_weights_scale,
gemm2_bias=gemm2_bias,
output1_scale_scalar=output1_scale_scalar,
output1_scale_gate_scalar=output1_scale_gate_scalar,
output2_scale_scalar=output2_scale_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
do_finalize=do_finalize,
enable_pdl=enable_pdl,
activation_type=activation_type,
tune_max_num_tokens=tune_max_num_tokens,
)[0]


def _fake_fp8_per_tensor_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
Expand Down
188 changes: 117 additions & 71 deletions python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe.flashinfer_trtllm_moe import (
trtllm_fp4_block_scale_routed_moe_wrapper,
trtllm_fp8_block_scale_moe_wrapper,
trtllm_fp8_block_scale_routed_moe_wrapper,
trtllm_fp8_per_tensor_scale_moe_wrapper,
Expand Down Expand Up @@ -275,15 +276,15 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
w13_weight.size(0), # num_experts
)

# Set flashinfer parameters
# Set flashinfer parameters in-place
copy_or_rebind_param(layer, "w13_weight", gemm1_weights_fp4_shuffled.contiguous())
copy_or_rebind_param(layer, "w2_weight", gemm2_weights_fp4_shuffled.contiguous())
copy_or_rebind_param(
layer, "gemm1_weights_fp4_shuffled", gemm1_weights_fp4_shuffled
layer, "w13_weight_scale", gemm1_scales_fp4_shuffled.contiguous()
)
copy_or_rebind_param(
layer, "gemm2_weights_fp4_shuffled", gemm2_weights_fp4_shuffled
layer, "w2_weight_scale", gemm2_scales_fp4_shuffled.contiguous()
)
copy_or_rebind_param(layer, "gemm1_scales_fp4_shuffled", gemm1_scales_fp4_shuffled)
copy_or_rebind_param(layer, "gemm2_scales_fp4_shuffled", gemm2_scales_fp4_shuffled)

# Compute additional scaling factor needed for TRT-LLM
w2_input_scale_quant = cast(torch.Tensor, layer.w2_input_scale_quant)
Expand All @@ -294,14 +295,6 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
(w2_input_scale_quant * g1_alphas).to(torch.float32),
)

# Clean up weights that won't be used by TRT-LLM
del (
layer.w2_weight,
layer.w2_weight_scale,
layer.w13_weight,
layer.w13_weight_scale,
)


@dataclass
class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):
Expand Down Expand Up @@ -560,11 +553,10 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
class FlashInferTrtllmFp4MoeQuantInfo(MoeQuantInfo):
"""Quantization payload consumed by FlashInfer TRT-LLM FP4 MoE kernels."""

# Shuffled FP4 weights (processed by align_fp4_moe_weights_for_flashinfer_trtllm)
gemm1_weights_fp4_shuffled: torch.Tensor
gemm2_weights_fp4_shuffled: torch.Tensor
gemm1_scales_fp4_shuffled: torch.Tensor
gemm2_scales_fp4_shuffled: torch.Tensor
w13_weight: torch.Tensor
w2_weight: torch.Tensor
w13_weight_scale: torch.Tensor
w2_weight_scale: torch.Tensor

# Scaling factors
g1_scale_c: torch.Tensor
Expand Down Expand Up @@ -616,6 +608,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
dispatch_output: StandardDispatchOutput,
quant_info: FlashInferTrtllmFp4MoeQuantInfo,
runner_config: MoeRunnerConfig,
use_routed_topk: bool = False,
) -> StandardCombineInput:
"""FlashInfer TRTLLM FP4 MoE forward pass.

Expand All @@ -633,27 +626,26 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(

hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert TopKOutputChecker.format_is_bypassed(topk_output)
if TopKOutputChecker.format_is_bypassed(topk_output):
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)
else:
router_logits = None
topk_config = None
correction_bias = None

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
routing_method_type = quant_info.routing_method_type

# Quantize hidden states to FP4
hs_fp4, hs_scale_linear = quantize_hidden_states_fp4(
hidden_states, quant_info.w13_input_scale_quant
)

# DeepSeekV3 style routing requires float32 router logits
if routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)

with use_symmetric_memory(get_tp_group(), disabled=not is_allocation_symmetric()):
num_tokens = hs_fp4.shape[0]
hidden_size = (
Expand All @@ -663,46 +655,93 @@ def fused_experts_none_to_flashinfer_trtllm_fp4(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
)

result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape(
*hs_scale_linear.shape[:-1], -1
),
gemm1_weights=quant_info.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
output2_scale_scalar=quant_info.g2_alphas,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=quant_info.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=runner_config.routed_scaling_factor,
routing_method_type=(
routing_method_type
if routing_method_type is not None
else RoutingMethodType.Default
),
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]
if use_routed_topk:
assert (
runner_config.top_k is not None
), "runner_config.top_k is required for flashinfer_trtllm_routed."
assert TopKOutputChecker.format_is_standard(topk_output)
packed_topk_ids = _pack_topk_for_flashinfer_routed(
topk_ids=topk_output.topk_ids,
topk_weights=topk_output.topk_weights,
)

result = trtllm_fp4_block_scale_routed_moe_wrapper(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape(
*hs_scale_linear.shape[:-1], -1
),
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
output2_scale_scalar=quant_info.g2_alphas,
num_experts=quant_info.global_num_experts,
top_k=runner_config.top_k,
n_group=None,
topk_group=None,
intermediate_size=quant_info.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=runner_config.routed_scaling_factor,
routing_method_type=(
RoutingMethodType.TopK
if routing_method_type == RoutingMethodType.DeepSeekV3
else routing_method_type
),
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
)
else:
assert TopKOutputChecker.format_is_bypassed(topk_output)

result = trtllm_fp4_block_scale_moe(
routing_logits=(
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
),
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape(
*hs_scale_linear.shape[:-1], -1
),
gemm1_weights=quant_info.w13_weight,
gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=quant_info.w2_weight,
gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=quant_info.g1_scale_c,
output1_scale_gate_scalar=quant_info.g1_alphas,
output2_scale_scalar=quant_info.g2_alphas,
num_experts=quant_info.global_num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=quant_info.intermediate_size_per_partition,
local_expert_offset=quant_info.local_expert_offset,
local_num_experts=quant_info.local_num_experts,
routed_scaling_factor=runner_config.routed_scaling_factor,
routing_method_type=(
routing_method_type
if routing_method_type is not None
else RoutingMethodType.Default
),
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]

return StandardCombineInput(hidden_states=result)

Expand Down Expand Up @@ -858,6 +897,13 @@ def fused_experts_none_to_flashinfer_trtllm_routed(
quant_info: MoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
if isinstance(quant_info, FlashInferTrtllmFp4MoeQuantInfo):
return fused_experts_none_to_flashinfer_trtllm_fp4(
dispatch_output,
quant_info,
runner_config,
use_routed_topk=True,
)
if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo):
return fused_experts_none_to_flashinfer_trtllm_fp8(
dispatch_output,
Expand Down
Loading
Loading