Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,24 +795,4 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_ascend_fuseep():
return NpuFuseEPMoE

if get_moe_runner_backend().is_flashinfer_trtllm():
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE

return FlashInferFP4MoE
elif (
quant_config is None
or quant_config.get_name() == "fp8"
or quant_config.get_name() == "mxfp8"
or quant_config.get_name() == "modelopt_fp8"
or quant_config.get_name() == "compressed_tensors"
):
# FlashInferFusedMoE supports bf16, fp8, mxfp8 and compressed_tensors
return FusedMoE

if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
return FusedMoE
319 changes: 0 additions & 319 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from sglang.srt.layers.moe.token_dispatcher.flashinfer import FlashinferDispatcher
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardDispatcher,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import (
BypassedTopKOutput,
Expand All @@ -66,16 +65,11 @@
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_flashinfer_available,
is_hip,
next_power_of_2,
round_up,
)
from sglang.srt.utils.custom_op import register_custom_op

if is_flashinfer_available():
from flashinfer import fp4_quantize

_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
Expand Down Expand Up @@ -1146,267 +1140,6 @@ def clear_overlap_args(self) -> None:
self.meta_overlap_args = None


class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert TopKOutputChecker.format_is_bypassed(
topk_output
), "Only bypassed topk output is supported for flashinfer trtllm moe"

if is_in_piecewise_cuda_graph():
return flashinfer_bf16_moe_forward_piecewise_cuda_graph_impl(
hidden_states,
topk_output.router_logits,
topk_output.topk_config.top_k,
topk_output.topk_config.topk_group,
topk_output.topk_config.num_expert_group,
topk_output.topk_config.correction_bias,
topk_output.topk_config.renormalize,
self.layer_id,
)
else:
return self.forward_impl(hidden_states, topk_output)

def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert (
self.moe_runner_config.activation == "silu"
), "Only silu is supported for flashinfer trtllm moe"
assert self.quant_method is not None
assert (
topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer trtllm moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer trtllm moe"
assert (
self.moe_runner_config.is_gated
), "Only gated MoEs are supported for flashinfer trtllm moe"

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
correction_bias = topk_config.correction_bias
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor

if isinstance(self.quant_method, UnquantizedFusedMoEMethod):
# lazy import
try:
from flashinfer.fused_moe import trtllm_bf16_moe
except ImportError as e:
raise ImportError(
"Can't import trtllm_bf16_moe from flashinfer. "
"Please check flashinfer version to use bf16 with flashinfer_trtllm backend."
) from e

# Allocate output inside symmetric memory context
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
# TODO: Now trtllm_bf16_moe doesn't support inplace output,
# we can move this out when it support that.
symm_output = torch.empty(
hidden_states.shape[0],
hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# Move kernel call outside context manager to avoid graph breaks
# during torch.compile for piecewise cuda graph
moe_result = trtllm_bf16_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hidden_states,
gemm1_weights=self.w13_weight,
gemm2_weights=self.w2_weight,
num_experts=self.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routing_method_type=self.routing_method_type,
tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]),
)
# Copy result to symmetric memory output
symm_output.copy_(moe_result)
final_hidden_states = symm_output

else:

final_hidden_states = self.quant_method.apply(
layer=self,
dispatch_output=StandardDispatchOutput(
hidden_states=hidden_states,
hidden_states_scale=None,
topk_output=topk_output,
),
).hidden_states

# NOTE for symmetric memory tagging:
# We do not create the context in this function.
# Instead, we create the context and tagging inside each FusedMoEMethodBase
# This can allow fine-grained tagging.

if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

return final_hidden_states


class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
"""
Quantize hidden states using global scale factor from quantization method.

Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.

Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""

# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
hidden_states,
self.w13_input_scale_quant,
16, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)

seq_len, hidden_size = hidden_states.shape
hs_fp4 = hs_fp4_bytes.reshape(seq_len, hidden_size // 2)
# TRT-LLM expects hidden state scales shaped as [seq_len, hidden_size // 16]
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(
seq_len, hidden_size // 16
)

return hs_fp4, hs_sf

def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert TopKOutputChecker.format_is_bypassed(
topk_output
), "Only bypassed topk output is supported for flashinfer fp4 moe"

if is_in_piecewise_cuda_graph():
return flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl(
hidden_states,
topk_output.router_logits,
topk_output.topk_config.top_k,
topk_output.topk_config.topk_group,
topk_output.topk_config.num_expert_group,
topk_output.topk_config.correction_bias,
self.layer_id,
)
else:
return self.forward_impl(hidden_states, topk_output)

def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
"""Forward pass using FP4 TRTLLM kernel.

Args:
hidden_states: Input tensor
topk_output: TopKOutput object with Bypassed format
"""
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe

assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)

assert (
self.moe_runner_config.is_gated
), "Only gated MoEs are supported for flashinfer fp4 moe"

assert TopKOutputChecker.format_is_bypassed(topk_output)

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config

hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
routing_method_type = self.routing_method_type
assert (
routing_method_type is not None
), "flashinfer trtllm moe nvfp4 backend has not been adapted for the current moe layer, you can set routing_method_type (See definition of RoutingMethodType please) for the moe layer explicitly for a quick adaptation."

# DeepSeekV3 style routing requires float32 router logits,
# see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6
if routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(hidden_states.dtype)
)

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
num_tokens = hs_fp4.shape[0]
hidden_size = (
hs_fp4.shape[-1] * 2
if hs_fp4.dtype == torch.uint8
else hs_fp4.shape[-1]
)
symm_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape(
*hs_scale_linear.shape[:-1], -1
),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
# Respect the routing method configured for this layer (e.g., Renormalize for Qwen3),
# instead of always assuming DeepSeekV3.
routing_method_type=(
self.routing_method_type
if self.routing_method_type is not None
else RoutingMethodType.Default
),
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]

return result


@register_custom_op(out_shape="hidden_states")
def moe_forward_piecewise_cuda_graph_impl(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1449,55 +1182,3 @@ def fused_moe_bypassed_piecewise_cuda_graph_impl(
forward_context = get_forward_context()
moe_layer = forward_context.moe_layers[layer_id]
return moe_layer.forward_impl(hidden_states, topk_output)


@register_custom_op(out_shape="hidden_states")
def flashinfer_bf16_moe_forward_piecewise_cuda_graph_impl(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
topk_group: Optional[int],
num_expert_group: Optional[int],
correction_bias: Optional[torch.Tensor],
renormalize: bool,
layer_id: int,
) -> torch.Tensor:
topk_output = BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=TopKConfig(
top_k=top_k,
topk_group=topk_group,
num_expert_group=num_expert_group,
correction_bias=correction_bias,
renormalize=renormalize,
),
)
forward_context = get_forward_context()
moe_layer = forward_context.moe_layers[layer_id]
return moe_layer.forward_impl(hidden_states, topk_output)


@register_custom_op(out_shape="hidden_states")
def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
topk_group: Optional[int],
num_expert_group: Optional[int],
correction_bias: Optional[torch.Tensor],
layer_id: int,
) -> torch.Tensor:
topk_output = BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=TopKConfig(
top_k=top_k,
topk_group=topk_group,
num_expert_group=num_expert_group,
correction_bias=correction_bias,
),
)
forward_context = get_forward_context()
moe_layer = forward_context.moe_layers[layer_id]
return moe_layer.forward_impl(hidden_states, topk_output)
Loading
Loading