diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index cf24630c509f..a991d8e675db 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -5,10 +5,17 @@ import pytest import torch -import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed -from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available + +# Import AITER backend if on ROCm and aiter is available +if current_platform.is_rocm(): + from vllm._aiter_ops import is_aiter_found_and_supported + + if is_aiter_found_and_supported(): + import aiter + + from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] @@ -102,8 +109,11 @@ def test_varlen_with_paged_kv( num_blocks: int, q_dtype: torch.dtype | None, ) -> None: - if not is_flash_attn_varlen_func_available(): - pytest.skip("flash_attn_varlen_func required to run this test.") + from vllm._aiter_ops import is_aiter_found_and_supported + + if not is_aiter_found_and_supported(): + pytest.skip("aiter package required for this test.") + torch.set_default_device("cuda") set_random_seed(0) num_seqs = len(seq_lens) @@ -129,6 +139,8 @@ def test_varlen_with_paged_kv( cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum( dim=0, dtype=torch.int32 ) + # Save kv_lens as list before converting to tensor + kv_lens_list = kv_lens kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size @@ -141,33 +153,83 @@ def test_varlen_with_paged_kv( maybe_quantized_query = query maybe_quantized_key_cache = key_cache maybe_quantized_value_cache = value_cache - k_descale = None - v_descale = None + k_scale_tensor = None + v_scale_tensor = None + dequant = False + if q_dtype is not None: # QKV are drawn from N(0, 1): no need for a fp8 scaling factor maybe_quantized_query = query.to(q_dtype) maybe_quantized_key_cache = key_cache.to(q_dtype) maybe_quantized_value_cache = value_cache.to(q_dtype) - + dequant = True scale_shape = (num_seqs, num_kv_heads) - k_descale = torch.ones(scale_shape, dtype=torch.float32) - v_descale = torch.ones(scale_shape, dtype=torch.float32) - torch.ops.vllm.flash_attn_varlen_func( - maybe_quantized_query, - maybe_quantized_key_cache, - maybe_quantized_value_cache, - out=output, + # For per-seq-per-head scales (matching AITER backend expectation) + k_scale_tensor = torch.ones(scale_shape, dtype=torch.float32) + v_scale_tensor = torch.ones(scale_shape, dtype=torch.float32) + + # Prepare metadata for cp_mha_gather_cache + # token_to_batch: maps each token to its batch index + token_to_batch = torch.zeros(sum(kv_lens_list), dtype=torch.int32) + seq_starts = torch.zeros(num_seqs, dtype=torch.int32) + + token_idx = 0 + for batch_idx, kv_len in enumerate(kv_lens_list): + token_to_batch[token_idx : token_idx + kv_len] = batch_idx + seq_starts[batch_idx] = 0 # Assuming all sequences start at 0 in their blocks + token_idx += kv_len + + # Allocate buffers for gathered KV + total_kv_tokens = sum(kv_lens_list) + gathered_key = torch.empty( + total_kv_tokens, num_kv_heads, head_size, dtype=maybe_quantized_key_cache.dtype + ) + gathered_value = torch.empty( + total_kv_tokens, + num_kv_heads, + head_size, + dtype=maybe_quantized_value_cache.dtype, + ) + + # Gather paged KV cache into contiguous tensors using triton kernel + cp_mha_gather_cache( + key_cache=maybe_quantized_key_cache, + value_cache=maybe_quantized_value_cache, + key=gathered_key, + value=gathered_value, + block_tables=block_tables, + k_scales=k_scale_tensor + if k_scale_tensor is not None + else torch.ones(1, dtype=torch.float32), + v_scales=v_scale_tensor + if v_scale_tensor is not None + else torch.ones(1, dtype=torch.float32), + cu_seqlens_kv=cu_seq_lens, + token_to_batch=token_to_batch, + seq_starts=seq_starts, + dequant=dequant, + kv_cache_layout="NHD", + total_tokens=total_kv_tokens, + ) + + # Call aiter flash attention with gathered KV + aiter.flash_attn_varlen_func( + q=maybe_quantized_query, + k=gathered_key, + v=gathered_value, cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_seq_lens, max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, + min_seqlen_q=1, + dropout_p=0.0, softmax_scale=scale, - alibi_slopes=None, + causal=True, window_size=window_size, - block_table=block_tables, - cu_seqlens_k=cu_seq_lens, - k_scale=k_descale, - v_scale=v_descale, + alibi_slopes=None, + return_lse=False, + out=output, ) ref_output = ref_paged_attn( @@ -175,7 +237,7 @@ def test_varlen_with_paged_kv( key_cache=key_cache, value_cache=value_cache, query_lens=query_lens, - kv_lens=kv_lens, + kv_lens=kv_lens_list, block_tables=block_tables, scale=scale, sliding_window=sliding_window, @@ -189,3 +251,8 @@ def test_varlen_with_paged_kv( torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), f"{torch.max(torch.abs(output - ref_output))}", ) + + # Log diff stats for tracking changes + print(f"Max abs diff: {torch.max(torch.abs(output - ref_output))}") + print(f"Mean diff: {torch.mean(torch.abs(output - ref_output))}") + print(f"Min diff: {torch.std(torch.abs(output - ref_output))}") diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 7228d92f7810..c544d2d3d195 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -14,7 +14,10 @@ rocm_aiter_sparse_attn_indexer_fake, ) -_FP8_DTYPE = current_platform.fp8_dtype() +# fp8_dtype is not cached. +# on ROCm the fp8_dtype always calls is_fp8_fnuz +# which is a host op, so we cache it once here. +FP8_DTYPE = current_platform.fp8_dtype() def is_aiter_found() -> bool: @@ -31,12 +34,22 @@ def is_aiter_found() -> bool: def is_aiter_found_and_supported() -> bool: - """Check if AITER is available AND enabled via environment variable. + """Check if AITER library is available and platform supports it. - Checks: platform (ROCm), device arch (gfx9), library existence, - and VLLM_ROCM_USE_AITER env variable. + Checks: platform (ROCm), device arch (gfx9), and library existence. + Does NOT check environment variables - that's handled by rocm_aiter_ops.is_enabled(). + + This function determines if aiter CAN be used, not if it SHOULD be used. + + Separation of concerns: + - This function: Can aiter work on this system? (platform + library availability) + - rocm_aiter_ops.is_enabled(): Should aiter be used by default? (adds env var check) + - Backend selection: Can explicitly request aiter regardless of env var + + This allows explicit backend selection via attention_config to work even when + VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery. """ - if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER: + if current_platform.is_rocm() and IS_AITER_FOUND: from vllm.platforms.rocm import on_gfx9 return on_gfx9() @@ -58,21 +71,6 @@ def wrapper(*args, **kwargs): return wrapper -# Can't use dtypes.fp8 directly inside an op -# because it returns wrong result on gfx942. -# This is a workaround to get the correct FP8 dtype. -# This might because that the get_gfx() is wrapped as a custom op. -if is_aiter_found_and_supported(): - from aiter import dtypes - - AITER_FP8_DTYPE = dtypes.fp8 -else: - # Placeholder when AITER is disabled - prevents NameError during module load. - # Note: When AITER is disabled, ops are not registered, so fake implementations - # referencing this variable won't actually be called at runtime. - AITER_FP8_DTYPE = _FP8_DTYPE - - def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -539,7 +537,7 @@ def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: import aiter as rocm_aiter - assert quant_dtype in [torch.int8, _FP8_DTYPE] + assert quant_dtype in [torch.int8, FP8_DTYPE] y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) @@ -581,7 +579,7 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl( ) -> tuple[torch.Tensor, torch.Tensor]: import aiter as rocm_aiter - assert quant_dtype in [torch.int8, _FP8_DTYPE] + assert quant_dtype in [torch.int8, FP8_DTYPE] y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) @@ -630,10 +628,10 @@ def _rocm_aiter_per_token_quant_impl( ) -> tuple[torch.Tensor, torch.Tensor]: from aiter.ops.quant import dynamic_per_token_scaled_quant - assert quant_dtype in [torch.int8, _FP8_DTYPE] + assert quant_dtype in [torch.int8, FP8_DTYPE] out_shape = x.shape - out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device) + out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device) if scale is None: scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) dynamic_per_token_scaled_quant( @@ -653,7 +651,7 @@ def _rocm_aiter_per_token_quant_fake( ) -> tuple[torch.Tensor, torch.Tensor]: out_shape = x.shape return ( - torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device), + torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device), torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), ) @@ -675,7 +673,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( None, None, group_size=group_size, - dtype_quant=AITER_FP8_DTYPE, + dtype_quant=FP8_DTYPE, res1=residual, ) return ( @@ -695,7 +693,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( M, N = x.shape scale_shape = (M, (N + group_size - 1) // group_size) return ( - torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty_like(x, dtype=FP8_DTYPE, device=x.device), torch.empty_like(residual, device=residual.device), torch.empty(scale_shape, dtype=torch.float32, device=x.device), ) @@ -717,7 +715,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_impl( None, None, group_size=group_size, - dtype_quant=AITER_FP8_DTYPE, + dtype_quant=FP8_DTYPE, res1=None, ) return (x_quant, x_quant_scales) @@ -732,7 +730,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake( M, N = x.shape scale_shape = (M, (N + group_size - 1) // group_size) return ( - torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty_like(x, dtype=FP8_DTYPE, device=x.device), torch.empty(scale_shape, dtype=torch.float32, device=x.device), ) @@ -745,7 +743,7 @@ def _rocm_aiter_group_fp8_quant_impl( from aiter import QuantType, get_hip_quant aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) - return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE) + return aiter_per1x128_quant(x.contiguous(), quant_dtype=FP8_DTYPE) def _rocm_aiter_group_fp8_quant_fake( @@ -753,7 +751,7 @@ def _rocm_aiter_group_fp8_quant_fake( group_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: M, N = x.shape - x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device) + x_fp8 = torch.empty((M, N), dtype=FP8_DTYPE, device=x.device) out_bs = torch.empty( ( M, @@ -775,7 +773,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_impl( x, activation="silu", group_size=group_size, - dtype_quant=AITER_FP8_DTYPE, + dtype_quant=FP8_DTYPE, ) @@ -786,7 +784,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( M, N = x.shape assert N % 2 == 0 N_half = N // 2 - x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device) + x_fp8 = torch.empty((M, N_half), dtype=FP8_DTYPE, device=x.device) out_bs = torch.empty( ( M, @@ -986,7 +984,7 @@ def is_mha_enabled(cls) -> bool: @classmethod @if_aiter_supported def is_shuffle_kv_cache_enabled(cls) -> bool: - return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED + return cls._SHUFFLE_KV_CACHE_ENABLED @classmethod @if_aiter_supported @@ -1654,5 +1652,87 @@ def shuffle_weights( return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) + @staticmethod + def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int | None = None, + dropout_p: float = 0.0, + softmax_scale: float | None = None, + causal: bool = False, + window_size: tuple[int, int] | None = None, + alibi_slopes: torch.Tensor | None = None, + return_lse: bool = False, + out: torch.Tensor | None = None, + ): + """ + Flash attention with variable length sequences. + + This function is NOT wrapped with @is_aiter_supported decorator + to allow explicit backend selection via attention_config to work + even when VLLM_ROCM_USE_AITER=0. + + Note: This performs lazy import of aiter.flash_attn_varlen_func + """ + from aiter import flash_attn_varlen_func + + return flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + min_seqlen_q=min_seqlen_q, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_lse=return_lse, + out=out, + ) + + @staticmethod + def pa_fwd_asm( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_tables_stride0: int, + K_QScale: torch.Tensor, + V_QScale: torch.Tensor, + out_: torch.Tensor, + ): + """ + Paged attention forward pass using assembly kernel. + + This function is NOT wrapped with @is_aiter_supported decorator + to allow explicit backend selection via attention_config to work + even when VLLM_ROCM_USE_AITER=0. + + Note: This performs lazy import of aiter.pa_fwd_asm + """ + from aiter import pa_fwd_asm + + return pa_fwd_asm( + Q=Q, + K=K, + V=V, + block_tables=block_tables, + context_lens=context_lens, + block_tables_stride0=block_tables_stride0, + K_QScale=K_QScale, + V_QScale=V_QScale, + out_=out_, + ) + rocm_aiter_ops.register_ops_once() diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 281d188557fd..ccf52aff20d9 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -8,6 +8,12 @@ logger = init_logger(__name__) +# Track whether upstream flash-attn is available on ROCm. +# Set during module initialization and never modified afterwards. +# This module-level flag avoids repeated import attempts and ensures +# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py). +_ROCM_FLASH_ATTN_AVAILABLE = False + if current_platform.is_cuda(): from vllm._custom_ops import reshape_and_cache_flash from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] @@ -26,6 +32,9 @@ elif current_platform.is_rocm(): try: from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] + + # Mark that upstream flash-attn is available on ROCm + _ROCM_FLASH_ATTN_AVAILABLE = True except ImportError: def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc] @@ -34,6 +43,15 @@ def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no "to be installed. Please install flash-attn first." ) + # ROCm doesn't use scheduler metadata (FA3 feature), provide stub + def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + return None + + # ROCm uses the C++ custom op for reshape_and_cache + from vllm import _custom_ops as ops + + reshape_and_cache_flash = ops.reshape_and_cache_flash + def get_flash_attn_version(requires_alibi: bool = False) -> int | None: # import here to avoid circular dependencies @@ -128,4 +146,30 @@ def flash_attn_supports_mla(): def is_flash_attn_varlen_func_available() -> bool: - return current_platform.is_cuda() or current_platform.is_xpu() + """Check if flash_attn_varlen_func is available. + + This function determines whether the flash_attn_varlen_func imported at module + level is a working implementation or a stub. + + Platform-specific sources: + - CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func + - XPU: ipex_ops.flash_attn_varlen_func + - ROCm: upstream flash_attn.flash_attn_varlen_func (if available) + + Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py) + which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is + handled separately via _aiter_ops.is_aiter_found_and_supported(). + + Returns: + bool: True if a working flash_attn_varlen_func implementation is available. + """ + if current_platform.is_cuda() or current_platform.is_xpu(): + # CUDA and XPU always have flash_attn_varlen_func available + return True + + if current_platform.is_rocm(): + # Use the flag set during module import to check if + # upstream flash-attn was successfully imported + return _ROCM_FLASH_ATTN_AVAILABLE + + return False diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 18194e05f9e9..28b5a7f419df 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -34,9 +34,6 @@ if current_platform.is_rocm(): from vllm.triton_utils import tl, triton - if rocm_aiter_ops.is_enabled(): - import aiter - def block_size(x, head_dim): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) @@ -798,7 +795,7 @@ def extend_for_sliding_window( total_tokens=swa_total_tokens, ) - aiter.flash_attn_varlen_func( + rocm_aiter_ops.flash_attn_varlen_func( q=query, k=key_fetched, v=value_fetched, @@ -848,7 +845,7 @@ def extend_forward( v_scale, ) return - out, lse = aiter.flash_attn_varlen_func( + out, lse = rocm_aiter_ops.flash_attn_varlen_func( q=query, k=key, v=value, @@ -895,7 +892,7 @@ def extend_forward( total_tokens=total_token_per_batch[chunk_idx], ) - suf_out, suf_lse = aiter.flash_attn_varlen_func( + suf_out, suf_lse = rocm_aiter_ops.flash_attn_varlen_func( q=query, k=key_fetched, v=value_fetched, @@ -1053,7 +1050,7 @@ def forward( prefill_key = key[num_decode_tokens + num_extend_tokens :] prefill_value = value[num_decode_tokens + num_extend_tokens :] - aiter.flash_attn_varlen_func( + rocm_aiter_ops.flash_attn_varlen_func( q=prefill_query, k=prefill_key, v=prefill_value, @@ -1159,7 +1156,7 @@ def forward( ) new_key_cache = key_cache.view_as(k_cache_template) new_value_cache = value_cache.view_as(v_cache_template) - aiter.pa_fwd_asm( + rocm_aiter_ops.pa_fwd_asm( Q=query[:num_decode_tokens], K=new_key_cache, V=new_value_cache, @@ -1188,6 +1185,10 @@ def forward( device=output.device, ) + # import so that aiter register the op to the namespace of + # torch.ops.aiter + import aiter # noqa: F401 + torch.ops.aiter.paged_attention_v1( output[:num_decode_tokens], workspace_buffer, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 82505645cfca..d4b38d67021a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -222,9 +222,13 @@ def __init__( RocmAttentionMetadata, ] # ROCM_AITER_FA is an optional backend - from vllm._aiter_ops import rocm_aiter_ops - - if rocm_aiter_ops.is_enabled() and find_spec( + # We check is_enabled() here to avoid importing the backend module during + # auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter + # import and JIT compilation warnings. Explicit backend selection via + # attention_config still works because the backend module is loaded + # directly when selected, not through this auto-discovery path. + # Check if backend module exists to allow explicit selection + if find_spec( AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) ): from vllm.v1.attention.backends.rocm_aiter_fa import (