Skip to content

Commit e6a07f6

Browse files
mgoinamd-xiaoyu12
authored andcommitted
[CI Perf] Prune tests in tests/kernels/quantization/ (vllm-project#22942)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 73cd18f commit e6a07f6

File tree

6 files changed

+66
-33
lines changed

6 files changed

+66
-33
lines changed

tests/kernels/quantization/test_fp8_quant.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
from tests.kernels.utils import opcheck
1212
from vllm.platforms import current_platform
1313

14-
DTYPES = [torch.half, torch.bfloat16, torch.float]
15-
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
16-
8193] # Arbitrary values for testing
17-
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
18-
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
14+
DTYPES = [torch.bfloat16, torch.float]
15+
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
16+
NUM_TOKENS = [1, 7, 4096]
1917
SCALE_UBS = [True, False]
2018
SEEDS = [0]
2119

tests/kernels/quantization/test_int8_quant.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
from vllm._custom_ops import scaled_int8_quant
1010
from vllm.platforms import current_platform
1111

12-
DTYPES = [torch.half, torch.bfloat16, torch.float]
13-
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
14-
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
15-
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
12+
DTYPES = [torch.bfloat16, torch.float]
13+
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
14+
NUM_TOKENS = [1, 7, 4096]
1615
SEEDS = [0]
1716
SCALE = [0.1, 2.1]
1817

tests/kernels/quantization/test_machete_mm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,13 @@
3434

3535
MNK_SHAPES = [
3636
(1, 128, 128),
37-
(1, 512, 1024),
38-
(1, 4096, 4096),
3937
(1, 8192, 28672),
4038
(13, 8192, 4096),
4139
(26, 4096, 8192),
4240
(64, 4096, 4096),
4341
(64, 8192, 28672),
4442
(257, 128, 4096),
4543
(257, 4224, 4160),
46-
(257, 4096, 4096),
47-
(1024, 4096, 8192),
4844
(1024, 8192, 4096),
4945
]
5046

tests/kernels/quantization/test_marlin_gemm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,8 @@
5353
MNK_FACTORS = [
5454
(1, 1, 1),
5555
(1, 4, 8),
56-
(1, 7, 5),
57-
(13, 17, 67),
5856
(26, 37, 13),
59-
(67, 13, 11),
6057
(257, 13, 11),
61-
(658, 13, 11),
6258
]
6359

6460
DTYPES = [torch.float16, torch.bfloat16]

tests/kernels/quantization/test_rocm_skinny_gemms.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,55 @@
88
from vllm.platforms import current_platform
99

1010
DTYPES = [torch.bfloat16, torch.float16]
11-
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
12-
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
13-
N = [1, 2, 3, 4]
11+
# Specific (N, K, M) combinations for targeted testing
12+
NKM_FACTORS_LLMM1 = [
13+
# Small, medium, large cases
14+
(1, 8, 16),
15+
(1, 32, 64),
16+
(1, 128, 256),
17+
(1, 512, 1024),
18+
(1, 2048, 4096),
19+
# Edge cases with specific K sizes
20+
(1, 6144, 1024),
21+
(1, 8192, 2048),
22+
# Very large case
23+
(1, 4096, 8192),
24+
]
25+
26+
NKM_FACTORS_WVSPLITK = [
27+
# Different batch sizes with key dimensions
28+
(1, 16, 16),
29+
(1, 64, 64),
30+
(2, 256, 256),
31+
(3, 1024, 1024),
32+
(4, 4096, 4096),
33+
# Extended K values
34+
(1, 9216, 512),
35+
(2, 10240, 1024),
36+
(4, 16384, 8192),
37+
# Minimum M constraint validation (m >= 8)
38+
(1, 64, 8),
39+
(2, 128, 8),
40+
(4, 256, 8),
41+
]
42+
43+
NKM_FACTORS_WVSPLITK_FP8 = [
44+
# FP8-specific cases with K % 16 == 0
45+
(1, 16, 16),
46+
(1, 64, 64),
47+
(2, 512, 512),
48+
(3, 2048, 2048),
49+
(4, 4096, 4096),
50+
# Extended FP8 dimensions not covered by WVSPLITK
51+
(1, 14336, 1024),
52+
(2, 24576, 2048),
53+
(4, 32768, 28672),
54+
]
55+
1456
SEEDS = [0]
1557

1658

17-
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
18-
@pytest.mark.parametrize("k", K)
19-
@pytest.mark.parametrize("m", M)
59+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
2060
@pytest.mark.parametrize("dtype", DTYPES)
2161
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
2262
@pytest.mark.parametrize("seed", SEEDS)
@@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
3474
assert torch.allclose(out, ref_out, rtol=0.01)
3575

3676

37-
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
38-
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
39-
@pytest.mark.parametrize("m", [8] + M) # m >= 8
77+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
4078
@pytest.mark.parametrize("dtype", DTYPES)
4179
@pytest.mark.parametrize("seed", SEEDS)
4280
@pytest.mark.skipif(not current_platform.is_rocm(),
@@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
5492
assert torch.allclose(out, ref_out, rtol=0.01)
5593

5694

57-
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
58-
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
59-
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
95+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
6096
@pytest.mark.parametrize("dtype", DTYPES)
6197
@pytest.mark.parametrize("seed", SEEDS)
6298
@pytest.mark.skipif(

tests/kernels/quantization/test_triton_scaled_mm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
6060
num_logprobs)
6161

6262

63-
@pytest.mark.parametrize("M", [1, 33, 64, 512])
64-
@pytest.mark.parametrize("N", [256, 971, 20486])
65-
@pytest.mark.parametrize("K", [128, 496, 1024])
66-
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
63+
MNK_FACTORS = [
64+
(1, 256, 128),
65+
(33, 256, 496),
66+
(64, 971, 1024),
67+
(64, 20486, 128),
68+
(512, 256, 496),
69+
(512, 20486, 1024),
70+
]
71+
72+
73+
@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
74+
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
6775
@pytest.mark.parametrize("in_dtype", get_8bit_types())
6876
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
6977
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])

0 commit comments

Comments
 (0)