@@ -72,6 +72,24 @@ def run_bench(
7272 measurements = bench_gpu_time (lambda : wrapper_old .run (q , kv_data ))
7373 ms_old = np .median (measurements )
7474
75+ wrapper_persistent = flashinfer .BatchAttention (kv_layout = "NHD" )
76+ wrapper_persistent .plan (
77+ q_indptr .to (device ),
78+ kv_indptr .to (device ),
79+ torch .arange (num_blocks , dtype = torch .int32 , device = device ),
80+ seq_lens .to (device ),
81+ num_qo_heads ,
82+ num_kv_heads ,
83+ head_dim ,
84+ head_dim ,
85+ page_block_size ,
86+ causal = causal ,
87+ q_data_type = torch .bfloat16 ,
88+ kv_data_type = torch .bfloat16 ,
89+ )
90+ o_persistent , _ = wrapper_persistent .run (q , kv_data )
91+ measurements_persistent = bench_gpu_time (lambda : wrapper_persistent .run (q , kv_data ))
92+ ms_persistent = np .mean (measurements_persistent )
7593 if len (p_kv_lens ) == 1 :
7694 q_d = q [: d_q_indptr [- 1 ]]
7795 kv_d = kv_data [: d_kv_indptr [- 1 ]].unbind (1 )
@@ -123,9 +141,46 @@ def run_bench(
123141 )
124142 )
125143 ms_pod = np .median (measurements )
144+
145+ # Sequential two kernels: single prefill + batch decode (tensor cores)
146+ # Prefill using single_prefill_with_kv_cache
147+ def _run_single_prefill ():
148+ return flashinfer .prefill .single_prefill_with_kv_cache (
149+ q_p ,
150+ k_p ,
151+ v_p ,
152+ causal = causal ,
153+ pos_encoding_mode = "NONE" ,
154+ backend = "fa2" ,
155+ )
156+
157+ measurements_prefill = bench_gpu_time (lambda : _run_single_prefill ())
158+ ms_prefill = np .median (measurements_prefill )
159+
160+ # Batch decode using tensor cores
161+ wrapper_decode = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
162+ workspace_buffer , kv_layout = kv_layout , use_tensor_cores = True
163+ )
164+ wrapper_decode .plan (
165+ d_kv_indptr .to (device ),
166+ kv_indices_d .to (device ),
167+ last_page_len_d ,
168+ num_qo_heads ,
169+ num_kv_heads ,
170+ head_dim ,
171+ page_block_size ,
172+ data_type = torch .bfloat16 ,
173+ q_data_type = torch .bfloat16 ,
174+ )
175+ measurements_decode = bench_gpu_time (lambda : wrapper_decode .run (q_d , kv_d ))
176+ ms_decode = np .median (measurements_decode )
177+ ms_seq_two_kernels = ms_prefill + ms_decode
178+
126179 print (f"Elapsed time (Batched Prefill): { ms_old :.2f} ms" )
127180 if len (p_kv_lens ) == 1 :
128181 print (f"Elapsed time (POD Attention): { ms_pod :.2f} ms" )
182+ print (f"Elapsed time (Sequential two kernels): { ms_seq_two_kernels :.2f} ms" )
183+ print (f"Elapsed time (Persistent BatchAttention): { ms_persistent :.2f} ms" )
129184 total_bytes = (
130185 q .numel () * q .element_size () + kv_data .numel () * kv_data .element_size ()
131186 )
@@ -137,77 +192,29 @@ def run_bench(
137192 if len (p_kv_lens ) == 1 :
138193 bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3 ) / (1024 ** 3 )
139194 print (f"Memory bandwidth (POD Attention): { bandwidth_pod_gb_s :.2f} GB/s" )
195+ bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3 ) / (1024 ** 3 )
196+ print (
197+ f"Memory bandwidth (Sequential two kernels): { bandwidth_seq_gb_s :.2f} GB/s"
198+ )
199+ bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3 ) / (1024 ** 3 )
200+ print (
201+ f"Memory bandwidth (Persistent BatchAttention): { bandwidth_persistent_gb_s :.2f} GB/s"
202+ )
140203
141204
142205if __name__ == "__main__" :
143206 np .random .seed (42 )
144207 torch .random .manual_seed (42 )
145208
146209 # Irregular sequence lengths for prefill and decode
147- d_q_len_configs = [[1 ] * 122 , [1 ] * 128 , [1 ] * 242 , [1 ] * 256 ]
148- d_kv_len_configs = [[600 ] * 122 , [10000 ] * 128 , [400 ] * 242 , [8192 ] * 256 ]
149- p_q_configs = [[17 ] * 1 , [10000 ], [17 ] * 1 , []]
150- p_kv_configs = [[10000 ] * 1 , [10000 ], [8192 ] * 1 , []]
151-
152- # construct random length testcases
153- for _ in range (1 ):
154- bsz = 256
155- stride = 16
156- sparsity = 0.05
157-
158- full_kv_len = np .random .randint (1000 , 8192 , size = bsz )
159- p_q_lens = []
160- p_kv_lens = []
161- d_q_lens = []
162- d_kv_lens = []
163- for i in range (bsz ):
164- if i % stride == 0 :
165- kv_len = full_kv_len [i ]
166- qo_len = stride + 1
167- p_q_lens .append (qo_len )
168- p_kv_lens .append (kv_len )
169- else :
170- kv_len = int (full_kv_len [i ] * sparsity )
171- qo_len = 1
172- d_q_lens .append (qo_len )
173- d_kv_lens .append (kv_len )
174-
175- p_q_configs .append (p_q_lens )
176- p_kv_configs .append (p_kv_lens )
177- d_q_len_configs .append (d_q_lens )
178- d_kv_len_configs .append (d_kv_lens )
179-
180- for _ in range (1 ):
181- bsz = 128
182- stride = 16
183- sparsity = 0.05
184-
185- full_kv_len = np .random .randint (2000 , 16000 , size = bsz )
186- p_q_lens = []
187- p_kv_lens = []
188- d_q_lens = []
189- d_kv_lens = []
190-
191- for i in range (bsz ):
192- if i % stride == 0 :
193- kv_len = full_kv_len [i ]
194- qo_len = stride + 1
195- p_q_lens .append (qo_len )
196- p_kv_lens .append (kv_len )
197- else :
198- kv_len = int (full_kv_len [i ] * sparsity )
199- qo_len = 1
200- d_q_lens .append (qo_len )
201- d_kv_lens .append (kv_len )
202-
203- p_q_configs .append (p_q_lens )
204- p_kv_configs .append (p_kv_lens )
205- d_q_len_configs .append (d_q_lens )
206- d_kv_len_configs .append (d_kv_lens )
210+ d_q_len_configs = [[1 ] * 128 , [1 ] * 128 , [1 ] * 128 , [1 ] * 128 ]
211+ d_kv_len_configs = [[2048 ] * 128 , [4096 ] * 128 , [8192 ] * 128 , [8192 ] * 128 ]
212+ p_q_configs = [[2048 ], [4096 ], [4096 ], [6000 ]]
213+ p_kv_configs = [[2048 ], [4096 ], [4096 ], [7000 ]]
207214
208215 page_block_size = 1
209- num_kv_heads = 4
210- num_qo_heads = 28
216+ num_kv_heads = 8
217+ num_qo_heads = 32
211218 head_dim = 128
212219
213220 for idx , (p_q_lens , p_kv_lens , d_q_len , d_kv_len ) in enumerate (
0 commit comments