11import random
2- from typing import Optional
2+ from typing import List , Optional
33
44from flash_attn .flash_attention import FlashAttention
55import torch
@@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
6464 output [i ].copy_ (out , non_blocking = True )
6565
6666
67+ def ref_multi_query_kv_attention (
68+ cu_seq_lens : List [int ],
69+ query : torch .Tensor ,
70+ key : torch .Tensor ,
71+ value : torch .Tensor ,
72+ dtype : torch .dtype ,
73+ ) -> torch .Tensor :
74+ head_size = query .shape [- 1 ]
75+ scale = 1.0 / (head_size ** 0.5 )
76+
77+ num_seqs = len (cu_seq_lens ) - 1
78+ ref_outputs = []
79+ for i in range (num_seqs ):
80+ start_idx = cu_seq_lens [i ]
81+ end_idx = cu_seq_lens [i + 1 ]
82+ seq_len = end_idx - start_idx
83+
84+ # Create attention mask
85+ attn_mask = torch .triu (torch .ones (seq_len , seq_len ), diagonal = 1 ) * - 1e5
86+ attn_mask = attn_mask .to (dtype = dtype , device = 'cuda' )
87+
88+ ref_output = ref_masked_attention (
89+ query [start_idx :end_idx ],
90+ key [start_idx :end_idx ],
91+ value [start_idx :end_idx ],
92+ scale ,
93+ attn_mask = attn_mask ,
94+ )
95+ ref_outputs .append (ref_output )
96+ ref_output = torch .cat (ref_outputs , dim = 0 )
97+ return ref_output
98+
99+
67100def test_single_query_cached_kv_attention (
68101 num_tokens : int ,
69102 num_heads : int ,
@@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
156189 causal = True ,
157190 )[0 ]
158191
159- ref_outputs = []
160- for i , seq_len in enumerate (seq_lens ):
161- attn_mask = torch .triu (torch .ones (seq_len , seq_len ), diagonal = 1 ) * - 1e5
162- attn_mask = attn_mask .to (dtype = dtype , device = 'cuda' )
163- start_idx = cu_seq_lens [i ]
164- end_idx = cu_seq_lens [i + 1 ]
165- ref_output = ref_masked_attention (
166- query [start_idx :end_idx ],
167- key [start_idx :end_idx ],
168- value [start_idx :end_idx ],
169- scale ,
170- attn_mask = attn_mask ,
171- )
172- ref_outputs .append (ref_output )
173- ref_output = torch .cat (ref_outputs , dim = 0 )
174-
192+ cu_seq_lens = cu_seq_lens .cpu ().tolist ()
193+ ref_output = ref_multi_query_kv_attention (
194+ cu_seq_lens ,
195+ query ,
196+ key ,
197+ value ,
198+ dtype ,
199+ )
175200 assert torch .allclose (output , ref_output , atol = 1e-3 , rtol = 1e-5 )
176201
177202
178203@torch .inference_mode ()
179- def test_attention () -> None :
204+ def test_attention (seed : int ) -> None :
205+ # NOTE(woosuk): Even when the seed is fixed, there is a chance that
206+ # the test fails due to the precision issue. Re-run the test if it fails.
207+ torch .random .manual_seed (seed )
208+ torch .cuda .manual_seed (seed )
180209 for dtype in [torch .half , torch .float ]:
181210 for block_size in [8 , 16 ]:
182211 for head_size in [32 , 64 , 80 , 96 , 128 , 160 , 192 , 256 ]:
212+ print (f'Testing single_query_cached_kv_attention with '
213+ f'dtype={ dtype } , block_size={ block_size } , '
214+ f'head_size={ head_size } ' )
183215 test_single_query_cached_kv_attention (
184216 num_tokens = 37 ,
185217 num_heads = 3 ,
@@ -193,6 +225,8 @@ def test_attention() -> None:
193225 for dtype in [torch .half ]:
194226 # NOTE(woosuk): FlashAttention does not support head_size > 128.
195227 for head_size in [64 , 80 , 96 , 128 ]:
228+ print (f'Testing multi_query_kv_attention with dtype={ dtype } , '
229+ f'head_size={ head_size } ' )
196230 test_multi_query_kv_attention (
197231 num_seqs = 11 ,
198232 num_heads = 3 ,
@@ -202,4 +236,4 @@ def test_attention() -> None:
202236
203237
204238if __name__ == '__main__' :
205- test_attention ()
239+ test_attention (seed = 0 )
0 commit comments