Skip to content

Commit f8856f7

Browse files
mgoinamd-xiaoyu12
authored andcommitted
[CI Perf] Prune tests in tests/kernels/attention/ (vllm-project#22936)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 70e2815 commit f8856f7

File tree

8 files changed

+39
-38
lines changed

8 files changed

+39
-38
lines changed

tests/kernels/attention/test_aiter_flash_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
1010
from vllm.platforms import current_platform
1111

12-
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
12+
NUM_HEADS = [(4, 4), (8, 2)]
1313
HEAD_SIZES = [128, 256]
14-
BLOCK_SIZES = [16, 32]
15-
DTYPES = [torch.float16, torch.bfloat16]
14+
BLOCK_SIZES = [16]
15+
DTYPES = [torch.bfloat16]
1616
QDTYPES = [None]
1717
# one value large enough to test overflow in index calculation.
1818
# one value small enough to test the schema op check

tests/kernels/attention/test_attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,14 @@
2929
NUM_BLOCKS = 4321 # Arbitrary values for testing
3030
PARTITION_SIZE = 512
3131
PARTITION_SIZE_ROCM = 256
32-
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
33-
DTYPES = [
34-
torch.half, torch.bfloat16, torch.float
35-
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
32+
DTYPES = [torch.bfloat16]
3633
NUM_GEN_SEQS = [7] # Arbitrary values for testing
3734
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
3835
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
3936

4037
# This should be sync with get_supported_head_sizes() in
4138
# vllm.attention.ops.paged_attn.PagedAttention
42-
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
39+
HEAD_SIZES = [32, 80, 128, 256]
4340

4441
BLOCK_SIZES = [16, 32]
4542
USE_ALIBI = [False, True]

tests/kernels/attention/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from vllm.platforms import current_platform
1212

1313
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
14-
DTYPES = [torch.half, torch.bfloat16, torch.float]
14+
DTYPES = [torch.bfloat16, torch.float]
1515
NUM_TOKENS = [42] # Arbitrary values for testing
1616
NUM_LAYERS = [1] # Arbitrary values for testing
1717
NUM_HEADS = [8] # Arbitrary values for testing
18-
HEAD_SIZES = [64, 80, 120, 256]
18+
HEAD_SIZES = [64, 80, 256]
1919
BLOCK_SIZES = [8, 16, 32]
2020
CACHE_LAYOUTS = ["NHD", "HND"]
2121

tests/kernels/attention/test_flash_attn.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
flash_attn_with_kvcache,
1313
is_fa_version_supported)
1414

15-
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
15+
NUM_HEADS = [(4, 4), (8, 2)]
1616
HEAD_SIZES = [128, 256]
17-
BLOCK_SIZES = [16, 32]
18-
DTYPES = [torch.float16, torch.bfloat16]
17+
BLOCK_SIZES = [16]
18+
DTYPES = [torch.bfloat16]
1919
QDTYPES = [None, torch.float8_e4m3fn]
2020
# one value large enough to test overflow in index calculation.
2121
# one value small enough to test the schema op check
2222
NUM_BLOCKS = [32768, 2048]
23+
SOFT_CAPS = [None, 50.0]
24+
SLIDING_WINDOWS = [None, 256]
2325

2426

2527
def ref_paged_attn(
@@ -83,9 +85,9 @@ def ref_paged_attn(
8385
@pytest.mark.parametrize("head_size", HEAD_SIZES)
8486
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
8587
@pytest.mark.parametrize("dtype", DTYPES)
86-
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
88+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
8789
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
88-
@pytest.mark.parametrize("sliding_window", [None, 256])
90+
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
8991
@pytest.mark.parametrize("fa_version", [2, 3])
9092
@pytest.mark.parametrize("q_dtype", QDTYPES)
9193
@torch.inference_mode()
@@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv(
198200
@pytest.mark.parametrize("num_heads", NUM_HEADS)
199201
@pytest.mark.parametrize("head_size", HEAD_SIZES)
200202
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
201-
@pytest.mark.parametrize("sliding_window", [None, 256])
203+
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
202204
@pytest.mark.parametrize("dtype", DTYPES)
203-
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
205+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
204206
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
205207
@pytest.mark.parametrize("fa_version", [2, 3])
206208
@pytest.mark.parametrize("q_dtype", QDTYPES)

tests/kernels/attention/test_flashinfer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
from vllm.platforms import current_platform
1111

12-
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
12+
NUM_HEADS = [(32, 8), (6, 1)]
1313
HEAD_SIZES = [128, 256]
1414
BLOCK_SIZES = [16, 32]
15-
DTYPES = [torch.float16, torch.bfloat16]
15+
DTYPES = [torch.bfloat16]
1616
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
17+
SOFT_CAPS = [None, 30.0]
18+
SLIDING_WINDOWS = [None, 64]
1719

1820

1921
def ref_paged_attn(
@@ -76,8 +78,8 @@ def ref_paged_attn(
7678
@pytest.mark.parametrize("head_size", HEAD_SIZES)
7779
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
7880
@pytest.mark.parametrize("dtype", DTYPES)
79-
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
80-
@pytest.mark.parametrize("sliding_window", [None, 64])
81+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
82+
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
8183
@torch.inference_mode
8284
def test_flashinfer_decode_with_paged_kv(
8385
kv_lens: list[int],
@@ -173,8 +175,8 @@ def test_flashinfer_decode_with_paged_kv(
173175
@pytest.mark.parametrize("head_size", HEAD_SIZES)
174176
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
175177
@pytest.mark.parametrize("dtype", DTYPES)
176-
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
177-
@pytest.mark.parametrize("sliding_window", [None, 64])
178+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
179+
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
178180
@torch.inference_mode
179181
def test_flashinfer_prefill_with_paged_kv(
180182
seq_lens: list[tuple[int, int]],
@@ -278,11 +280,11 @@ def test_flashinfer_prefill_with_paged_kv(
278280

279281

280282
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
281-
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
283+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
282284
@pytest.mark.parametrize("head_size", HEAD_SIZES)
283285
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
284286
@pytest.mark.parametrize("dtype", DTYPES)
285-
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
287+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
286288
def test_flashinfer_prefill_with_paged_fp8_kv(
287289
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
288290
head_size: int, dtype: torch.dtype, block_size: int,
@@ -385,11 +387,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
385387

386388

387389
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
388-
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
390+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
389391
@pytest.mark.parametrize("head_size", HEAD_SIZES)
390392
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
391393
@pytest.mark.parametrize("dtype", DTYPES)
392-
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
394+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
395+
@pytest.mark.skip(reason="TODO: fix the accuracy issue")
393396
@torch.inference_mode
394397
def test_flashinfer_decode_with_paged_fp8_kv(
395398
kv_lens: list[int],
@@ -399,7 +402,6 @@ def test_flashinfer_decode_with_paged_fp8_kv(
399402
block_size: int,
400403
soft_cap: Optional[float],
401404
) -> None:
402-
pytest.skip("TODO: fix the accuracy issue")
403405
# test doesn't work for num_heads = (16,16)
404406
torch.set_default_device("cuda")
405407
current_platform.seed_everything(0)

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
MAX_Q_LEN = 1024
2121
MAX_KV_LEN = 4096
2222
BATCH_SIZES = [4, 12]
23-
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
23+
NUM_HEADS = [(16, 16), (40, 8)]
2424
HEAD_SIZES = [128]
25-
BLOCK_SIZES = [16, 32]
25+
BLOCK_SIZES = [16]
2626
KV_LAYOUTS = ["HND"]
27-
DTYPES = [torch.float16, torch.bfloat16]
27+
DTYPES = [torch.bfloat16]
2828
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
2929
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
3030
SOFT_CAPS = [None, 50.0]

tests/kernels/attention/test_prefix_prefill.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
2020

2121
NUM_HEADS = [64]
22-
NUM_QUERIES_PER_KV = [1, 8, 64]
23-
HEAD_SIZES = [128, 96, 24]
22+
NUM_QUERIES_PER_KV = [1, 64]
23+
HEAD_SIZES = [24, 128]
2424
DTYPES = [torch.float16]
2525
CUDA_DEVICES = [
2626
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
2727
]
28-
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
28+
SLIDING_WINDOW = [0, 16, 2048]
2929
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
3030

3131
OPS = [chunked_prefill_paged_decode, context_attention_fwd]

tests/kernels/attention/test_triton_unified_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from vllm.attention.ops.triton_unified_attention import unified_attention
1010
from vllm.platforms import current_platform
1111

12-
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
12+
NUM_HEADS = [(4, 4), (8, 2)]
1313
HEAD_SIZES = [128, 256]
14-
BLOCK_SIZES = [16, 32]
14+
BLOCK_SIZES = [16]
1515

16-
DTYPES = [torch.float16, torch.bfloat16]
16+
DTYPES = [torch.bfloat16]
1717
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
1818
None, torch.float8_e4m3fnuz
1919
]
@@ -85,7 +85,7 @@ def ref_paged_attn(
8585
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
8686
@pytest.mark.parametrize("sliding_window", [None, 256])
8787
@pytest.mark.parametrize("dtype", DTYPES)
88-
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
88+
@pytest.mark.parametrize("soft_cap", [None, 50.0])
8989
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
9090
@pytest.mark.parametrize("q_dtype", QDTYPES)
9191
@torch.inference_mode()

0 commit comments

Comments
 (0)