From 7a6e293d3fbb89d9dbc96d6498ab4792d90797da Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 29 Apr 2026 20:55:02 +0000 Subject: [PATCH 1/2] Refactor W4A8 (w_mxfp4_a_fp8) to use oracle backend selection - Add oracle backend selection for MXFP4 MOE - Add unittest cases, fix w4a8 weight re-assign - Refactor kernel selection and move out aiter kernel Co-Authored-By: Claude Opus 4.6 Signed-off-by: Bowen Bao --- tests/kernels/moe/test_ocp_mx_moe.py | 396 +++++++++++++++++- .../experts/gpt_oss_triton_kernels_moe.py | 200 +++++++-- .../layers/fused_moe/oracle/mxfp4.py | 139 +++++- .../layers/quantization/mxfp4.py | 6 +- .../layers/quantization/quark/quark_moe.py | 266 +++--------- 5 files changed, 757 insertions(+), 250 deletions(-) diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index aefc35324d86..8ed7757f6553 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -28,6 +28,25 @@ and has_flashinfer() ) +# ROCm platform and dependencies +ROCM_AVAILABLE = current_platform.is_rocm() +ROCM_TRITON_KERNELS_AVAILABLE = False +ROCM_AITER_AVAILABLE = False +ROCM_GFX950 = False + +if ROCM_AVAILABLE: + from vllm._aiter_ops import rocm_aiter_ops + from vllm.platforms.rocm import on_gfx950 + from vllm.utils.import_utils import has_triton_kernels + + ROCM_TRITON_KERNELS_AVAILABLE = has_triton_kernels() + ROCM_GFX950 = on_gfx950() + ROCM_AITER_AVAILABLE = rocm_aiter_ops.is_enabled() + + if ROCM_AITER_AVAILABLE: + from aiter.ops.triton.moe.quant_moe import upcast_from_mxfp + from aiter.ops.triton.quant import dynamic_mxfp4_quant + if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import ( fp4_quantize, @@ -111,6 +130,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None): # Note we add an extra bias of 1 to the linear layer + # Uses chunked layout: first half is gate, second half is up x_glu, x_linear = torch.chunk(x, 2, dim=-1) if limit is not None: x_glu = x_glu.clamp(max=limit) @@ -119,6 +139,16 @@ def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = Non return out_glu * (x_linear + beta) +def swigluoai(x, alpha: float = 1.702, limit: float = 7.0): + # OAI swiglu uses interleaved layout: gate/up alternating + # See SwigluOAIAndMul in vllm/model_executor/layers/activation.py + gate, up = x[..., ::2], x[..., 1::2] + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1) * glu + + fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6] @@ -168,8 +198,20 @@ def reference_moe( beta, limit, act_type, - is_gated, + activation: str = "swiglu", + use_interleaved_layout: bool = False, ): + """ + Reference MoE implementation for accuracy testing. + + Args: + activation: One of "swiglu", "silu", "relu2". Controls the activation + function used after the first MLP. + use_interleaved_layout: If True, uses interleaved gate/up layout + (gate=x[..., ::2], up=x[..., 1::2]) as used by SWIGLUOAI. + If False, uses chunked layout (gate, up = chunk(x, 2)) as used + by standard swiglu/silu. + """ # renormalize routing experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) expert_weights = torch.nn.functional.softmax(experts.values, dim=1) @@ -179,12 +221,21 @@ def reference_moe( mlp1_weight = w13[expert_indices, ...] mlp1_bias = bias13[expert_indices, ...] t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias - if is_gated: - t = swiglu(t, alpha=alpha, beta=beta, limit=limit) - else: + + # Apply activation + if activation in ("swiglu", "silu"): + if use_interleaved_layout: + # SWIGLUOAI: interleaved gate/up layout + t = swigluoai(t, alpha=alpha, limit=limit) + else: + # Standard swiglu/silu: chunked layout + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + elif activation == "relu2": # RELU2_NO_MUL: relu(x)^2 t = torch.relu(t) t = t * t + else: + raise ValueError(f"Unknown activation: {activation}") if act_type == "mxfp8": t_quantized, t_scale = mxfp8_quantize( @@ -585,7 +636,8 @@ def test_trtllm_gen_mxfp4_fused_moe( beta, limit, act_type, - is_gated=True, + activation="swiglu", + use_interleaved_layout=False, ) ref_result[start_idx:end_idx].copy_(chunk_result) @@ -722,7 +774,8 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( beta, limit, "bf16", - is_gated=True, + activation="swiglu", + use_interleaved_layout=False, ) from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe @@ -908,7 +961,8 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor): beta, limit, "mxfp8", - is_gated=True, + activation="swiglu", + use_interleaved_layout=False, ) # Prepare inputs for FlashInfer CUTLASS fused MoE @@ -1080,7 +1134,8 @@ def test_trtllm_gen_mxfp8_block_scale_moe( beta=0.0, limit=None, act_type="mxfp8", - is_gated=is_gated, + activation="swiglu" if is_gated else "relu2", + use_interleaved_layout=False, ) # Shuffle weights/scales with the same indexed layout used by TRTLLM kernels. @@ -1150,3 +1205,328 @@ def test_trtllm_gen_mxfp8_block_scale_moe( # Block-scale MXFP8 kernels are approximate; require majority close. check_accuracy(ref, out, atol=0.1, rtol=0.85, percent=0.8) + + +# ----------------------------------------------------------------------------- +# ROCm Oracle-based kernel execution tests +# ----------------------------------------------------------------------------- +# TODO: Further tighten the accuracy threshold. +# - More accurate ref moe to include activation quantization +# - Check aiter kernel accuracy. E.g., quant / dequant details. +ROCM_BACKEND_CONFIGS = { + "TRITON": { + "activation": "SWIGLUOAI", + "rtol": 0.3, + "percent": 0.95, + "requires_aiter": False, + "requires_gfx950": False, + }, + "TRITON_UNFUSED": { + "activation": "SWIGLUOAI", + "rtol": 0.3, + "percent": 0.95, + "requires_aiter": False, + "requires_gfx950": False, + }, + "AITER_MXFP4_BF16": { + "activation": "SILU", + "rtol": 1.0, + "percent": 0.7, + "requires_aiter": True, + "requires_gfx950": True, + }, + "AITER_MXFP4_FP8": { + "activation": "SWIGLUOAI", + "rtol": 0.5, + "percent": 0.9, + "requires_aiter": True, + "requires_gfx950": True, + }, +} + + +@pytest.mark.parametrize("backend_name", list(ROCM_BACKEND_CONFIGS.keys())) +@pytest.mark.parametrize("topk", [4]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("num_tokens,hidden_size,intermediate_size", [(16, 256, 256)]) +@pytest.mark.skipif( + not ROCM_AVAILABLE, + reason="ROCm is required for this test", +) +@torch.inference_mode() +def test_rocm_mxfp4_moe_oracle( + backend_name: str, + topk: int, + num_experts: int, + num_tokens: int, + hidden_size: int, + intermediate_size: int, +): + """ + Test ROCm MXFP4 MoE using oracle functions. + + This test validates that the oracle functions work end-to-end: + - select_mxfp4_moe_backend() selects a valid backend + - convert_to_mxfp4_moe_kernel_format() converts weights without error + - make_mxfp4_moe_quant_config() builds a valid quant config + - make_mxfp4_moe_kernel() creates a kernel that runs without error + - The kernel output is within accuracy tolerance of reference + """ + config = ROCM_BACKEND_CONFIGS[backend_name] + + # Check platform requirements + if not ROCM_TRITON_KERNELS_AVAILABLE: + pytest.skip("triton_kernels required for quantization") + if config["requires_aiter"] and not ROCM_AITER_AVAILABLE: + pytest.skip(f"Backend {backend_name} requires AITER") + if config["requires_gfx950"] and not ROCM_GFX950: + pytest.skip(f"Backend {backend_name} requires GFX950") + + from vllm.config import VllmConfig, set_current_vllm_config + from vllm.model_executor.layers.fused_moe.activation import MoEActivation + from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + backend_to_kernel_cls, + convert_to_mxfp4_moe_kernel_format, + make_mxfp4_moe_kernel, + make_mxfp4_moe_quant_config, + ) + from vllm.v1.worker.workspace import init_workspace_manager + + # Initialize workspace manager (needed for modular kernels) + init_workspace_manager(torch.accelerator.current_device_index()) + + # Map string to enum + backend = Mxfp4MoeBackend[backend_name] + + # Get experts class from oracle + experts_cls_list = backend_to_kernel_cls(backend) + if experts_cls_list is None or len(experts_cls_list) == 0: + pytest.skip(f"Backend {backend_name} not available") + + # Use first experts class + experts_cls = experts_cls_list[0] + + torch.manual_seed(42) + dtype = torch.bfloat16 + device = "cuda:0" + + # Create MoE config with Renormalize routing (required by monolithic kernels) + from vllm.model_executor.layers.fused_moe import FusedMoEConfig + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + RoutingMethodType, + ) + + moe_config = FusedMoEConfig( + num_experts=num_experts, + experts_per_token=topk, + hidden_dim=hidden_size, + intermediate_size_per_partition=intermediate_size, + num_local_experts=num_experts, + num_logical_experts=num_experts, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + activation=MoEActivation[config["activation"]], + in_dtype=dtype, + device="cuda", + routing_method=RoutingMethodType.Renormalize, + ) + + # Create float weights in checkpoint format: + # w13: [num_experts, 2*intermediate_size, hidden_size] + # w2: [num_experts, hidden_size, intermediate_size] + w13_float = torch.randn( + num_experts, 2 * intermediate_size, hidden_size, dtype=dtype, device=device + ) + w2_float = torch.randn( + num_experts, hidden_size, intermediate_size, dtype=dtype, device=device + ) + + # dynamic_mxfp4_quant expects 2D input, so reshape 3D weights + # w13: [E, 2*I, H] -> [E*2*I, H] -> quantize -> [E, 2*I, H//2] + # w2: [E, H, I] -> [E*H, I] -> quantize -> [E, H, I//2] + w13_2d = w13_float.reshape(-1, hidden_size) + w13_quant_2d, w13_scale_2d = dynamic_mxfp4_quant(w13_2d) + w13_quant = w13_quant_2d.reshape(num_experts, 2 * intermediate_size, -1) + w13_scale = w13_scale_2d.reshape(num_experts, 2 * intermediate_size, -1) + + w2_2d = w2_float.reshape(-1, intermediate_size) + w2_quant_2d, w2_scale_2d = dynamic_mxfp4_quant(w2_2d) + w2_quant = w2_quant_2d.reshape(num_experts, hidden_size, -1) + w2_scale = w2_scale_2d.reshape(num_experts, hidden_size, -1) + + w13_bias = torch.randn( + num_experts, 2 * intermediate_size, dtype=dtype, device=device + ) + w2_bias = torch.randn(num_experts, hidden_size, dtype=dtype, device=device) + + # Create static input scales for W4A8 backend (AITER_MXFP4_FP8) + w13_input_scale: torch.Tensor | None = None + w2_input_scale: torch.Tensor | None = None + if backend_name == "AITER_MXFP4_FP8": + # Static FP8 scales: one scale per expert + w13_input_scale = torch.ones(num_experts, dtype=torch.float32, device=device) + w2_input_scale = torch.ones(num_experts, dtype=torch.float32, device=device) + + # Create mock layer for oracle functions + class MockLayer: + w13_weight: torch.Tensor + w2_weight: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor + w13_input_scale: torch.Tensor | None + w2_input_scale: torch.Tensor | None + + layer = MockLayer() + layer.w13_weight = w13_quant + layer.w2_weight = w2_quant + layer.w13_weight_scale = w13_scale + layer.w2_weight_scale = w2_scale + layer.w13_input_scale = w13_input_scale + layer.w2_input_scale = w2_input_scale + + # Convert weights using oracle + w13_conv, w2_conv, w13_scale_conv, w2_scale_conv, w13_bias_conv, w2_bias_conv = ( + convert_to_mxfp4_moe_kernel_format( + mxfp4_backend=backend, + layer=layer, # type: ignore[arg-type] + w13_weight=w13_quant, + w2_weight=w2_quant, + w13_weight_scale=w13_scale, + w2_weight_scale=w2_scale, + w13_bias=w13_bias, + w2_bias=w2_bias, + ) + ) + + # Build quant config using oracle + quant_config = make_mxfp4_moe_quant_config( + mxfp4_backend=backend, + w1_scale=w13_scale_conv, + w2_scale=w2_scale_conv, + w1_bias=w13_bias_conv, + w2_bias=w2_bias_conv, + a1_scale=w13_input_scale, + a2_scale=w2_input_scale, + ) + + # Select activation based on backend + activation_name = str(config["activation"]) + activation = MoEActivation[activation_name] + + # Build kernel using oracle + assert quant_config is not None, "Failed to create quant config" + with set_current_vllm_config(VllmConfig()): + kernel = make_mxfp4_moe_kernel( + moe_quant_config=quant_config, + moe_config=moe_config, + mxfp4_backend=backend, + experts_cls=experts_cls, + routing_tables=None, + shared_experts=None, + ) + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + router_logits = torch.randn( + num_tokens, num_experts, dtype=torch.float32, device=device + ) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + # Run kernel - use appropriate method based on impl type + if kernel.is_monolithic: + # Monolithic impl uses router_logits + out = kernel.apply_monolithic( + hidden_states=x, + w1=w13_conv, + w2=w2_conv, + router_logits=router_logits, + activation=activation, + global_num_experts=num_experts, + expert_map=None, + apply_router_weight_on_input=False, + ) + else: + # Modular impl uses topk_weights and topk_ids + out = kernel.apply( + hidden_states=x, + w1=w13_conv, + w2=w2_conv, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=num_experts, + expert_map=None, + apply_router_weight_on_input=False, + ) + + # Verify output is valid (no NaN/Inf) and has expected shape + assert out.shape == (num_tokens, hidden_size), f"Unexpected shape: {out.shape}" + assert not torch.any(torch.isnan(out)), "Output contains NaN" + assert not torch.any(torch.isinf(out)), "Output contains Inf" + + # Verify output has reasonable magnitude (not all zeros) + assert out.abs().max() > 0.01, "Output is effectively zero" + + # Dequantize weights for reference computation + w13_dq = upcast_from_mxfp( + w13_quant.view(torch.uint8), w13_scale, torch.bfloat16, axis=-1 + ) + w2_dq = upcast_from_mxfp( + w2_quant.view(torch.uint8), w2_scale, torch.bfloat16, axis=-1 + ) + + # Determine activation type and layout + # SWIGLUOAI uses interleaved layout (gate/up alternating) + # SILU uses chunked layout (first half gate, second half up) + use_interleaved = activation == MoEActivation.SWIGLUOAI + if activation in [MoEActivation.SWIGLUOAI, MoEActivation.SILU]: + act_name = "swiglu" + else: + act_name = "relu2" + + ref = reference_moe( + router_logits, + topk, + num_experts, + x.to(torch.float32), + w13_dq.to(torch.float32), + w13_bias.to(torch.float32), + w2_dq.to(torch.float32), + w2_bias.to(torch.float32), + alpha=1.702 if activation == MoEActivation.SWIGLUOAI else 1.0, + beta=1.0 if activation == MoEActivation.SWIGLUOAI else 0.0, + limit=7.0 if activation == MoEActivation.SWIGLUOAI else None, + act_type="bf16", + activation=act_name, + use_interleaved_layout=use_interleaved, + ) + + # Compute and print accuracy statistics + diff = (ref.float() - out.float()).abs() + rel_diff = diff / (ref.float().abs() + 1e-6) + + print(f"\n[{backend_name}] Accuracy statistics:") + print( + f" Reference: min={ref.min():.4f}, max={ref.max():.4f}, mean={ref.mean():.4f}" + ) + print( + f" Output: min={out.min():.4f}, max={out.max():.4f}, mean={out.mean():.4f}" + ) + print( + f" Abs diff: min={diff.min():.4f}, max={diff.max():.4f}, " + f"mean={diff.mean():.4f}" + ) + print( + f" Rel diff: min={rel_diff.min():.4f}, max={rel_diff.max():.4f}, " + f"mean={rel_diff.mean():.4f}" + ) + + # Check what percentage of values are within various tolerances + for rtol in [0.1, 0.5, 1.0, 2.0]: + within_tol = (diff <= rtol * out.float().abs()).float().mean() + print(f" Within rtol={rtol}: {within_tol * 100:.1f}%") + + # Check accuracy using per-backend thresholds + check_accuracy(ref, out, atol=0.1, rtol=config["rtol"], percent=config["percent"]) diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index ac317ac7762c..85fda18b7e9e 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kFp8StaticTensorSym, kMxfp4Static, ) from vllm.platforms import current_platform @@ -269,7 +270,7 @@ def pack_bitmatrix( tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) -def triton_kernel_moe_forward( +def aiter_triton_kernel_w4a8_moe_forward( hidden_states: torch.Tensor, w1, # Tensor or triton_kernels.Tensor w2, # Tensor or triton_kernels.Tensor @@ -285,36 +286,54 @@ def triton_kernel_moe_forward( unpadded_K_w1=None, unpadded_N_w2=None, unpadded_K_w2=None, -) -> torch.Tensor: - if ( +): + assert ( quant_config is not None and quant_config.use_mxfp4_w4a8 and rocm_aiter_ops.is_enabled() - ): - from aiter.ops.triton.moe_routing.routing import routing as aiter_routing + ) + from aiter.ops.triton.moe_routing.routing import routing as aiter_routing - routing_data, gather_idx, scatter_idx = aiter_routing( - gating_output, topk, sm_first=not renormalize - ) - return triton_kernel_fused_mxfp4_w4a8_experts( - None, - hidden_states, - w1, - w2, - routing_data, - gather_idx, - scatter_idx, - activation=activation.value, - quant_config=quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - unpadded_N_w1=unpadded_N_w1, - unpadded_K_w1=unpadded_K_w1, - unpadded_N_w2=unpadded_N_w2, - unpadded_K_w2=unpadded_K_w2, - ) + routing_data, gather_idx, scatter_idx = aiter_routing( + gating_output, topk, sm_first=not renormalize + ) + return triton_kernel_fused_mxfp4_w4a8_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation.value, + quant_config=quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + unpadded_N_w1=unpadded_N_w1, + unpadded_K_w1=unpadded_K_w1, + unpadded_N_w2=unpadded_N_w2, + unpadded_K_w2=unpadded_K_w2, + ) + +def triton_kernel_moe_forward( + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: MoEActivation = MoEActivation.SWIGLUOAI, + quant_config: FusedMoEQuantConfig | None = None, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +) -> torch.Tensor: from triton_kernels.topk import topk as topk_fn sm_first = not renormalize @@ -1153,3 +1172,132 @@ def apply( quant_config=self.quant_config, apply_router_weight_on_input=apply_router_weight_on_input, ) + + +class AiterW4A8ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): + """ + Monolithic MXFP4 W4A8 expert using AITER triton kernels. + + This backend uses: + - aiter.ops.triton.moe_routing.routing for routing + - aiter.ops.triton.moe_op_gemm_a8w4.moe_gemm_a8w4 for computation + + Weight format: MXFP4 weights with GFX950 swizzle + Activation: Static FP8 quantization + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + self.topk = moe_config.experts_per_token + self.renormalize = moe_config.routing_method in ( + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + # Requires AITER and GFX950 + if not rocm_aiter_ops.is_enabled(): + return False + from vllm.platforms.rocm import on_gfx950 + + return on_gfx950() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # W4A8: MXFP4 weights with static FP8 activations + SUPPORTED_W_A = [ + (kMxfp4Static, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + # Only SILU activation (swiglu) is supported + return activation == MoEActivation.SWIGLUOAI + + @staticmethod + def _supports_parallel_config( + moe_parallel_config: FusedMoEParallelConfig, + ) -> bool: + return ( + not moe_parallel_config.use_all2all_kernels + and not moe_parallel_config.enable_eplb + and moe_parallel_config.dp_size <= 1 + ) + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return routing_method in [ + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + return True + + def supports_expert_map(self) -> bool: + return False # Expert parallelism not yet supported + + @property + def expects_unquantized_inputs(self) -> bool: + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + assert self.moe_config.intermediate_size_per_partition_unpadded is not None + assert self.moe_config.hidden_dim_unpadded is not None + return aiter_triton_kernel_w4a8_moe_forward( + hidden_states=hidden_states, + w1=w1, + w2=w2, + gating_output=router_logits, + topk=self.topk, + renormalize=self.renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + unpadded_N_w1=self.moe_config.intermediate_size_per_partition_unpadded * 2, + unpadded_K_w1=self.moe_config.hidden_dim_unpadded, + unpadded_N_w2=self.moe_config.hidden_dim_unpadded, + unpadded_K_w2=self.moe_config.intermediate_size_per_partition_unpadded, + ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..ac1d8d368817 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -19,6 +19,7 @@ FusedMoEQuantConfig, FusedMoEQuantDesc, mxfp4_mxfp8_moe_quant_config, + mxfp4_w4a8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) @@ -26,9 +27,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8Dynamic128Sym, + kFp8StaticTensorSym, kMxfp4Static, kMxfp8Dynamic, ) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import all_close_1d from vllm.platforms import current_platform from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import round_up @@ -59,8 +62,9 @@ class Mxfp4MoeBackend(Enum): # Marlin BATCHED_MARLIN = "BATCHED_MARLIN" MARLIN = "MARLIN" - # ROCm AITER - AITER = "AITER" + # ROCm AITER backends + AITER_MXFP4_BF16 = "AITER_MXFP4_BF16" # W4A16: CK kernel + AITER_MXFP4_FP8 = "AITER_MXFP4_FP8" # W4A8: triton kernel # Triton TRITON = "TRITON" TRITON_UNFUSED = "TRITON_UNFUSED" @@ -70,6 +74,13 @@ class Mxfp4MoeBackend(Enum): EMULATION = "EMULATION" +# AITER backends group +AITER_BACKENDS = ( + Mxfp4MoeBackend.AITER_MXFP4_BF16, + Mxfp4MoeBackend.AITER_MXFP4_FP8, +) + + # Backends that share the same TRTLLM weight format TRTLLM_BACKENDS = ( Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, @@ -144,13 +155,20 @@ def backend_to_kernel_cls( return [BatchedMarlinExperts] - elif backend == Mxfp4MoeBackend.AITER: + elif backend == Mxfp4MoeBackend.AITER_MXFP4_BF16: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( AiterExperts, ) return [AiterExperts] + elif backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: + from vllm.model_executor.layers.fused_moe.experts.gpt_oss_triton_kernels_moe import ( # noqa: E501 + AiterW4A8ExpertsMonolithic, + ) + + return [AiterW4A8ExpertsMonolithic] + elif backend == Mxfp4MoeBackend.XPU: from vllm.model_executor.layers.fused_moe.experts.xpu_moe import XPUExpertsMXFp4 @@ -178,7 +196,8 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: "triton": Mxfp4MoeBackend.TRITON, "triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED, "marlin": Mxfp4MoeBackend.MARLIN, - "aiter": Mxfp4MoeBackend.AITER, + "aiter": Mxfp4MoeBackend.AITER_MXFP4_BF16, # W4A16 + "aiter_mxfp4_fp8": Mxfp4MoeBackend.AITER_MXFP4_FP8, # W4A8 "xpu": Mxfp4MoeBackend.XPU, "emulation": Mxfp4MoeBackend.EMULATION, } @@ -197,7 +216,8 @@ def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]: """ _AVAILABLE_BACKENDS = [ Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, - Mxfp4MoeBackend.AITER, + Mxfp4MoeBackend.AITER_MXFP4_BF16, + Mxfp4MoeBackend.AITER_MXFP4_FP8, Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, # TRITON_UNFUSED has bug with MTP support @@ -238,16 +258,28 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): return kMxfp8Dynamic - return None + if backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: + return kFp8StaticTensorSym + return None # BF16 activation -def select_gpt_oss_mxfp4_moe_backend( +def select_mxfp4_moe_backend( config: FusedMoEConfig, + activation_key: QuantKey | None = None, ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: """ Select the primary MXFP4 MoE backend. + + Args: + config: MoE configuration + activation_key: Optional activation quantization key. If provided, + overrides the default activation key for backend selection. + Use kFp8StaticTensorSym for W4A8 scheme. + Note: Shape-specific fallbacks may still occur at runtime. """ + # If activation_key is explicitly provided (e.g., W4A8), use it + requested_activation_key = activation_key device_capability = current_platform.get_device_capability() triton_kernels_supported = ( has_triton_kernels() @@ -316,11 +348,17 @@ def _return_or_raise( and requested_backend == Mxfp4MoeBackend.MARLIN ): requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN + # Use requested_activation_key if provided, otherwise use backend default + act_key = ( + requested_activation_key + if requested_activation_key is not None + else _backend_activation_key(requested_backend) + ) return _return_or_raise( requested_backend, config, kMxfp4Static, - _backend_activation_key(requested_backend), + act_key, activation_format, ) @@ -392,10 +430,15 @@ def _return_or_raise( ) for backend in AVAILABLE_BACKENDS: - activation_key = _backend_activation_key(backend) + # Use requested_activation_key if provided, otherwise use backend default + act_key = ( + requested_activation_key + if requested_activation_key is not None + else _backend_activation_key(backend) + ) for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( - k_cls, config, kMxfp4Static, activation_key, activation_format + k_cls, config, kMxfp4Static, act_key, activation_format ) if supported: logger.info_once(_make_log_backend(backend)) @@ -422,7 +465,7 @@ def _return_or_raise( return Mxfp4MoeBackend.NONE, None -def select_mxfp4_moe_backend( +def select_deepseek_v4_mxfp4_moe_backend( config: FusedMoEConfig, ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: """ @@ -806,7 +849,7 @@ def _interleave_mxfp4_cutlass_sm90(w): w2_bias, ) - elif mxfp4_backend == Mxfp4MoeBackend.AITER: + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_BF16: from vllm._aiter_ops import rocm_aiter_ops if w13_bias is not None: @@ -868,6 +911,63 @@ def _interleave_mxfp4_cutlass_sm90(w): w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: + # W4A8: MXFP4 weights + static FP8 activations (triton kernel) + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + from triton_kernels.numerics import InFlexData + + if w13_bias is not None: + w13_bias = w13_bias.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.to(torch.float32) + + # Process static FP8 input scales (reduce to scalar, warn if not uniform) + w13_input_scale = layer.w13_input_scale + w2_input_scale = layer.w2_input_scale + if w13_input_scale is None or w2_input_scale is None: + raise ValueError( + "W4A8 (AITER_MXFP4_FP8) requires static input scales, but found " + "w13_input_scale or w2_input_scale is None." + ) + if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + w13_input_scale = w13_input_scale.max().to(torch.float32) + w2_input_scale = w2_input_scale.max().to(torch.float32) + + # Swizzle weights for GFX950 + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(w13_weight, w13_weight_scale) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(w2_weight, w2_weight_scale) + + # Create InFlexData for activation scales + lhs_data13 = InFlexData(scale=w13_input_scale) + lhs_data2 = InFlexData(scale=w2_input_scale) + + # Create PrecisionConfig with both weight and activation info + w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, + flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13), + ) + w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, + flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2), + ) + + del layer.w13_weight + del layer.w2_weight + + return ( + w13_weight, + w2_weight, + w13_precision_config, + w2_precision_config, + w13_bias, + w2_bias, + ) + elif mxfp4_backend in TRITON_BACKENDS: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -1175,6 +1275,8 @@ def make_mxfp4_moe_quant_config( swiglu_limit: float | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, ) -> FusedMoEQuantConfig | None: """Create a FusedMoEQuantConfig for the given MXFP4 backend.""" if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: @@ -1208,6 +1310,17 @@ def make_mxfp4_moe_quant_config( gemm1_beta=gemm1_beta, gemm1_clamp_limit=swiglu_limit, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: + # W4A8: MXFP4 weights + static FP8 activations + return mxfp4_w4a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + block_shape=None, + ) elif mxfp4_backend in ( Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, @@ -1215,7 +1328,7 @@ def make_mxfp4_moe_quant_config( Mxfp4MoeBackend.TRITON_UNFUSED, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.AITER, + Mxfp4MoeBackend.AITER_MXFP4_BF16, ): return mxfp4_w4a16_moe_quant_config( w1_bias=w1_bias, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 0a516831c4ec..530aac8de0e1 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -24,7 +24,7 @@ make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, - select_gpt_oss_mxfp4_moe_backend, + select_deepseek_v4_mxfp4_moe_backend, select_mxfp4_moe_backend, ) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod @@ -140,7 +140,7 @@ class GptOssMxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.weight_dtype = "gpt_oss_mxfp4" - self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe) + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size @@ -466,7 +466,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.weight_dtype = "mxfp4" - self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) + self.mxfp4_backend, self.experts_cls = select_deepseek_v4_mxfp4_moe_backend(moe) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 1eeca142343b..a14bfbc9c19b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -35,19 +35,19 @@ make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, - select_gpt_oss_mxfp4_moe_backend, + select_mxfp4_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - _swizzle_mxfp4, -) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8StaticTensorSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, @@ -62,7 +62,6 @@ __all__ = [ "QuarkMoEMethod", "QuarkOCP_MX_MoEMethod", - "QuarkOCP_MX_MoEMethod_OSS", ] @@ -94,22 +93,9 @@ def get_moe_method( elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config): - emulate = not current_platform.supports_mx() or not ( - rocm_aiter_ops.is_fused_moe_enabled() - ) - if ( - input_config is not None - and input_config.get("dtype") == "fp8_e4m3" - and not input_config.get("is_dynamic") - and not emulate - ): - return QuarkOCP_MX_MoEMethod_OSS( - weight_config, input_config, module.moe_config - ) - else: - return QuarkOCP_MX_MoEMethod( - weight_config, input_config, module.moe_config - ) + # All OCP MX schemes (W4A16, W4A8, etc.) handled by QuarkOCP_MX_MoEMethod + # Backend selection happens inside via oracle + return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_static_tensor_w8a8( weight_config, input_config ) or quant_config._is_dynamic_per_token_w8a8(weight_config, input_config): @@ -993,7 +979,7 @@ def __init__( self.experts_cls: type[mk.FusedMoEExperts] | None = None self.moe_kernel: mk.FusedMoEKernel | None = None - # Used for triton kernel precision configs + # Used for triton kernel precision configs (W4A8, TRITON backends) self.w13_precision_config = None self.w2_precision_config = None @@ -1002,6 +988,17 @@ def __init__( else: self.static_input_scales = False + # Select backend based on OCP MX scheme + if self.ocp_mx_scheme == "w_mxfp4": + # W4A16: weight-only MXFP4 + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) + elif self.ocp_mx_scheme == "w_mxfp4_a_fp8" and self.static_input_scales: + # W4A8: MXFP4 weights + static FP8 activations + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend( + moe, activation_key=kFp8StaticTensorSym + ) + + # Validation for unsupported schemes if any( self.ocp_mx_scheme.endswith(a_scheme) for a_scheme in ["a_mxfp4", "a_mxfp6_e3m2", "a_mxfp6_e2m3"] @@ -1026,7 +1023,7 @@ def __init__( ) # TODO: Remove once all OCP MX schemes use the kernel abstraction - _AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4") + _AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4", "w_mxfp4_a_fp8") self.emulate = ( not current_platform.supports_mx() or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES @@ -1034,9 +1031,6 @@ def __init__( self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe ) - if self.ocp_mx_scheme == "w_mxfp4": - self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe) - if self.emulate: # We use the same code path between MXFP4/MXFP6 emulation. self.mxfp4_backend = Mxfp4MoeBackend.EMULATION @@ -1046,7 +1040,12 @@ def __init__( if self.mxfp4_backend != Mxfp4MoeBackend.NONE: self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0] - if self.emulate: + # Log backend selection + if self.mxfp4_backend != Mxfp4MoeBackend.NONE: + logger.info_once( + f"Using {self.mxfp4_backend.value} backend for {self.ocp_mx_scheme}" + ) + elif self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, " @@ -1056,10 +1055,6 @@ def __init__( "QDQ (quantize and dequantize) will be used, with the linear " "layers computed in high precision." ) - else: - logger.warning_once( - "The current mode supports native MoE MXFP4 computation" - ) def maybe_roundup_sizes( self, @@ -1204,6 +1199,11 @@ def create_weights( layer.w2_input_scale = None def process_weights_after_loading(self, layer): + # For MXFP4 schemes with native backend, use oracle + if self.mxfp4_backend != Mxfp4MoeBackend.NONE: + self._setup_kernel(layer) + return + if self.static_input_scales and self.input_dtype == "fp8": # firstly, process activations if fp8 static input if layer.w13_input_scale is None or layer.w2_input_scale is None: @@ -1252,14 +1252,6 @@ def process_weights_after_loading(self, layer): w2_input_scale, requires_grad=False ) - # For w_mxfp4, use oracle functions - if self.emulate or ( - self.ocp_mx_scheme == "w_mxfp4" - and self.mxfp4_backend != Mxfp4MoeBackend.NONE - ): - self._setup_kernel_via_oracle(layer) - return - # TODO(bowenbao): gradually migrate to oracles. # Existing AITER path for w_mxfp4_a_mxfp4 and other schemes from aiter.utility.fp4_utils import e8m0_shuffle @@ -1298,46 +1290,48 @@ def process_weights_after_loading(self, layer): self.moe_quant_config = self.get_fused_moe_quant_config(layer) torch.accelerator.empty_cache() - def _setup_kernel_via_oracle(self, layer: FusedMoE): - """Setup kernel using oracle functions for w_mxfp4 scheme.""" - w13 = layer.w13_weight - w2 = layer.w2_weight - w13_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale + def _setup_kernel(self, layer: FusedMoE): + """Setup kernel using oracle functions for MXFP4 schemes (W4A16, W4A8).""" w13_bias = getattr(layer, "w13_bias", None) w2_bias = getattr(layer, "w2_bias", None) - # Convert weights to kernel format + # Convert weights to kernel format (handles all backend-specific logic) w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( mxfp4_backend=self.mxfp4_backend, layer=layer, - w13_weight=w13, - w2_weight=w2, - w13_weight_scale=w13_scale, - w2_weight_scale=w2_scale, + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + w13_weight_scale=layer.w13_weight_scale, + w2_weight_scale=layer.w2_weight_scale, w13_bias=w13_bias, w2_bias=w2_bias, ) ) - # For TRITON backends, weights are wrapped tensors from triton_kernels - # that don't support .detach(). Manually assign parameters. - if self.mxfp4_backend not in TRITON_BACKENDS: - replace_parameter(layer, "w13_weight", w13) - replace_parameter(layer, "w2_weight", w2) - replace_parameter(layer, "w13_weight_scale", w13_scale) - replace_parameter(layer, "w2_weight_scale", w2_scale) - else: + # Handle weight/scale assignment based on backend type + if self.mxfp4_backend in TRITON_BACKENDS or self.mxfp4_backend in ( + Mxfp4MoeBackend.AITER_MXFP4_FP8, + ): + # Triton-based backends: w13/w2 are triton_kernels.tensor.Tensor + # Store on layer for apply(), scales are PrecisionConfig layer.w13_weight = w13 layer.w2_weight = w2 self.w13_precision_config = w13_scale self.w2_precision_config = w2_scale + else: + # Standard backends: replace parameters + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) if w13_bias is not None and w2_bias is not None: replace_parameter(layer, "w13_bias", w13_bias) replace_parameter(layer, "w2_bias", w2_bias) + torch.accelerator.empty_cache() + # Build quant config and kernel self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config is not None and self.experts_cls is not None: @@ -1353,22 +1347,26 @@ def _setup_kernel_via_oracle(self, layer: FusedMoE): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - # For w_mxfp4 with oracle backend, use oracle function - if self.ocp_mx_scheme == "w_mxfp4" and self.mxfp4_backend not in ( - Mxfp4MoeBackend.NONE, - Mxfp4MoeBackend.EMULATION, - ): - w1_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale - if self.mxfp4_backend in TRITON_BACKENDS: + # For oracle-based backends (W4A16, W4A8), use make_mxfp4_moe_quant_config + if self.mxfp4_backend not in (Mxfp4MoeBackend.NONE, Mxfp4MoeBackend.EMULATION): + # Determine scale source based on backend type + if self.mxfp4_backend in TRITON_BACKENDS or self.mxfp4_backend in ( + Mxfp4MoeBackend.AITER_MXFP4_FP8, + ): w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config + else: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + return make_mxfp4_moe_quant_config( mxfp4_backend=self.mxfp4_backend, w1_scale=w1_scale, w2_scale=w2_scale, w1_bias=getattr(layer, "w13_bias", None), w2_bias=getattr(layer, "w2_bias", None), + a1_scale=getattr(layer, "w13_input_scale", None), + a2_scale=getattr(layer, "w2_input_scale", None), ) # Emulation and other schemes @@ -1421,7 +1419,7 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: - # For oracle kernel or emulation kernel + # For oracle-based kernels (W4A16, W4A8) or emulation kernel if self.moe_kernel is not None: return self.moe_kernel.apply( hidden_states=x, @@ -1473,135 +1471,3 @@ def apply_monolithic( expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - - -class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): - def __init__( - self, - weight_config: dict[str, Any], - input_config: dict[str, Any], - moe: FusedMoEConfig, - ): - super().__init__(weight_config, input_config, moe) - - def process_weights_after_loading(self, layer): - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - - w13_bias = layer.w13_bias.to(torch.float32) - w2_bias = layer.w2_bias.to(torch.float32) - - layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False) - layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False) - - # FIXME warp need to be adjusted based on batch size - # only apply to batched mode - if self.moe.use_ep: - num_warps = 4 if self.moe.max_num_tokens <= 512 else 8 - else: - num_warps = 8 - - w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( - layer.w13_weight, layer.w13_weight_scale, num_warps - ) - w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( - layer.w2_weight, layer.w2_weight_scale, num_warps - ) - - self.w13_weight_triton_tensor = w13_weight - self.w2_weight_triton_tensor = w2_weight - - # need to delete the original weights to save memory on single GPU - del layer.w13_weight - del layer.w2_weight - layer.w13_weight = None - layer.w2_weight = None - torch.accelerator.empty_cache() - - if self.static_input_scales: - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer." - ) - - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max().to(torch.float32), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max().to(torch.float32), requires_grad=False - ) - - from triton_kernels.numerics import InFlexData - - lhs_data13 = InFlexData(scale=layer.w13_input_scale) - lhs_data2 = InFlexData(scale=layer.w2_input_scale) - - self.w13_precision_config = PrecisionConfig( - weight_scale=w13_scale, - flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13), - ) - - self.w2_precision_config = PrecisionConfig( - weight_scale=w2_scale, - flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2), - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - return mxfp4_w4a8_moe_quant_config( - w1_scale=self.w13_precision_config, - w2_scale=self.w2_precision_config, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - block_shape=None, - ) - - @property - def is_monolithic(self) -> bool: - return True - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - input_ids: torch.Tensor | None = None, - ) -> torch.Tensor: - if layer.enable_eplb: - raise NotImplementedError( - f"EPLB not supported for {self.__class__.__name__} yet." - ) - - from vllm.model_executor.layers.fused_moe.experts.gpt_oss_triton_kernels_moe import ( # noqa: E501 - triton_kernel_moe_forward, - ) - - assert self.moe.hidden_dim_unpadded is not None - assert self.moe.intermediate_size_per_partition_unpadded is not None - return triton_kernel_moe_forward( - hidden_states=x, - w1=self.w13_weight_triton_tensor, - w2=self.w2_weight_triton_tensor, - gating_output=router_logits, - topk=layer.top_k, - renormalize=layer.renormalize, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - unpadded_N_w1=self.moe.intermediate_size_per_partition_unpadded * 2, - unpadded_K_w1=self.moe.hidden_dim_unpadded, - unpadded_N_w2=self.moe.hidden_dim_unpadded, - unpadded_K_w2=self.moe.intermediate_size_per_partition_unpadded, - ) From a8691a6f9f9abbc84de0065d7ee77c5b240922ec Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Thu, 30 Apr 2026 19:09:54 +0000 Subject: [PATCH 2/2] move aiter moe to dedicated file Signed-off-by: Bowen Bao --- .../fused_moe/experts/aiter_mxfp4_w4a8_moe.py | 292 ++++++++++++++++++ .../experts/gpt_oss_triton_kernels_moe.py | 271 ---------------- .../layers/fused_moe/oracle/mxfp4.py | 6 +- 3 files changed, 295 insertions(+), 274 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py diff --git a/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py b/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py new file mode 100644 index 000000000000..3906a7e057ca --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/aiter_mxfp4_w4a8_moe.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, + kMxfp4Static, +) + +__all__ = [ + "AiterW4A8ExpertsMonolithic", + "aiter_triton_kernel_w4a8_moe_forward", +] + + +def aiter_triton_kernel_w4a8_moe_forward( + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: MoEActivation = MoEActivation.SWIGLUOAI, + quant_config: FusedMoEQuantConfig | None = None, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +): + assert ( + quant_config is not None + and quant_config.use_mxfp4_w4a8 + and rocm_aiter_ops.is_enabled() + ) + from aiter.ops.triton.moe_routing.routing import routing as aiter_routing + + routing_data, gather_idx, scatter_idx = aiter_routing( + gating_output, topk, sm_first=not renormalize + ) + return triton_kernel_fused_mxfp4_w4a8_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation.value, + quant_config=quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + unpadded_N_w1=unpadded_N_w1, + unpadded_K_w1=unpadded_K_w1, + unpadded_N_w2=unpadded_N_w2, + unpadded_K_w2=unpadded_K_w2, + ) + + +def triton_kernel_fused_mxfp4_w4a8_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + quant_config: FusedMoEQuantConfig | None = None, + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + a1q_scale: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +) -> torch.Tensor: + assert quant_config is not None + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 + assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 + + # Shape check: weights are padded (e.g. hidden_size padded for + # GFX950 swizzle). + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + gammas = routing_data.gate_scal if routing_data else None + + from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4 + from aiter.ops.triton.quant_moe import downcast_to_static_fp8 + + assert quant_config.w1_precision is not None, ( + "w1_precision in quant config can't be None" + ) + assert quant_config.w2_precision is not None, ( + "w2_precision in quant config can't be None" + ) + + hidden_states = downcast_to_static_fp8( + hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale + ) + + intermediate_cache1 = moe_gemm_a8w4( + hidden_states, + w1.storage.data, + None, + quant_config.w1_precision.weight_scale.storage.data, + quant_config.w1_precision.flex_ctx.lhs_data.scale, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + out_dtype=torch.float8_e4m3fn, + apply_swiglu=True, + alpha=swiglu_alpha, + limit=swiglu_limit, + unpadded_N=unpadded_N_w1, + unpadded_K=unpadded_K_w1, + ) + + intermediate_cache3 = moe_gemm_a8w4( + intermediate_cache1, + w2.storage.data, + None, + quant_config.w2_precision.weight_scale.storage.data, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + None, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE", + unpadded_N=unpadded_N_w2, + unpadded_K=unpadded_K_w2, + ) + + return intermediate_cache3 + + +class AiterW4A8ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): + """ + Monolithic MXFP4 W4A8 expert using AITER triton kernels. + + This backend uses: + - aiter.ops.triton.moe_routing.routing for routing + - aiter.ops.triton.moe_op_gemm_a8w4.moe_gemm_a8w4 for computation + + Weight format: MXFP4 weights with GFX950 swizzle + Activation: Static FP8 quantization + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + self.topk = moe_config.experts_per_token + self.renormalize = moe_config.routing_method in ( + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + # Requires AITER and GFX950 + if not rocm_aiter_ops.is_enabled(): + return False + from vllm.platforms.rocm import on_gfx950 + + return on_gfx950() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # W4A8: MXFP4 weights with static FP8 activations + SUPPORTED_W_A = [ + (kMxfp4Static, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + # Only SILU activation (swiglu) is supported + return activation == MoEActivation.SWIGLUOAI + + @staticmethod + def _supports_parallel_config( + moe_parallel_config: FusedMoEParallelConfig, + ) -> bool: + return ( + not moe_parallel_config.use_all2all_kernels + and not moe_parallel_config.enable_eplb + and moe_parallel_config.dp_size <= 1 + ) + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return routing_method in [ + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + return True + + def supports_expert_map(self) -> bool: + return False # Expert parallelism not yet supported + + @property + def expects_unquantized_inputs(self) -> bool: + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + assert self.moe_config.intermediate_size_per_partition_unpadded is not None + assert self.moe_config.hidden_dim_unpadded is not None + return aiter_triton_kernel_w4a8_moe_forward( + hidden_states=hidden_states, + w1=w1, + w2=w2, + gating_output=router_logits, + topk=self.topk, + renormalize=self.renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + unpadded_N_w1=self.moe_config.intermediate_size_per_partition_unpadded * 2, + unpadded_K_w1=self.moe_config.hidden_dim_unpadded, + unpadded_N_w2=self.moe_config.hidden_dim_unpadded, + unpadded_K_w2=self.moe_config.intermediate_size_per_partition_unpadded, + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index 85fda18b7e9e..e10514debd08 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -5,7 +5,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -22,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, - kFp8StaticTensorSym, kMxfp4Static, ) from vllm.platforms import current_platform @@ -270,53 +268,6 @@ def pack_bitmatrix( tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) -def aiter_triton_kernel_w4a8_moe_forward( - hidden_states: torch.Tensor, - w1, # Tensor or triton_kernels.Tensor - w2, # Tensor or triton_kernels.Tensor - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - activation: MoEActivation = MoEActivation.SWIGLUOAI, - quant_config: FusedMoEQuantConfig | None = None, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - unpadded_N_w1=None, - unpadded_K_w1=None, - unpadded_N_w2=None, - unpadded_K_w2=None, -): - assert ( - quant_config is not None - and quant_config.use_mxfp4_w4a8 - and rocm_aiter_ops.is_enabled() - ) - from aiter.ops.triton.moe_routing.routing import routing as aiter_routing - - routing_data, gather_idx, scatter_idx = aiter_routing( - gating_output, topk, sm_first=not renormalize - ) - return triton_kernel_fused_mxfp4_w4a8_experts( - None, - hidden_states, - w1, - w2, - routing_data, - gather_idx, - scatter_idx, - activation=activation.value, - quant_config=quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - unpadded_N_w1=unpadded_N_w1, - unpadded_K_w1=unpadded_K_w1, - unpadded_N_w2=unpadded_N_w2, - unpadded_K_w2=unpadded_K_w2, - ) - - def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1, # Tensor or triton_kernels.Tensor @@ -490,99 +441,6 @@ def triton_kernel_fused_experts( return output_tensor -# This is a triton implementation of the fused_experts function -def triton_kernel_fused_mxfp4_w4a8_experts( - output_tensor: torch.Tensor, - hidden_states: torch.Tensor, - w1, # Tensor or triton_kernels.Tensor - w2, # Tensor or triton_kernels.Tensor - routing_data, # RoutingData - gather_indx, # GatherIndx - scatter_indx, # ScatterIndx - activation: str = "silu", - quant_config: FusedMoEQuantConfig | None = None, - swiglu_alpha: float = 1.702, - swiglu_limit: float = 7.0, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - a1q_scale: torch.Tensor | None = None, - unpadded_N_w1=None, - unpadded_K_w1=None, - unpadded_N_w2=None, - unpadded_K_w2=None, -) -> torch.Tensor: - assert quant_config is not None - # type check, uint8 means mxfp4 - assert hidden_states.dtype == torch.bfloat16 - assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 - assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 - - # Shape check: weights are padded (e.g. hidden_size padded for - # GFX950 swizzle). - assert hidden_states.shape[-1] == w1.shape[-2] - assert w2.shape[-1] == w1.shape[1] - - E, _, N = w1.shape - - if global_num_experts == -1: - global_num_experts = E - - gammas = routing_data.gate_scal if routing_data else None - - from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4 - from aiter.ops.triton.quant_moe import downcast_to_static_fp8 - - assert quant_config.w1_precision is not None, ( - "w1_precision in quant config can't be None" - ) - assert quant_config.w2_precision is not None, ( - "w2_precision in quant config can't be None" - ) - - hidden_states = downcast_to_static_fp8( - hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale - ) - - intermediate_cache1 = moe_gemm_a8w4( - hidden_states, - w1.storage.data, - None, - quant_config.w1_precision.weight_scale.storage.data, - quant_config.w1_precision.flex_ctx.lhs_data.scale, - quant_config.w2_precision.flex_ctx.lhs_data.scale, - quant_config.w1_bias, - routing_data, - gather_indx=gather_indx, - gammas=gammas if apply_router_weight_on_input else None, - swizzle_mx_scale="CDNA4_SCALE", - out_dtype=torch.float8_e4m3fn, - apply_swiglu=True, - alpha=swiglu_alpha, - limit=swiglu_limit, - unpadded_N=unpadded_N_w1, - unpadded_K=unpadded_K_w1, - ) - - intermediate_cache3 = moe_gemm_a8w4( - intermediate_cache1, - w2.storage.data, - None, - quant_config.w2_precision.weight_scale.storage.data, - quant_config.w2_precision.flex_ctx.lhs_data.scale, - None, - quant_config.w2_bias, - routing_data, - scatter_indx=scatter_indx, - gammas=None if apply_router_weight_on_input else gammas, - swizzle_mx_scale="CDNA4_SCALE", - unpadded_N=unpadded_N_w2, - unpadded_K=unpadded_K_w2, - ) - - return intermediate_cache3 - - def make_routing_data( topk_ids: torch.Tensor, topk_weights: torch.Tensor, @@ -1172,132 +1030,3 @@ def apply( quant_config=self.quant_config, apply_router_weight_on_input=apply_router_weight_on_input, ) - - -class AiterW4A8ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): - """ - Monolithic MXFP4 W4A8 expert using AITER triton kernels. - - This backend uses: - - aiter.ops.triton.moe_routing.routing for routing - - aiter.ops.triton.moe_op_gemm_a8w4.moe_gemm_a8w4 for computation - - Weight format: MXFP4 weights with GFX950 swizzle - Activation: Static FP8 quantization - """ - - def __init__( - self, - moe_config: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, - ): - super().__init__(moe_config, quant_config) - self.topk = moe_config.experts_per_token - self.renormalize = moe_config.routing_method in ( - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ) - - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - @staticmethod - def _supports_current_device() -> bool: - # Requires AITER and GFX950 - if not rocm_aiter_ops.is_enabled(): - return False - from vllm.platforms.rocm import on_gfx950 - - return on_gfx950() - - @staticmethod - def _supports_no_act_and_mul() -> bool: - return False - - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - # W4A8: MXFP4 weights with static FP8 activations - SUPPORTED_W_A = [ - (kMxfp4Static, kFp8StaticTensorSym), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - @staticmethod - def _supports_activation(activation: MoEActivation) -> bool: - # Only SILU activation (swiglu) is supported - return activation == MoEActivation.SWIGLUOAI - - @staticmethod - def _supports_parallel_config( - moe_parallel_config: FusedMoEParallelConfig, - ) -> bool: - return ( - not moe_parallel_config.use_all2all_kernels - and not moe_parallel_config.enable_eplb - and moe_parallel_config.dp_size <= 1 - ) - - @staticmethod - def _supports_routing_method( - routing_method: RoutingMethodType, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - return routing_method in [ - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - - @staticmethod - def _supports_router_logits_dtype( - router_logits_dtype: torch.dtype | None, - routing_method: RoutingMethodType, - ) -> bool: - return True - - def supports_expert_map(self) -> bool: - return False # Expert parallelism not yet supported - - @property - def expects_unquantized_inputs(self) -> bool: - return True - - def apply( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - router_logits: torch.Tensor, - activation: MoEActivation, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - apply_router_weight_on_input: bool, - # grouped topk + fused topk bias parameters - num_expert_group: int | None = None, - e_score_correction_bias: torch.Tensor | None = None, - routed_scaling_factor: float | None = None, - topk_group: int | None = None, - ) -> torch.Tensor: - assert self.moe_config.intermediate_size_per_partition_unpadded is not None - assert self.moe_config.hidden_dim_unpadded is not None - return aiter_triton_kernel_w4a8_moe_forward( - hidden_states=hidden_states, - w1=w1, - w2=w2, - gating_output=router_logits, - topk=self.topk, - renormalize=self.renormalize, - global_num_experts=global_num_experts, - expert_map=expert_map, - quant_config=self.quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - unpadded_N_w1=self.moe_config.intermediate_size_per_partition_unpadded * 2, - unpadded_K_w1=self.moe_config.hidden_dim_unpadded, - unpadded_N_w2=self.moe_config.hidden_dim_unpadded, - unpadded_K_w2=self.moe_config.intermediate_size_per_partition_unpadded, - ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index ac1d8d368817..2352064cbeb0 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -163,7 +163,7 @@ def backend_to_kernel_cls( return [AiterExperts] elif backend == Mxfp4MoeBackend.AITER_MXFP4_FP8: - from vllm.model_executor.layers.fused_moe.experts.gpt_oss_triton_kernels_moe import ( # noqa: E501 + from vllm.model_executor.layers.fused_moe.experts.aiter_mxfp4_w4a8_moe import ( AiterW4A8ExpertsMonolithic, ) @@ -196,8 +196,8 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: "triton": Mxfp4MoeBackend.TRITON, "triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED, "marlin": Mxfp4MoeBackend.MARLIN, - "aiter": Mxfp4MoeBackend.AITER_MXFP4_BF16, # W4A16 - "aiter_mxfp4_fp8": Mxfp4MoeBackend.AITER_MXFP4_FP8, # W4A8 + "aiter": Mxfp4MoeBackend.AITER_MXFP4_BF16, + "aiter_mxfp4_fp8": Mxfp4MoeBackend.AITER_MXFP4_FP8, "xpu": Mxfp4MoeBackend.XPU, "emulation": Mxfp4MoeBackend.EMULATION, }