Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 194 additions & 116 deletions tests/kernels/attention/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

import pytest
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
import torch.nn.functional as F

from tests.kernels.utils import make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
Expand All @@ -28,6 +26,74 @@
OPS = [chunked_prefill_paged_decode, context_attention_fwd]


def create_causal_attention_mask_for_sdpa(
query_lens: list[int],
seq_lens: list[int],
sliding_window: int = 0,
device: torch.device = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
total_queries = sum(query_lens)
total_keys = sum(seq_lens)

# Create a mask filled with -inf
mask = torch.full(
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
)

query_start = 0
key_start = 0

for query_len, seq_len in zip(query_lens, seq_lens):
query_end = query_start + query_len
key_end = key_start + seq_len
q_indices = torch.arange(query_len, device=device)
k_indices = torch.arange(seq_len, device=device)
q_pos_in_seq = seq_len - query_len + q_indices

valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]

if sliding_window > 0:
valid_mask &= k_indices[None, :] >= (
q_pos_in_seq[:, None] - sliding_window + 1
)

mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0

query_start = query_end
key_start = key_end

return mask


def create_alibi_causal_mask(
query_len: int,
seq_len: int,
alibi_slopes: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
query_pos = torch.arange(
seq_len - query_len, seq_len, device=device, dtype=torch.float32
)
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)

rel_pos = key_pos[None, :] - query_pos[:, None]

# Apply ALiBi slopes: [num_heads, query_len, seq_len]
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
alibi_bias = alibi_bias.to(dtype)

# Apply causal mask: prevent attending to future positions
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
causal_mask = key_pos[None, :] <= query_pos[:, None]
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))

# Add batch dimension: [1, num_heads, query_len, seq_len]
# SDPA expects batch dimension even for single sequences
return alibi_bias.unsqueeze(0)


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand All @@ -52,6 +118,13 @@ def test_contexted_kv_attention(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)

if (
current_platform.is_rocm()
and op is chunked_prefill_paged_decode
and kv_cache_dtype == "fp8_e5m2"
):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")

current_platform.seed_everything(0)
torch.set_default_device(device)

Expand Down Expand Up @@ -96,16 +169,16 @@ def test_contexted_kv_attention(
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = torch.arange(0, cache_size, dtype=torch.int32)
values = values[torch.randperm(cache_size)]
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
)
for i in range(BS):
for j in range(query_lens[i]):
Expand Down Expand Up @@ -189,56 +262,57 @@ def test_contexted_kv_attention(

scale = float(1.0 / (head_size**0.5))

attn_op = xops.fmha.cutlass.FwOp()
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
# (1, num_heads, seq_len, head_size)
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, num_tokens, head_size
)

if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
)
key = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# Expand key and value for GQA/MQA to match query heads
key_sdpa = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, sum(seq_lens), head_size
)

attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens
value_sdpa = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, sum(seq_lens), head_size
)

attn_mask = create_causal_attention_mask_for_sdpa(
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
)

output_ref = F.scaled_dot_product_attention(
query_sdpa,
key_sdpa,
value_sdpa,
attn_mask=attn_mask,
dropout_p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
output_ref = F.scaled_dot_product_attention(
query_sdpa,
key_sdpa,
value_sdpa,
attn_mask=attn_mask,
dropout_p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")

# Reshape output back to (num_tokens, num_heads, head_size)
output_ref = output_ref.view(num_heads, num_tokens, head_size)
output_ref = output_ref.permute(1, 0, 2).contiguous()
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)

Expand All @@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)

if (
current_platform.is_rocm()
and op is chunked_prefill_paged_decode
and kv_cache_dtype == "fp8_e5m2"
):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")

current_platform.seed_everything(0)
torch.set_default_device(device)

Expand Down Expand Up @@ -331,16 +412,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = torch.arange(0, cache_size, dtype=torch.int32)
values = values[torch.randperm(cache_size)]
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
)
for i in range(BS):
for j in range(query_lens[i]):
Expand Down Expand Up @@ -423,78 +504,75 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))

# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat(
[
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...],
],
dim=0,
)
seq_start += seq_len
query_start += query_len
query = query_pad

if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)

attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
# Prepare query, key, value for SDPA
# Expand key and value for GQA/MQA to match query heads
key_expanded = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value_expanded = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)

output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0

torch.cuda.synchronize()
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/v1/attention/backends/xformers.py#L343

query_start = 0
key_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(
query[:, seq_start:seq_end],
key[:, seq_start:seq_end],
value[:, seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale,
key_end = key_start + seq_len

# Get query, key, value for this sequence
q = query[query_start:query_end] # [query_len, num_heads, head_size]
k = key_expanded[
key_start:key_end
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
v = value_expanded[
key_start:key_end
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]

# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
q_sdpa = (
q_sdpa.permute(1, 2, 0, 3)
.reshape(1, num_heads, query_len, head_size)
.contiguous()
)

k_sdpa = (
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
)
v_sdpa = (
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size

# Create ALiBi causal mask for this sequence using utility function
alibi_mask = create_alibi_causal_mask(
query_len, seq_len, alibi_slopes, device, dtype
)

# Compute attention
out = F.scaled_dot_product_attention(
q_sdpa,
k_sdpa,
v_sdpa,
attn_mask=alibi_mask,
dropout_p=0.0,
scale=scale,
)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
seq_start += seq_len
query_start += query_len

# Reshape output back to [query_len, num_heads, head_size]
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
output_ref[query_start:query_end].copy_(out)

query_start = query_end
key_start = key_end

torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)

Expand Down