Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 185 additions & 2 deletions tests/kernels/moe/test_ocp_mx_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
)

TRTLLM_GEN_MXFP8_AVAILABLE = TRTLLM_GEN_MXFP4_AVAILABLE

HOPPER_MXFP4_BF16_AVAILABLE = (
current_platform.is_cuda()
and current_platform.is_device_capability(90)
Expand All @@ -34,9 +36,15 @@
shuffle_matrix_a,
shuffle_matrix_sf_a,
trtllm_fp4_block_scale_moe,
trtllm_fp8_block_scale_moe,
)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache

if TRTLLM_GEN_MXFP8_AVAILABLE:
from flashinfer.fused_moe.core import (
Fp8QuantizationType,
get_w2_permute_indices_with_cache,
)


@dataclass
Expand Down Expand Up @@ -160,6 +168,7 @@ def reference_moe(
beta,
limit,
act_type,
is_gated,
):
# renormalize routing
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
Expand All @@ -170,7 +179,12 @@ def reference_moe(
mlp1_weight = w13[expert_indices, ...]
mlp1_bias = bias13[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
if is_gated:
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
else:
# RELU2_NO_MUL: relu(x)^2
t = torch.relu(t)
t = t * t

if act_type == "mxfp8":
t_quantized, t_scale = mxfp8_quantize(
Expand Down Expand Up @@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
beta,
limit,
act_type,
is_gated=True,
)
ref_result[start_idx:end_idx].copy_(chunk_result)

Expand Down Expand Up @@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
beta,
limit,
"bf16",
is_gated=True,
)

from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
Expand Down Expand Up @@ -890,6 +906,7 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
beta,
limit,
"mxfp8",
is_gated=True,
)

# Prepare inputs for FlashInfer CUTLASS fused MoE
Expand Down Expand Up @@ -965,3 +982,169 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):

# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)


@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("is_gated", [True], ids=["gated"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is non-gated not supported?

Copy link
Contributor Author

@danisereb danisereb Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-gated MoE is not supported yet (in flashinfer 0.6.4), working on it:
flashinfer-ai/flashinfer#2707

@pytest.mark.skipif(
not TRTLLM_GEN_MXFP8_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test",
)
def test_trtllm_gen_mxfp8_block_scale_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
is_gated: bool,
):
torch.manual_seed(42)
device = "cuda:0"

inter_size = intermediate_size * (2 if is_gated else 1)

hidden_states = (
torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) / 20
)
w13 = (
torch.randn(
num_experts,
inter_size,
hidden_size,
device=device,
dtype=torch.bfloat16,
)
/ 20
)
w2 = (
torch.randn(
num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16,
)
/ 20
)
router_logits = torch.rand(
num_tokens, num_experts, dtype=torch.float32, device=device
)
router_logits_kernel = router_logits.to(torch.bfloat16)

# Quantize weights to MXFP8 and normalize scales to [E, M, K//32].
w13_q, w13_scale = mxfp8_quantize(w13, is_sf_swizzled_layout=False)
w2_q, w2_scale = mxfp8_quantize(w2, is_sf_swizzled_layout=False)
if w13_scale.ndim == 1:
w13_scale = w13_scale.view(
num_experts,
inter_size,
hidden_size // 32,
)
if w2_scale.ndim == 1:
w2_scale = w2_scale.view(num_experts, hidden_size, intermediate_size // 32)

# Quantize activations to MXFP8.
hidden_states_q, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False
)
if hidden_states_scale.ndim == 1:
hidden_states_scale = hidden_states_scale.view(num_tokens, hidden_size // 32)

# Reference output using dequantized tensors + MXFP8 intermediate quantization.
w13_ref = mxfp8_dequantize(w13_q, w13_scale).to(torch.float32)
w2_ref = mxfp8_dequantize(w2_q, w2_scale).to(torch.float32)
hidden_states_ref = mxfp8_dequantize(hidden_states_q, hidden_states_scale).to(
torch.float32
)
bias13 = torch.zeros(
num_experts,
intermediate_size * (2 if is_gated else 1),
device=device,
)
bias2 = torch.zeros(num_experts, hidden_size, device=device)
ref = reference_moe(
router_logits_kernel.to(torch.float32),
topk,
num_experts,
hidden_states_ref,
w13_ref,
bias13,
w2_ref,
bias2,
alpha=1.0,
beta=0.0,
limit=None,
act_type="mxfp8",
is_gated=is_gated,
)

# Shuffle weights/scales with the same indexed layout used by TRTLLM kernels.
epilogue_tile_m = 128
gemm1_weights_shuffled = []
gemm1_scales_shuffled = []
gemm2_weights_shuffled = []
gemm2_scales_shuffled = []
for i in range(num_experts):
w13_rows = intermediate_size * (2 if is_gated else 1)
w13_interleaved = w13_q[i].clone().reshape(w13_rows, -1)
w13_scale_interleaved = w13_scale[i].clone().reshape(w13_rows, -1)
if is_gated:
w13_interleaved = reorder_rows_for_gated_act_gemm(w13_interleaved)
w13_scale_interleaved = reorder_rows_for_gated_act_gemm(
w13_scale_interleaved
)
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_interleaved.view(torch.uint8), epilogue_tile_m)
.contiguous()
.view(w13_q.dtype)
)
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_q[i].view(torch.uint8), epilogue_tile_m)
.contiguous()
.view(w2_q.dtype)
)

gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(
w13_scale_interleaved.view(torch.uint8).reshape(w13_rows, -1),
epilogue_tile_m,
)
.contiguous()
.view(w13_scale.dtype)
)
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), epilogue_tile_m
)
.contiguous()
.view(w2_scale.dtype)
)

out = trtllm_fp8_block_scale_moe(
routing_logits=router_logits_kernel,
routing_bias=None,
hidden_states=hidden_states_q,
hidden_states_scale=hidden_states_scale,
gemm1_weights=torch.stack(gemm1_weights_shuffled),
gemm1_weights_scale=torch.stack(gemm1_scales_shuffled),
gemm2_weights=torch.stack(gemm2_weights_shuffled),
gemm2_weights_scale=torch.stack(gemm2_scales_shuffled),
num_experts=num_experts,
top_k=topk,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routing_method_type=1, # renormalize routing
use_shuffled_weight=True,
weight_layout=0, # MajorK
fp8_quantization_type=Fp8QuantizationType.MxFp8,
)

# Block-scale MXFP8 kernels are approximate; require majority close.
check_accuracy(ref, out, atol=0.1, rtol=0.85, percent=0.8)
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,17 +1204,26 @@ def weight_loader(
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern()
quant_method = getattr(param, "quant_method", None)

# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# weights scales.
# Input scales are always per-tensor.
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
# NOTE: ModelOpt MXFP8 MoE uses block scales in weight_scale
# tensors (quant_method=BLOCK), so those must not be treated
# as per-tensor scalars here.
is_block_weight_scale = (
"weight_scale" in weight_name
and quant_method == FusedMoeWeightScaleSupported.BLOCK.value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, I think MXFP8 should be a GROUP scale rather than BLOCK since it is (1, 32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I followed what ModelOpt NVFP4 ModelOptNvFp4FusedMoE uses - FusedMoeWeightScaleSupported.BLOCK.
I assume NVFP4 would also have to use GROUP since it is (1, 16).

)
is_per_tensor = (
"weight_scale_2" in weight_name
if uses_weight_scale_2
else "weight_scale" in weight_name
) or "input_scale" in weight_name
is_per_tensor = is_per_tensor and not is_block_weight_scale
if is_per_tensor:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
Expand Down
44 changes: 44 additions & 0 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum

from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig

logger = init_logger(__name__)


class MxFp8MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"


def select_mxfp8_moe_backend(
config: FusedMoEConfig,
) -> MxFp8MoeBackend:
if config.is_lora_enabled:
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")

AVAILABLE_BACKENDS = [
MxFp8MoeBackend.FLASHINFER_TRTLLM,
]

runner_backend = config.moe_backend
if runner_backend != "auto":
mapping = {
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
}
if backend := mapping.get(runner_backend):
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
)
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
f"Expected one of {list(mapping.keys())}."
)

# Auto-select: only one backend available for now.
backend = AVAILABLE_BACKENDS[0]
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend
Loading