88from vllm .platforms import current_platform
99
1010DTYPES = [torch .bfloat16 , torch .float16 ]
11- M = [16 , 32 , 64 , 128 , 256 , 512 , 1024 , 4096 , 8192 ]
12- K = [8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 6144 , 8192 ] # k % 8 == 0
13- N = [1 , 2 , 3 , 4 ]
11+ # Specific (N, K, M) combinations for targeted testing
12+ NKM_FACTORS_LLMM1 = [
13+ # Small, medium, large cases
14+ (1 , 8 , 16 ),
15+ (1 , 32 , 64 ),
16+ (1 , 128 , 256 ),
17+ (1 , 512 , 1024 ),
18+ (1 , 2048 , 4096 ),
19+ # Edge cases with specific K sizes
20+ (1 , 6144 , 1024 ),
21+ (1 , 8192 , 2048 ),
22+ # Very large case
23+ (1 , 4096 , 8192 ),
24+ ]
25+
26+ NKM_FACTORS_WVSPLITK = [
27+ # Different batch sizes with key dimensions
28+ (1 , 16 , 16 ),
29+ (1 , 64 , 64 ),
30+ (2 , 256 , 256 ),
31+ (3 , 1024 , 1024 ),
32+ (4 , 4096 , 4096 ),
33+ # Extended K values
34+ (1 , 9216 , 512 ),
35+ (2 , 10240 , 1024 ),
36+ (4 , 16384 , 8192 ),
37+ # Minimum M constraint validation (m >= 8)
38+ (1 , 64 , 8 ),
39+ (2 , 128 , 8 ),
40+ (4 , 256 , 8 ),
41+ ]
42+
43+ NKM_FACTORS_WVSPLITK_FP8 = [
44+ # FP8-specific cases with K % 16 == 0
45+ (1 , 16 , 16 ),
46+ (1 , 64 , 64 ),
47+ (2 , 512 , 512 ),
48+ (3 , 2048 , 2048 ),
49+ (4 , 4096 , 4096 ),
50+ # Extended FP8 dimensions not covered by WVSPLITK
51+ (1 , 14336 , 1024 ),
52+ (2 , 24576 , 2048 ),
53+ (4 , 32768 , 28672 ),
54+ ]
55+
1456SEEDS = [0 ]
1557
1658
17- @pytest .mark .parametrize ("n" , [1 ]) # only test for batch size 1
18- @pytest .mark .parametrize ("k" , K )
19- @pytest .mark .parametrize ("m" , M )
59+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_LLMM1 )
2060@pytest .mark .parametrize ("dtype" , DTYPES )
2161@pytest .mark .parametrize ("rows_per_block" , [2 , 4 , 8 , 16 ])
2262@pytest .mark .parametrize ("seed" , SEEDS )
@@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
3474 assert torch .allclose (out , ref_out , rtol = 0.01 )
3575
3676
37- @pytest .mark .parametrize ("n" , N ) # only test for batch size <= 4
38- @pytest .mark .parametrize ("k" , K + [9216 , 10240 , 16384 ])
39- @pytest .mark .parametrize ("m" , [8 ] + M ) # m >= 8
77+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK )
4078@pytest .mark .parametrize ("dtype" , DTYPES )
4179@pytest .mark .parametrize ("seed" , SEEDS )
4280@pytest .mark .skipif (not current_platform .is_rocm (),
@@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
5492 assert torch .allclose (out , ref_out , rtol = 0.01 )
5593
5694
57- @pytest .mark .parametrize ("n" , N ) # only test for batch size <= 4
58- @pytest .mark .parametrize ("k" , K [1 :] + [14336 , 24576 , 32768 ]) # k % 16 == 0
59- @pytest .mark .parametrize ("m" , M + [28672 ]) # m >= 16
95+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK_FP8 )
6096@pytest .mark .parametrize ("dtype" , DTYPES )
6197@pytest .mark .parametrize ("seed" , SEEDS )
6298@pytest .mark .skipif (
0 commit comments