From a08bf02f793cfe947e8d14f9a98741f33ec881a4 Mon Sep 17 00:00:00 2001 From: Li Date: Wed, 4 Mar 2026 21:44:14 -0800 Subject: [PATCH 1/2] [ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs - Fix _rocm_aiter_fused_moe_fake missing hidden_pad, intermediate_pad, bias1, bias2 parameters that the impl has. This causes a TypeError crash when torch.compile traces through MXFP4 MoE code paths (e.g. mxfp4.py) that pass these arguments to the custom op, because the fake function receives unexpected keyword arguments in FakeTensor mode. - Fix copy-paste error messages in AiterFlashAttentionImpl that incorrectly referenced FlashAttentionImpl (2 instances) - Fix wrong KV cache layout comment to match actual indexing - Fix wrong variable label in QuarkOCP_MX_MoEMethod log message - Fix method name typo qaunt -> quant in rocm_aiter_ops - Fix typo and grammar in rocm_aiter_fa.py Signed-off-by: Li Made-with: Cursor --- vllm/_aiter_ops.py | 6 +++++- .../layers/quantization/quark/quark_moe.py | 2 +- .../layers/quantization/quark/schemes/quark_ocp_mx.py | 2 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 9 +++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c4ba8053cc58..cf0da35f8e47 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake( a2_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: torch.Tensor | None = None, + bias2: torch.Tensor | None = None, ) -> torch.Tensor: if output_dtype is not None: return torch.empty_like(hidden_states, dtype=output_dtype) @@ -1700,7 +1704,7 @@ def gemm_a8wfp4( ) @staticmethod - def triton_fp4_gemm_dynamic_qaunt( + def triton_fp4_gemm_dynamic_quant( x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index b2b77e6688c1..93eb2f7f68ca 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -765,7 +765,7 @@ def __init__( if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " + f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 1b30f5b82c6a..0cc2cd4184df 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -43,7 +43,7 @@ # for envs checks which does not require @cache anymore. # triton kernel is torch compile compatible. # does not require direct registration. -# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`. +# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_quant`. @cache def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: return ( diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index d563fbcbcb0b..22f524cd990f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -157,13 +157,13 @@ def cp_mha_gather_cache( total_tokens: int, ): assert kv_cache_layout in ["NHD", "SHUFFLE"], ( - "kv_cache_layout only support NHD, SHUFFLE" + "kv_cache_layout only supports NHD, SHUFFLE" ) head_dim = key.shape[2] x = 16 // key_cache.element_size() # assert dequant is True, "Currently, we only support "\ # "gather cache with dequant" - # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + # For k cache layout: [num_blocks, page_size, num_heads, head_dim] assert head_dim == key_cache.shape[3], ( "We assume your kv cache layout is [num_blocks, " "page_size, num_heads, head_dim], but got otherwise" @@ -832,7 +832,7 @@ def __init__( if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention is not implemented for FlashAttentionImpl" + "Encoder self-attention is not implemented for AiterFlashAttentionImpl" ) def extend_for_sliding_window( @@ -1047,7 +1047,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported for FlashAttentionImpl" + "fused output quantization is not yet supported " + "for AiterFlashAttentionImpl" ) if attn_metadata is None: From ee4213f04e9ef92a98d0177f1419fcd28d990754 Mon Sep 17 00:00:00 2001 From: Li Date: Thu, 5 Mar 2026 22:05:35 -0800 Subject: [PATCH 2/2] [ROCm] Replace local is_rocm_aiter_fp4_asm_gemm_enabled with rocm_aiter_ops Use the centralized rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled instead of the local is_rocm_aiter_fp4_asm_gemm_enabled in quark_ocp_mx.py, as suggested in review. Signed-off-by: Li Made-with: Cursor --- .../quark/schemes/quark_ocp_mx.py | 25 ++++--------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 0cc2cd4184df..0b0a224f3891 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -3,13 +3,12 @@ from collections.abc import Callable from fractions import Fraction -from functools import cache, partial +from functools import partial from typing import Any import torch import torch.nn.functional as F -from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( @@ -37,22 +36,6 @@ logger = init_logger(__name__) -# TODO: move registration of custom op to aiter_ops.py -# `from vllm._aiter_ops import rocm_aiter_ops` -# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()` -# for envs checks which does not require @cache anymore. -# triton kernel is torch compile compatible. -# does not require direct registration. -# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_quant`. -@cache -def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM - and envs.VLLM_ROCM_USE_AITER - ) - - try: from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import ( @@ -63,7 +46,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: from vllm.utils.torch_utils import direct_register_custom_op - if is_rocm_aiter_fp4_asm_gemm_enabled(): + if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip def gemm_with_dynamic_quant( @@ -233,7 +216,9 @@ def __init__( self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4" ) - self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() + self.rocm_use_aiter_fp4_asm_gemm = ( + rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled() + ) if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None): # Currently need these kernels if not emulating