diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py new file mode 100644 index 000000000000..d0687c62b113 --- /dev/null +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import pytest +import torch + +import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 +from vllm.platforms import current_platform + +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +QDTYPES = [None] +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="Only ROCm is supported") +@pytest.mark.parametrize("seq_lens", + [[(10, 1328), (5, 18), + (129, 463)], [(8, 523), (24, 37), (3, 2011)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("q_dtype", QDTYPES) +@torch.inference_mode() +def test_varlen_with_paged_kv( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + num_blocks: int, + q_dtype: Optional[torch.dtype], +) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window - 1, 0) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + cu_seq_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = torch.empty_like(query) + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + k_descale = torch.ones(scale_shape, dtype=torch.float32) + v_descale = torch.ones(scale_shape, dtype=torch.float32) + + torch.ops.vllm.flash_attn_varlen_func( + maybe_quantized_query, + maybe_quantized_key_cache, + maybe_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + alibi_slopes=None, + window_size=window_size, + block_table=block_tables, + cu_seqlens_k=cu_seq_lens, + k_scale=k_descale, + v_scale=v_descale, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + ) + + atol, rtol = 2e-2, 2e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 0739d2596676..85a5dc8c91c1 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,20 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch -from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) + AttentionMetadata, AttentionType) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec +_PARTITION_SIZE_ROCM = 256 + if current_platform.is_rocm(): import aiter @@ -32,38 +33,54 @@ def _vllm_layout_trans_kernel( b_seq_lens_loc, block_table, block_table_stride_0, + k_scale, + v_scale, + output_dtype: tl.constexpr, E_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) block_idx = tl.program_id(1) - batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + - tl.arange(0, 2)) - batch_token_start, batch_token_end = tl.split(batch_token_indexes) - seq_len = batch_token_end - batch_token_start batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2)) batch_query_start, batch_query_end = tl.split(batch_query_indexes) query_len = batch_query_end - batch_query_start + if query_len <= 1: return + + batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + + tl.arange(0, 2)) + batch_token_start, batch_token_end = tl.split(batch_token_indexes) + seq_len = batch_token_end - batch_token_start + if block_idx * BLOCK_SIZE < seq_len: block_mask = (block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + - block_idx) + block_idx).to(tl.int64) kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) + if k_vals.dtype.is_fp8(): + k_vals = (k_vals.to(tl.float32) * + tl.load(k_scale)).to(output_dtype) + else: + k_vals = k_vals.to(output_dtype) + v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) - + if v_vals.dtype.is_fp8(): + v_vals = (v_vals.to(tl.float32) * + tl.load(v_scale)).to(output_dtype) + else: + v_vals = v_vals.to(output_dtype) kv_values_off = batch_token_start * E_DIM + \ block_idx * BLOCK_SIZE * E_DIM + \ tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \ @@ -72,29 +89,44 @@ def _vllm_layout_trans_kernel( tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, - k_buffer, v_buffer, max_seq_len, total_tokens): - H_KV = v_buffer.shape[2] - D = v_buffer.shape[3] - BLOCK_SIZE = v_buffer.shape[1] - dtype = k_buffer.dtype - k_values = torch.empty((total_tokens, H_KV, D), - dtype=dtype, - device="cuda") - v_values = torch.empty((total_tokens, H_KV, D), - dtype=dtype, - device="cuda") + k_cache, v_cache, max_seq_len, k_scale, v_scale, + output_dtype, total_tokens): + H_KV = v_cache.shape[2] + D = v_cache.shape[3] + BLOCK_SIZE = v_cache.shape[1] + + k_values = torch.empty( + (total_tokens, H_KV, D), + dtype=output_dtype, + device=k_cache.device, + ) + v_values = torch.empty( + (total_tokens, H_KV, D), + dtype=output_dtype, + device=v_cache.device, + ) grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) - _vllm_layout_trans_kernel[grid](k_buffer, - v_buffer, + if output_dtype == torch.float16: + output_dtype = tl.float16 + elif output_dtype == torch.bfloat16: + output_dtype = tl.bfloat16 + else: + raise ValueError(f"Unsupported output dtype: {output_dtype}") + + _vllm_layout_trans_kernel[grid](k_cache, + v_cache, k_values, v_values, b_query_lens_loc, b_seq_lens_loc, block_table, block_table.stride(0), + k_scale, + v_scale, + output_dtype=output_dtype, E_DIM=H_KV * D, BLOCK_SIZE=BLOCK_SIZE) @@ -107,16 +139,22 @@ def flash_attn_varlen_func_impl( out: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, - total_tokens: int, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, window_size: Optional[list[int]], # -1 means infinite context window alibi_slopes: Optional[list[float]], block_table: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + total_tokens: int = 0, ) -> torch.Tensor: + if total_tokens == 0: + total_tokens = int(cu_seqlens_k[-1].item()) k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table, - k_cache, v_cache, max_seqlen_k, total_tokens) + k_cache, v_cache, max_seqlen_k, k_scale, + v_scale, q.dtype, total_tokens) + output = aiter.flash_attn_varlen_func( q=q, k=k, @@ -141,19 +179,21 @@ def flash_attn_varlen_func_fake( out: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, - total_tokens: int, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, window_size: Optional[list[int]], # -1 means infinite context window alibi_slopes: Optional[list[float]], block_table: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + total_tokens: int = 0, ) -> torch.Tensor: return torch.empty(q.shape[0], q.shape[1], v_cache.shape[-2], - dtype=torch.float8_e4m3fnuz, - device="cuda") + dtype=q.dtype, + device=q.device) direct_register_custom_op("flash_attn_varlen_func", flash_attn_varlen_func_impl, ["out"], @@ -163,7 +203,33 @@ def flash_attn_varlen_func_fake( logger = init_logger(__name__) -class AiterFlashAttentionMetadataBuilder: +@dataclass +class AiterFlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + slot_mapping: torch.Tensor + block_table: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + total_tokens: int + + +class AiterFlashAttentionMetadataBuilder( + AttentionMetadataBuilder[AiterFlashAttentionMetadata]): + full_cudagraph_supported: ClassVar[bool] = True def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): @@ -180,14 +246,23 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None + self.total_tokens: int = 0 def reorder_batch(self, input_batch, scheduler_output) -> bool: return False + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + self.total_tokens = self.model_config.max_model_len \ + * self.vllm_config.scheduler_config.max_num_partial_prefills + res = self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + self.total_tokens = 0 + return res + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -195,43 +270,29 @@ def build(self, num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) - total_tokens = int(common_attn_metadata.seq_lens_cpu.sum()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, - dtype=torch.int32, - device=self.device) - torch.cumsum(seq_lens, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + return None use_cascade = common_prefix_len > 0 - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - attn_metadata = AiterFlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - cu_seq_lens=cu_seq_lens, - total_tokens=total_tokens, block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, + total_tokens=self.total_tokens, ) return attn_metadata @@ -254,7 +315,7 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + return [64, 128, 256] @classmethod def validate_head_size(cls, head_size: int) -> None: @@ -295,34 +356,6 @@ def get_kv_cache_shape( return (2, num_blocks, block_size, num_kv_heads, head_size) -@dataclass -class AiterFlashAttentionMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - cu_seq_lens: torch.Tensor - total_tokens: int - block_table: torch.Tensor - slot_mapping: torch.Tensor - - # For cascade attention. - use_cascade: bool - common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] - - class AiterFlashAttentionImpl(AttentionImpl): def __init__( @@ -366,10 +399,6 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "AiterFlashAttention does not support fp8 kv-cache on this " - "device.") def forward( self, @@ -440,12 +469,6 @@ def forward( if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(torch.float8_e4m3fnuz) value_cache = value_cache.view(torch.float8_e4m3fnuz) - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc @@ -455,8 +478,16 @@ def forward( block_table = attn_metadata.block_table if max_seqlen_q > 1: - cu_seq_lens = attn_metadata.cu_seq_lens - total_tokens = attn_metadata.total_tokens + + cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1, + dtype=torch.int32, + device=query.device) + + torch.cumsum(seqused_k, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:]) + torch.ops.vllm.flash_attn_varlen_func( query[:num_actual_tokens], key_cache, @@ -465,29 +496,31 @@ def forward( cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - total_tokens=total_tokens, softmax_scale=self.scale, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=cu_seq_lens) + cu_seqlens_k=cu_seq_lens, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + total_tokens=attn_metadata.total_tokens, + ) _, num_heads, head_size = query.shape - _PARTITION_SIZE_ROCM = 256 + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 num_seqs = seqused_k.shape[0] - nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8 max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM workspace_buffer = torch.empty( (num_seqs * num_heads * max_num_partitions * head_size) * - nbyes_per_qo_elem + 2 * + nbytes_per_qo_elem + 2 * (num_seqs * num_heads * max_num_partitions) * 4, dtype=torch.uint8, device=output.device, ) - aiter.paged_attention_v1( + torch.ops.aiter.paged_attention_v1( output[:num_actual_tokens], workspace_buffer, query[:num_actual_tokens],