@@ -207,7 +207,7 @@ def test_deepseek_prefill(
207207
208208@pytest .mark .parametrize ("batch_size" , [1 , 4 , 8 , 16 ])
209209@pytest .mark .parametrize ("seq_len" , [11 , 12 , 99 , 1763 , 9999 , 32767 ])
210- @pytest .mark .parametrize ("page_size" , [1 ]) # [1 , 16])
210+ @pytest .mark .parametrize ("page_size" , [1 , 16 ])
211211@pytest .mark .parametrize ("num_qo_heads" , [1 , 4 , 8 ])
212212@pytest .mark .parametrize ("num_kv_heads" , [1 , 4 , 8 ])
213213@pytest .mark .parametrize ("causal" , [False , True ])
@@ -267,8 +267,7 @@ def test_batch_paged_prefill(
267267 kv_indptr = torch .arange (
268268 0 , batch_size * num_pages_per_request + 1 , num_pages_per_request
269269 ).int ()
270- # NOTE(Zihao): pad 256 elements to avoid out-of-bound because we didn't check the boundary in the kernel
271- kv_indices = torch .arange (0 , batch_size * num_pages_per_request + 256 ).int ()
270+ kv_indices = torch .arange (0 , batch_size * num_pages_per_request ).int ()
272271 last_page_len = torch .full ((batch_size ,), last_page_len , dtype = torch .int32 )
273272
274273 wrapper_sm80 .plan (
0 commit comments