Skip to content

Commit a1b3de8

Browse files
authored
Refactor the test code for attention kernels (#13)
1 parent 64e0e38 commit a1b3de8

File tree

1 file changed

+53
-19
lines changed

1 file changed

+53
-19
lines changed

tests/kernels/attention.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import Optional
2+
from typing import List, Optional
33

44
from flash_attn.flash_attention import FlashAttention
55
import torch
@@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
6464
output[i].copy_(out, non_blocking=True)
6565

6666

67+
def ref_multi_query_kv_attention(
68+
cu_seq_lens: List[int],
69+
query: torch.Tensor,
70+
key: torch.Tensor,
71+
value: torch.Tensor,
72+
dtype: torch.dtype,
73+
) -> torch.Tensor:
74+
head_size = query.shape[-1]
75+
scale = 1.0 / (head_size ** 0.5)
76+
77+
num_seqs = len(cu_seq_lens) - 1
78+
ref_outputs = []
79+
for i in range(num_seqs):
80+
start_idx = cu_seq_lens[i]
81+
end_idx = cu_seq_lens[i + 1]
82+
seq_len = end_idx - start_idx
83+
84+
# Create attention mask
85+
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
86+
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
87+
88+
ref_output = ref_masked_attention(
89+
query[start_idx:end_idx],
90+
key[start_idx:end_idx],
91+
value[start_idx:end_idx],
92+
scale,
93+
attn_mask=attn_mask,
94+
)
95+
ref_outputs.append(ref_output)
96+
ref_output = torch.cat(ref_outputs, dim=0)
97+
return ref_output
98+
99+
67100
def test_single_query_cached_kv_attention(
68101
num_tokens: int,
69102
num_heads: int,
@@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
156189
causal=True,
157190
)[0]
158191

159-
ref_outputs = []
160-
for i, seq_len in enumerate(seq_lens):
161-
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
162-
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
163-
start_idx = cu_seq_lens[i]
164-
end_idx = cu_seq_lens[i + 1]
165-
ref_output = ref_masked_attention(
166-
query[start_idx:end_idx],
167-
key[start_idx:end_idx],
168-
value[start_idx:end_idx],
169-
scale,
170-
attn_mask=attn_mask,
171-
)
172-
ref_outputs.append(ref_output)
173-
ref_output = torch.cat(ref_outputs, dim=0)
174-
192+
cu_seq_lens = cu_seq_lens.cpu().tolist()
193+
ref_output = ref_multi_query_kv_attention(
194+
cu_seq_lens,
195+
query,
196+
key,
197+
value,
198+
dtype,
199+
)
175200
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
176201

177202

178203
@torch.inference_mode()
179-
def test_attention() -> None:
204+
def test_attention(seed: int) -> None:
205+
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
206+
# the test fails due to the precision issue. Re-run the test if it fails.
207+
torch.random.manual_seed(seed)
208+
torch.cuda.manual_seed(seed)
180209
for dtype in [torch.half, torch.float]:
181210
for block_size in [8, 16]:
182211
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
212+
print(f'Testing single_query_cached_kv_attention with '
213+
f'dtype={dtype}, block_size={block_size}, '
214+
f'head_size={head_size}')
183215
test_single_query_cached_kv_attention(
184216
num_tokens=37,
185217
num_heads=3,
@@ -193,6 +225,8 @@ def test_attention() -> None:
193225
for dtype in [torch.half]:
194226
# NOTE(woosuk): FlashAttention does not support head_size > 128.
195227
for head_size in [64, 80, 96, 128]:
228+
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
229+
f'head_size={head_size}')
196230
test_multi_query_kv_attention(
197231
num_seqs=11,
198232
num_heads=3,
@@ -202,4 +236,4 @@ def test_attention() -> None:
202236

203237

204238
if __name__ == '__main__':
205-
test_attention()
239+
test_attention(seed=0)

0 commit comments

Comments
 (0)