Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model_name: "amd/Qwen3.5-35B-A3B-MXFP4"
accuracy_threshold: 0.89
tolerance: 0.03
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--gpu-memory-utilization 0.35
env:
VLLM_ROCM_USE_AITER: "1"
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
model_name: "amd/Qwen3.5-35B-A3B-MXFP4"
accuracy_threshold: 0.82
accuracy_threshold: 0.89
tolerance: 0.03
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--moe-backend emulation
--gpu-memory-utilization 0.35
3 changes: 2 additions & 1 deletion tests/evals/gsm8k/configs/models-mi3xx.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ DeepSeek-R1-DP_MI325.yaml
DeepSeek-V3.2-TP_MI325.yaml
DeepSeek-V3.2-DP_MI325.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml
3 changes: 2 additions & 1 deletion tests/evals/gsm8k/configs/models-qwen35-mi355.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Qwen3.5-35B-A3B-DEP2.yaml
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-AITER-TP2.yaml
Qwen3.5-35B-A3B-MXFP4-EMU-TP2.yaml
88 changes: 86 additions & 2 deletions tests/quantization/test_gfx950_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for MXFP4 MoE oracle backend selection on mi355x (GFX950).

These tests run on real hardware — no mocks. Skipped on non-GFX950 platforms.
"""

def test_mi355_moe():
print("TODO: add tests for Mi355 MoE quantization")
import pytest
import torch

from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
select_mxfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp4Dynamic,
)
from vllm.platforms import current_platform

ROCM_AVAILABLE = current_platform.is_rocm()
ROCM_GFX950 = False
ROCM_AITER_AVAILABLE = False

if ROCM_AVAILABLE:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms.rocm import on_gfx950

ROCM_GFX950 = on_gfx950()
ROCM_AITER_AVAILABLE = rocm_aiter_ops.is_fused_moe_enabled()


def _make_w4a4_moe_config(moe_backend: str = "auto") -> FusedMoEConfig:
from vllm.model_executor.layers.fused_moe.activation import MoEActivation

return FusedMoEConfig(
num_experts=8,
experts_per_token=2,
hidden_dim=256,
intermediate_size_per_partition=256,
num_local_experts=8,
num_logical_experts=8,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
in_dtype=torch.bfloat16,
device="cuda",
routing_method=RoutingMethodType.Renormalize,
moe_backend=moe_backend,
)


@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
@pytest.mark.skipif(not ROCM_AITER_AVAILABLE, reason="Requires AITER enabled")
def test_w4a4_dispatches_to_aiter():
"""With AITER enabled + GFX950, W4A4 selects AITER_MXFP4_MXFP4."""
config = _make_w4a4_moe_config()
backend, experts_cls = select_mxfp4_moe_backend(
config, activation_key=kMxfp4Dynamic
)
assert backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4
assert experts_cls is not None


@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
@pytest.mark.skipif(
ROCM_AITER_AVAILABLE,
reason="Test requires AITER disabled (unset VLLM_ROCM_USE_AITER)",
)
def test_w4a4_raises_without_aiter_and_no_moe_backend():
"""Without AITER and no --moe-backend, raises NotImplementedError
with hint to use --moe-backend emulation."""
config = _make_w4a4_moe_config()
with pytest.raises(NotImplementedError, match="--moe-backend emulation"):
select_mxfp4_moe_backend(config, activation_key=kMxfp4Dynamic)


@pytest.mark.skipif(not ROCM_GFX950, reason="Requires GFX950 (mi355x)")
def test_w4a4_dispatches_to_emulation_with_moe_backend():
"""With --moe-backend emulation, W4A4 selects EMULATION."""
config = _make_w4a4_moe_config(moe_backend="emulation")
backend, experts_cls = select_mxfp4_moe_backend(
config, activation_key=kMxfp4Dynamic
)
assert backend == Mxfp4MoeBackend.EMULATION
assert experts_cls is not None
76 changes: 74 additions & 2 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QuantKey,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kMxfp4Static,
kMxfp8Dynamic,
)
Expand Down Expand Up @@ -68,6 +69,7 @@ class Mxfp4MoeBackend(Enum):
# Keep the legacy name as an alias while the ROCm split backend rename settles.
AITER = "AITER_MXFP4_BF16"
AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel
AITER_MXFP4_MXFP4 = "AITER_MXFP4_MXFP4" # W4A4: CK kernel
# Triton
TRITON = "TRITON"
TRITON_UNFUSED = "TRITON_UNFUSED"
Expand All @@ -83,6 +85,7 @@ class Mxfp4MoeBackend(Enum):
AITER_BACKENDS = (
Mxfp4MoeBackend.AITER_MXFP4_BF16,
Mxfp4MoeBackend.AITER_MXFP4_FP8,
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
)


Expand Down Expand Up @@ -187,6 +190,13 @@ def backend_to_kernel_cls(

return [AiterW4A8ExpertsMonolithic]

elif backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterMxfp4Experts,
)

return [AiterMxfp4Experts]

elif backend == Mxfp4MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import XPUExpertsMXFp4

Expand Down Expand Up @@ -217,6 +227,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"marlin": Mxfp4MoeBackend.MARLIN,
"aiter": Mxfp4MoeBackend.AITER_MXFP4_BF16,
"aiter_mxfp4_fp8": Mxfp4MoeBackend.AITER_MXFP4_FP8,
"aiter_mxfp4_mxfp4": Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
"xpu": Mxfp4MoeBackend.XPU,
"emulation": Mxfp4MoeBackend.EMULATION,
}
Expand All @@ -237,6 +248,7 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.AITER_MXFP4_BF16,
Mxfp4MoeBackend.AITER_MXFP4_FP8,
Mxfp4MoeBackend.AITER_MXFP4_MXFP4,
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
# TRITON_UNFUSED has bug with MTP support
Expand All @@ -245,7 +257,6 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
Mxfp4MoeBackend.EMULATION,
]
return _AVAILABLE_BACKENDS

Expand Down Expand Up @@ -281,6 +292,8 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
return kMxfp8Dynamic
if backend == Mxfp4MoeBackend.AITER_MXFP4_FP8:
return kFp8StaticTensorSym
if backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
return kMxfp4Dynamic
return None # BF16 activation


Expand Down Expand Up @@ -480,7 +493,12 @@ def _return_or_raise(

if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No MXFP4 MoE backend supports the deployment configuration."
"No MXFP4 MoE backend supports the deployment configuration. "
f"weight_key=kMxfp4Static, activation_key={activation_key}. "
"Native backends require specific hardware. "
"Set `VLLM_LOGGING_LEVEL=DEBUG` to see detailed unsupported reasons. "
"To use the emulation backend for research/debugging, pass "
"--moe-backend emulation."
)

return Mxfp4MoeBackend.NONE, None
Expand Down Expand Up @@ -898,6 +916,49 @@ def _interleave_mxfp4_cutlass_sm90(w):
w2_bias,
)

elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
from vllm._aiter_ops import rocm_aiter_ops

if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)

# e8m0_shuffle on weight scales (GFX950 swizzle layout)
from aiter.utility.fp4_utils import e8m0_shuffle

s0, s1, _ = w13_weight_scale.shape
w13_weight_scale.data = e8m0_shuffle(w13_weight_scale.view(s0 * s1, -1)).view(
s0, s1, -1
)

s0, s1, _ = w2_weight_scale.shape
w2_weight_scale.data = e8m0_shuffle(w2_weight_scale.view(s0 * s1, -1)).view(
s0, s1, -1
)

# View as native FP4 dtype
fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
if fp4_dtype is not None:
w13_weight.data = w13_weight.data.view(fp4_dtype)
w2_weight.data = w2_weight.data.view(fp4_dtype)

# Shuffle weights for AITER CK kernel
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
shuffled_w13.is_shuffled = True
shuffled_w2.is_shuffled = True

return (
shuffled_w13,
shuffled_w2,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)

elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16:
from vllm._aiter_ops import rocm_aiter_ops

Expand Down Expand Up @@ -1452,6 +1513,17 @@ def make_mxfp4_moe_quant_config(
w2_bias=w2_bias,
block_shape=None,
)
elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_MXFP4:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Expand Down
30 changes: 30 additions & 0 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kMxfp4Static,
)

Expand Down Expand Up @@ -327,6 +328,21 @@ def expects_unquantized_inputs(self) -> bool:
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard

@staticmethod
def is_supported_config(
cls, moe_config, weight_key, activation_key, activation_format
):
is_supported, reason = super().is_supported_config(
cls, moe_config, weight_key, activation_key, activation_format
)
if not is_supported and not rocm_aiter_ops.is_fused_moe_enabled():
reason = (
f"{reason}. AITER MoE is not enabled — "
"set VLLM_ROCM_USE_AITER=1 and VLLM_ROCM_USE_AITER_MOE=1 "
"to enable it"
)
return is_supported, reason

@staticmethod
def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled()
Expand Down Expand Up @@ -439,3 +455,17 @@ def apply(
output_dtype=output.dtype,
)
output.copy_(result)


class AiterMxfp4Experts(AiterExperts):
"""MXFP4 W4A4 variant: MXFP4 weights + dynamic MXFP4 activations."""

@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (
kMxfp4Static,
kMxfp4Dynamic,
)
Loading
Loading