Skip to content
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/token_dispatcher/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(self, moe_runner_config: MoeRunnerConfig):
backend = get_moe_runner_backend()
self.enable_flashinfer_cutlass_moe = backend.is_flashinfer_cutlass()
self.enable_flashinfer_mxfp4_moe = backend.is_flashinfer_mxfp4()
self.enable_flashinfer_cutlass_wmxfp4a16_moe = (
backend.is_flashinfer_cutlass_wmxfp4a16()
)
self.enable_flashinfer_trtllm_routed_moe = backend.is_flashinfer_trtllm_routed()
# Skip local expert mapping when the backend handles EP with global expert IDs:
# - cutlass / cutedsl / trtllm_routed handle EP internally
Expand All @@ -102,6 +105,7 @@ def __init__(self, moe_runner_config: MoeRunnerConfig):
self.enable_flashinfer_mxfp4_moe
and envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get()
)
or self.enable_flashinfer_cutlass_wmxfp4a16_moe
)
self.num_experts = moe_runner_config.num_experts
self.num_local_experts = moe_runner_config.num_local_experts
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER_TRTLLM_ROUTED = "flashinfer_trtllm_routed"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
FLASHINFER_CUTLASS_WMXFP4A16 = "flashinfer_cutlass_wmxfp4a16"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
CUTLASS = "cutlass"
MARLIN = "marlin"
Expand Down Expand Up @@ -107,6 +108,9 @@ def is_flashinfer_cutedsl(self):
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4

def is_flashinfer_cutlass_wmxfp4a16(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS_WMXFP4A16

def is_cutlass(self):
return self == MoeRunnerBackend.CUTLASS

Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,15 @@ def get_quant_method(
)

return Mxfp4FlashinferTrtllmMoEMethod(fp8_method, prefix=prefix)
if (
self.is_fp4_experts
and get_moe_runner_backend().is_flashinfer_cutlass_wmxfp4a16()
):
from sglang.srt.layers.quantization.wmxfp4a16_flashinfer_cutlass_moe import (
Wmxfp4A16FlashinferCutlassMoEMethod,
)

return Wmxfp4A16FlashinferCutlassMoEMethod(fp8_method, prefix=prefix)
return fp8_method
elif isinstance(layer, RadixAttention):
return Fp8KVCacheMethod(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput


from sglang.srt.environ import envs
from sglang.srt.utils.common import get_bool_env_var

Expand Down Expand Up @@ -132,7 +133,6 @@ def __init__(self, fp8_method, prefix: str):

def create_moe_runner(self, layer, moe_runner_config):
self.moe_runner_config = moe_runner_config

swiglu_limit = moe_runner_config.swiglu_limit
assert (
swiglu_limit is not None
Expand Down
Loading
Loading