Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFusedMoE,
FusedMoE,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import (
Expand Down Expand Up @@ -48,7 +44,6 @@
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip


if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul

Expand Down Expand Up @@ -741,6 +736,22 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE

# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
if quantization == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE,
)

return FlashInferFP4MoE
except:
pass

if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
Expand Down
189 changes: 173 additions & 16 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py

import importlib.util
import datetime
import glob
import logging
import os
import sys
from enum import Enum
from functools import lru_cache
from typing import List, Optional, Tuple

import torch
from packaging import version as pkg_version

from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
Expand All @@ -22,29 +23,66 @@
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_flashinfer_available,
is_hip,
next_power_of_2,
)

if is_flashinfer_available():
from flashinfer import (
RoutingMethodType,
fp4_quantize,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)

_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()


# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe():
try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
except ImportError:
trtllm_fp4_block_scale_moe = None

logger = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
def _is_fp4_quantization_enabled():
"""Check if ModelOpt FP4 quantization is enabled."""
try:
# Use the same simple check that works for class selection
quantization = global_server_args_dict.get("quantization")
return quantization == "modelopt_fp4"
except:
return False


def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim


class FusedMoeWeightScaleSupported(Enum):
Expand Down Expand Up @@ -157,10 +195,6 @@ def __init__(
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_cutlass_moe = (
self.enable_flashinfer_cutlass_moe
)
assert self.quant_method is not None

self.quant_config = quant_config
Expand Down Expand Up @@ -747,7 +781,130 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
routed_scaling_factor=self.routed_scaling_factor,
)

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

return final_hidden_states


class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""

def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)

# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)

super().__init__(*args, **kwargs)

# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor

# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
"""
Quantize hidden states using global scale factor from quantization method.

Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.

Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""

# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
hidden_states,
self.w13_input_scale_quant,
16, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)

hs_fp4 = hs_fp4_bytes.reshape(
hidden_states.shape[0], hidden_states.shape[1] // 2
)
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)

return hs_fp4, hs_sf

def forward(self, hidden_states: torch.Tensor, topk_output):
"""Forward pass using FP4 TRTLLM kernel.

Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""

# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)

_, router_logits = topk_output

hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)

router_logits = router_logits.to(torch.float32)

result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
)[0]

return result


def get_fused_moe_impl_class():
"""Factory function to get the appropriate FusedMoE implementation class."""
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
# Use FP4 variant when FP4 quantization is enabled
return FlashInferFP4MoE
elif should_use_flashinfer_trtllm_moe():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return FlashInferFusedMoE
else:
# Default case
return FusedMoE
16 changes: 16 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
import importlib.util
from enum import Enum
from functools import lru_cache

from packaging import version as pkg_version

from sglang.srt.managers.schedule_batch import global_server_args_dict


@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result


class MoeA2ABackend(Enum):
Expand Down
Loading
Loading