diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce543..348c301a51f7 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -5,7 +5,7 @@ import torch from torch._ops import OpOverload - +from vllm.platforms.rocm import on_gfx12x import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -50,9 +50,9 @@ def is_aiter_found_and_supported() -> bool: VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery. """ if current_platform.is_rocm() and IS_AITER_FOUND: - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx12x - return on_gfx9() + return on_gfx9() or on_gfx12x() return False @@ -1769,11 +1769,12 @@ def group_fp8_quant( group_size: int = 128, ) -> tuple[torch.Tensor, torch.Tensor]: assert group_size == 128, "Group size must be 128" + return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size) @staticmethod def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: - return (n, k) in [ + shapes = [ (1024, 8192), (2112, 7168), (3072, 1536), @@ -1910,7 +1911,70 @@ def flash_attn_varlen_func( even when VLLM_ROCM_USE_AITER=0. Note: This performs lazy import of aiter.flash_attn_varlen_func + For gfx12x (RDNA4) GPUs, uses mha_v3 Triton kernel for better performance. """ + # Check if we should use mha_v3 for gfx12x + # mha_v3 provides better performance on RDNA4 architectures + if on_gfx12x(): + # Check if mha_v3 can be used (no alibi, no dropout, no custom window size) + # window_size=None means infinite context window (default behavior) + can_use_mha_v3 = ( + alibi_slopes is None + and dropout_p == 0.0 + and (window_size is None or window_size == (-1, -1)) + and not return_lse # mha_v3 doesn't support returning LSE + ) + + if can_use_mha_v3: + try: + from aiter.ops.triton.attention.mha_v3 import ( + flash_attn_varlen_func as mha_v3_varlen_func, + ) + + # mha_v3 has a different signature - convert parameters + # mha_v3 expects window_size as a tuple, default is (-1, -1) for infinite window + mha_v3_window_size = ( + (-1, -1) if window_size is None else window_size + ) + + mha_v3_out = mha_v3_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=causal, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=mha_v3_window_size, + attention_chunk=0, + softcap=0.0, + deterministic=False, + sm_margin=0, + ) + + if out is not None: + out.copy_(mha_v3_out) + return out + return mha_v3_out + + except Exception as e: + # Fall through to default implementation if mha_v3 fails + # Log the error for debugging (only once to avoid spam) + import warnings + + warnings.warn( + f"mha_v3 failed, falling back to default: {e}", + RuntimeWarning, + stacklevel=2, + ) + pass + + # Default: use standard flash_attn_varlen_func (CK-based or Triton-based) from aiter import flash_attn_varlen_func return flash_attn_varlen_func( diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,block_shape=[128,128].json new file mode 100644 index 000000000000..0a3645526d6b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,block_shape=[128,128].json @@ -0,0 +1,69 @@ +{ + "triton_version": "3.5.1+rocm7.2.0.gita272dfa8", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..6d5a373d725d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,57 @@ +{ + "triton_version": "3.5.1+rocm7.2.0.gita272dfa8", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 68393f768dcc..a0052a8922a1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -921,13 +921,15 @@ def _supports_quant_scheme( ) -> bool: p = current_platform if p.is_rocm(): - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx12x is_rocm_on_gfx9 = on_gfx9() + is_rocm_on_gfx12 = on_gfx12x() else: is_rocm_on_gfx9 = False + is_rocm_on_gfx12 = False - device_supports_fp8 = is_rocm_on_gfx9 or ( + device_supports_fp8 = is_rocm_on_gfx9 or is_rocm_on_gfx12 or ( p.is_cuda() and p.has_device_capability((8, 9)) ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 023cdd0b4340..a9e2a2e2133f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1953,14 +1953,18 @@ def _supports_quant_scheme( ) -> bool: p = current_platform if p.is_rocm(): - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx12x is_rocm_on_gfx9 = on_gfx9() + is_rocm_on_gfx12 = on_gfx12x() else: is_rocm_on_gfx9 = False + is_rocm_on_gfx12 = False - device_supports_fp8 = is_rocm_on_gfx9 or ( - p.is_cuda() and p.has_device_capability((8, 9)) + device_supports_fp8 = ( + is_rocm_on_gfx9 + or is_rocm_on_gfx12 + or (p.is_cuda() and p.has_device_capability((8, 9))) ) if not device_supports_fp8: @@ -2056,7 +2060,9 @@ def apply( E, num_tokens, N, K, top_k_num = self.moe_problem_size( hidden_states, w1, w2, topk_ids ) - + # print("weight N=", N) + # print("weight M=", E) + # print('weight dtype:', hidden_states.dtype) if global_num_experts == -1: global_num_experts = E diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 56d654961908..74eb076719e4 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -144,7 +144,8 @@ def _get_gcn_arch() -> str: # These are plain Python bools — fully torch.compile/Dynamo safe. _GCN_ARCH = _get_gcn_arch() -_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"]) +_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11"]) +_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"]) _ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"]) _ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) _ON_GFX942 = "gfx942" in _GCN_ARCH @@ -226,6 +227,10 @@ def on_gfx1x() -> bool: return _ON_GFX1X +def on_gfx12x() -> bool: + return _ON_GFX12X + + def on_mi3xx() -> bool: return _ON_MI3XX @@ -286,7 +291,7 @@ def use_rocm_custom_paged_attention( @cache def flash_attn_triton_available() -> bool: - if not on_gfx1x(): + if not on_gfx1x() and not on_gfx12x(): return False try: from importlib.util import find_spec @@ -434,7 +439,7 @@ def get_attn_backend_cls( # Priority 2: Check for AITER MHA (Flash Attention) # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1) - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and (on_gfx9() or on_gfx12x()): logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() @@ -511,6 +516,7 @@ def get_vit_attn_backend( # RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend if ( on_gfx1x() + and on_gfx12x() and flash_attn_triton_available() and (dtype == torch.float16 or dtype == torch.bfloat16) ):