Skip to content

Commit 0bf29fa

Browse files
[Test] Remove old non-varlen FA2 test (#28420)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent a5a790e commit 0bf29fa

File tree

1 file changed

+0
-119
lines changed

1 file changed

+0
-119
lines changed

tests/kernels/attention/test_flash_attn.py

Lines changed: 0 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.vllm_flash_attn import (
1010
fa_version_unsupported_reason,
1111
flash_attn_varlen_func,
12-
flash_attn_with_kvcache,
1312
is_fa_version_supported,
1413
)
1514

@@ -83,124 +82,6 @@ def ref_paged_attn(
8382
return torch.cat(outputs, dim=0)
8483

8584

86-
@pytest.mark.parametrize("use_out", [True, False])
87-
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
88-
@pytest.mark.parametrize("num_heads", NUM_HEADS)
89-
@pytest.mark.parametrize("head_size", HEAD_SIZES)
90-
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
91-
@pytest.mark.parametrize("dtype", DTYPES)
92-
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
93-
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
94-
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
95-
@pytest.mark.parametrize("fa_version", [2, 3])
96-
@pytest.mark.parametrize("q_dtype", QDTYPES)
97-
@torch.inference_mode()
98-
def test_flash_attn_with_paged_kv(
99-
use_out: bool,
100-
kv_lens: list[int],
101-
num_heads: tuple[int, int],
102-
head_size: int,
103-
dtype: torch.dtype,
104-
block_size: int,
105-
soft_cap: float | None,
106-
num_blocks: int,
107-
sliding_window: int | None,
108-
fa_version: int,
109-
q_dtype: torch.dtype | None,
110-
) -> None:
111-
torch.set_default_device("cuda")
112-
if not is_fa_version_supported(fa_version):
113-
pytest.skip(
114-
f"Flash attention version {fa_version} not supported due "
115-
f'to: "{fa_version_unsupported_reason(fa_version)}"'
116-
)
117-
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
118-
pytest.skip(
119-
"Flash attention with quantized inputs is only "
120-
"supported on version 3 with bfloat16 base type"
121-
)
122-
123-
current_platform.seed_everything(0)
124-
num_seqs = len(kv_lens)
125-
num_query_heads = num_heads[0]
126-
num_kv_heads = num_heads[1]
127-
assert num_query_heads % num_kv_heads == 0
128-
max_kv_len = max(kv_lens)
129-
scale = head_size**-0.5
130-
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
131-
132-
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
133-
key_cache = torch.randn(
134-
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
135-
)
136-
value_cache = torch.randn_like(key_cache)
137-
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
138-
139-
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
140-
block_tables = torch.randint(
141-
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
142-
)
143-
144-
q = query.unsqueeze(1)
145-
out = torch.empty_like(q) if use_out else None
146-
147-
maybe_quantized_query = q
148-
maybe_quantized_key_cache = key_cache
149-
maybe_quantized_value_cache = value_cache
150-
q_descale = None
151-
k_descale = None
152-
v_descale = None
153-
if q_dtype is not None:
154-
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
155-
maybe_quantized_query = q.to(q_dtype)
156-
maybe_quantized_key_cache = key_cache.to(q_dtype)
157-
maybe_quantized_value_cache = value_cache.to(q_dtype)
158-
159-
scale_shape = (num_seqs, num_kv_heads)
160-
q_descale = torch.ones(scale_shape, dtype=torch.float32)
161-
k_descale = torch.ones(scale_shape, dtype=torch.float32)
162-
v_descale = torch.ones(scale_shape, dtype=torch.float32)
163-
164-
output = flash_attn_with_kvcache(
165-
q=maybe_quantized_query,
166-
k_cache=maybe_quantized_key_cache,
167-
v_cache=maybe_quantized_value_cache,
168-
out=out,
169-
softmax_scale=scale,
170-
causal=True,
171-
block_table=block_tables,
172-
cache_seqlens=kv_lens_tensor,
173-
softcap=soft_cap if soft_cap is not None else 0,
174-
window_size=window_size,
175-
fa_version=fa_version,
176-
q_descale=q_descale,
177-
k_descale=k_descale,
178-
v_descale=v_descale,
179-
)
180-
output = output if not use_out else out
181-
output = output.squeeze(1)
182-
183-
atol, rtol = 1.5e-2, 1e-2
184-
if q_dtype is not None:
185-
atol, rtol = 1.5e-1, 1.5e-1
186-
187-
ref_output = ref_paged_attn(
188-
query=query,
189-
key_cache=key_cache,
190-
value_cache=value_cache,
191-
query_lens=[1] * num_seqs,
192-
kv_lens=kv_lens,
193-
block_tables=block_tables,
194-
scale=scale,
195-
soft_cap=soft_cap,
196-
sliding_window=sliding_window,
197-
)
198-
(
199-
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
200-
f"{torch.max(torch.abs(output - ref_output))}",
201-
)
202-
203-
20485
@pytest.mark.parametrize("use_out", [True, False])
20586
@pytest.mark.parametrize(
20687
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]

0 commit comments

Comments
 (0)