diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 65972d02f2f6..78cdbbbf7379 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -8,10 +8,8 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +import torch.nn.functional as F -from tests.kernels.utils import make_alibi_bias from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform @@ -28,6 +26,74 @@ OPS = [chunked_prefill_paged_decode, context_attention_fwd] +def create_causal_attention_mask_for_sdpa( + query_lens: list[int], + seq_lens: list[int], + sliding_window: int = 0, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + total_queries = sum(query_lens) + total_keys = sum(seq_lens) + + # Create a mask filled with -inf + mask = torch.full( + (total_queries, total_keys), float("-inf"), device=device, dtype=dtype + ) + + query_start = 0 + key_start = 0 + + for query_len, seq_len in zip(query_lens, seq_lens): + query_end = query_start + query_len + key_end = key_start + seq_len + q_indices = torch.arange(query_len, device=device) + k_indices = torch.arange(seq_len, device=device) + q_pos_in_seq = seq_len - query_len + q_indices + + valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None] + + if sliding_window > 0: + valid_mask &= k_indices[None, :] >= ( + q_pos_in_seq[:, None] - sliding_window + 1 + ) + + mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0 + + query_start = query_end + key_start = key_end + + return mask + + +def create_alibi_causal_mask( + query_len: int, + seq_len: int, + alibi_slopes: torch.Tensor, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + query_pos = torch.arange( + seq_len - query_len, seq_len, device=device, dtype=torch.float32 + ) + key_pos = torch.arange(seq_len, device=device, dtype=torch.float32) + + rel_pos = key_pos[None, :] - query_pos[:, None] + + # Apply ALiBi slopes: [num_heads, query_len, seq_len] + alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :] + alibi_bias = alibi_bias.to(dtype) + + # Apply causal mask: prevent attending to future positions + # causal_mask[i, j] = True if key_pos[j] <= query_pos[i] + causal_mask = key_pos[None, :] <= query_pos[:, None] + alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf")) + + # Add batch dimension: [1, num_heads, query_len, seq_len] + # SDPA expects batch dimension even for single sequences + return alibi_bias.unsqueeze(0) + + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -52,6 +118,13 @@ def test_contexted_kv_attention( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) + if ( + current_platform.is_rocm() + and op is chunked_prefill_paged_decode + and kv_cache_dtype == "fp8_e5m2" + ): + pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") + current_platform.seed_everything(0) torch.set_default_device(device) @@ -96,16 +169,16 @@ def test_contexted_kv_attention( ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) + values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) + b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 ) for i in range(BS): for j in range(query_lens[i]): @@ -189,56 +262,57 @@ def test_contexted_kv_attention( scale = float(1.0 / (head_size**0.5)) - attn_op = xops.fmha.cutlass.FwOp() + # Reshape for SDPA: (seq_len, num_heads, head_size) -> + # (1, num_heads, seq_len, head_size) + query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size) + query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, num_tokens, head_size + ) - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - query = query.view( - query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1] - ) - key = key[:, :, None, :].expand( - key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] - ) - value = value[:, :, None, :].expand( - value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] - ) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) + # Expand key and value for GQA/MQA to match query heads + key_sdpa = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, sum(seq_lens), head_size + ) - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens + value_sdpa = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] ) - if sliding_window > 0: - attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window) - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, sum(seq_lens), head_size + ) + + attn_mask = create_causal_attention_mask_for_sdpa( + query_lens, seq_lens, sliding_window, device=device, dtype=dtype + ) + + output_ref = F.scaled_dot_product_attention( + query_sdpa, + key_sdpa, + value_sdpa, + attn_mask=attn_mask, + dropout_p=0.0, scale=scale, - op=attn_op, ) torch.cuda.synchronize() start_time = time.time() - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + output_ref = F.scaled_dot_product_attention( + query_sdpa, + key_sdpa, + value_sdpa, + attn_mask=attn_mask, + dropout_p=0.0, scale=scale, - op=attn_op, ) torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") - output_ref = output_ref.reshape(output.shape) + print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") + + # Reshape output back to (num_tokens, num_heads, head_size) + output_ref = output_ref.view(num_heads, num_tokens, head_size) + output_ref = output_ref.permute(1, 0, 2).contiguous() atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) + if ( + current_platform.is_rocm() + and op is chunked_prefill_paged_decode + and kv_cache_dtype == "fp8_e5m2" + ): + pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") + current_platform.seed_everything(0) torch.set_default_device(device) @@ -331,16 +412,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) + values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) + b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 ) for i in range(BS): for j in range(query_lens[i]): @@ -423,78 +504,75 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) - # NOTE(DefTruth): In order to reuse _make_alibi_bias function, - # we have to pad query tensor before MQA/GQA expanding. - if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) - query_pad.uniform_(-1e-3, 1e-3) - seq_start = 0 - query_start = 0 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat( - [ - torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...], - ], - dim=0, - ) - seq_start += seq_len - query_start += query_len - query = query_pad - - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - key = key[:, :, None, :].expand( - key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] - ) - value = value[:, :, None, :].expand( - value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] - ) - # [seq, num_kv_heads, num_queries_per_kv, dk]=> - # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the - # codebase. We save some time reshaping alibi matrix at runtime. - key = key.reshape(key.shape[0], -1, key.shape[-1]) - value = value.reshape(value.shape[0], -1, value.shape[-1]) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + # Prepare query, key, value for SDPA + # Expand key and value for GQA/MQA to match query heads + key_expanded = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value_expanded = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) + output_ref = torch.empty_like(output) - seq_start = 0 - query_start = 0 + + torch.cuda.synchronize() start_time = time.time() - # Attention with alibi slopes. - # FIXME(DefTruth): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/v1/attention/backends/xformers.py#L343 + + query_start = 0 + key_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward( - query[:, seq_start:seq_end], - key[:, seq_start:seq_end], - value[:, seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale, + key_end = key_start + seq_len + + # Get query, key, value for this sequence + q = query[query_start:query_end] # [query_len, num_heads, head_size] + k = key_expanded[ + key_start:key_end + ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] + v = value_expanded[ + key_start:key_end + ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] + + # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size) + q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size) + q_sdpa = ( + q_sdpa.permute(1, 2, 0, 3) + .reshape(1, num_heads, query_len, head_size) + .contiguous() + ) + + k_sdpa = ( + k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() + ) + v_sdpa = ( + v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() ) - out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size + + # Create ALiBi causal mask for this sequence using utility function + alibi_mask = create_alibi_causal_mask( + query_len, seq_len, alibi_slopes, device, dtype + ) + + # Compute attention + out = F.scaled_dot_product_attention( + q_sdpa, + k_sdpa, + v_sdpa, + attn_mask=alibi_mask, + dropout_p=0.0, + scale=scale, ) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...]) - seq_start += seq_len - query_start += query_len + + # Reshape output back to [query_len, num_heads, head_size] + out = out.view(num_heads, query_len, head_size).permute(1, 0, 2) + output_ref[query_start:query_end].copy_(out) + + query_start = query_end + key_start = key_end + torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") + print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)