99
1010from vllm .platforms import current_platform
1111
12- NUM_HEADS = [(16 , 16 ), ( 32 , 8 ), ( 64 , 8 ), (6 , 1 )]
12+ NUM_HEADS = [(32 , 8 ), (6 , 1 )]
1313HEAD_SIZES = [128 , 256 ]
1414BLOCK_SIZES = [16 , 32 ]
15- DTYPES = [torch .float16 , torch . bfloat16 ]
15+ DTYPES = [torch .bfloat16 ]
1616NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
17+ SOFT_CAPS = [None , 30.0 ]
18+ SLIDING_WINDOWS = [None , 64 ]
1719
1820
1921def ref_paged_attn (
@@ -76,8 +78,8 @@ def ref_paged_attn(
7678@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
7779@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
7880@pytest .mark .parametrize ("dtype" , DTYPES )
79- @pytest .mark .parametrize ("soft_cap" , [ None , 30.0 , 50.0 ] )
80- @pytest .mark .parametrize ("sliding_window" , [ None , 64 ] )
81+ @pytest .mark .parametrize ("soft_cap" , SOFT_CAPS )
82+ @pytest .mark .parametrize ("sliding_window" , SLIDING_WINDOWS )
8183@torch .inference_mode
8284def test_flashinfer_decode_with_paged_kv (
8385 kv_lens : list [int ],
@@ -173,8 +175,8 @@ def test_flashinfer_decode_with_paged_kv(
173175@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
174176@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
175177@pytest .mark .parametrize ("dtype" , DTYPES )
176- @pytest .mark .parametrize ("soft_cap" , [ None , 30.0 , 50.0 ] )
177- @pytest .mark .parametrize ("sliding_window" , [ None , 64 ] )
178+ @pytest .mark .parametrize ("soft_cap" , SOFT_CAPS )
179+ @pytest .mark .parametrize ("sliding_window" , SLIDING_WINDOWS )
178180@torch .inference_mode
179181def test_flashinfer_prefill_with_paged_kv (
180182 seq_lens : list [tuple [int , int ]],
@@ -278,11 +280,11 @@ def test_flashinfer_prefill_with_paged_kv(
278280
279281
280282@pytest .mark .parametrize ("seq_lens" , [[(1 , 132 ), (5 , 18 )]])
281- @pytest .mark .parametrize ("num_heads" , [( 32 , 8 ), ( 6 , 1 )] )
283+ @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
282284@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
283285@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
284286@pytest .mark .parametrize ("dtype" , DTYPES )
285- @pytest .mark .parametrize ("soft_cap" , [ None , 30.0 , 50.0 ] )
287+ @pytest .mark .parametrize ("soft_cap" , SOFT_CAPS )
286288def test_flashinfer_prefill_with_paged_fp8_kv (
287289 seq_lens : list [tuple [int , int ]], num_heads : tuple [int , int ],
288290 head_size : int , dtype : torch .dtype , block_size : int ,
@@ -385,11 +387,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
385387
386388
387389@pytest .mark .parametrize ("kv_lens" , [[1328 , 18 , 463 ], [1 , 54 , 293 , 70 ]])
388- @pytest .mark .parametrize ("num_heads" , [( 32 , 8 ), ( 64 , 8 ), ( 6 , 1 )] )
390+ @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
389391@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
390392@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
391393@pytest .mark .parametrize ("dtype" , DTYPES )
392- @pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
394+ @pytest .mark .parametrize ("soft_cap" , SOFT_CAPS )
395+ @pytest .mark .skip (reason = "TODO: fix the accuracy issue" )
393396@torch .inference_mode
394397def test_flashinfer_decode_with_paged_fp8_kv (
395398 kv_lens : list [int ],
@@ -399,7 +402,6 @@ def test_flashinfer_decode_with_paged_fp8_kv(
399402 block_size : int ,
400403 soft_cap : Optional [float ],
401404) -> None :
402- pytest .skip ("TODO: fix the accuracy issue" )
403405 # test doesn't work for num_heads = (16,16)
404406 torch .set_default_device ("cuda" )
405407 current_platform .seed_everything (0 )
0 commit comments