Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 0 additions & 322 deletions tests/test_metal_unified_attention.py

This file was deleted.

10 changes: 8 additions & 2 deletions tests/test_primitive_and_donation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def _make_cache_and_inputs(
"num_heads",
[(4, 4), (8, 2)],
)
@pytest.mark.parametrize("sliding_window", [-1, 128])
@pytest.mark.parametrize("num_blocks", [256])
def test_primitive_vs_reference_decode(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
num_blocks: int,
sliding_window: int,
) -> None:
"""paged_attention_primitive matches the pure-MLX reference (decode)."""
mx.random.seed(0)
Expand All @@ -119,7 +121,7 @@ def test_primitive_vs_reference_decode(
d["cu_seqlens_q"],
BLOCK_SIZE,
d["max_kv_len"],
-1, # sliding_window
sliding_window,
out,
)
mx.eval(out)
Expand All @@ -132,6 +134,7 @@ def test_primitive_vs_reference_decode(
kv_lens=d["kv_lens"],
block_tables=np.array(d["block_tables"]),
scale=d["scale"],
sliding_window=sliding_window if sliding_window >= 0 else None,
)
mx.eval(ref)

Expand All @@ -154,11 +157,13 @@ def test_primitive_vs_reference_decode(
"num_heads",
[(4, 4), (8, 2)],
)
@pytest.mark.parametrize("sliding_window", [-1, 128])
@pytest.mark.parametrize("num_blocks", [256])
def test_primitive_vs_reference_varlen(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
num_blocks: int,
sliding_window: int,
) -> None:
"""paged_attention_primitive matches reference for mixed prefill+decode."""
mx.random.seed(0)
Expand All @@ -179,7 +184,7 @@ def test_primitive_vs_reference_varlen(
d["cu_seqlens_q"],
BLOCK_SIZE,
d["max_kv_len"],
-1, # sliding_window
sliding_window,
out,
)
mx.eval(out)
Expand All @@ -192,6 +197,7 @@ def test_primitive_vs_reference_varlen(
kv_lens=d["kv_lens"],
block_tables=np.array(d["block_tables"]),
scale=d["scale"],
sliding_window=sliding_window if sliding_window >= 0 else None,
)
mx.eval(ref)

Expand Down
8 changes: 3 additions & 5 deletions tests/test_sliding_window_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
-> MetalPagedKVCache.sliding_window_per_layer

Kernel-level correctness (that a ``sliding_window`` value actually
masks out-of-window tokens) is separately validated by
``test_metal_unified_attn`` in ``test_metal_unified_attention.py``;
both the production ``paged_attention_primitive`` and the test helper
``metal_unified_attention`` dispatch to the same
``paged_attention_v2_online`` kernel (see ``paged_ops.cpp``).
masks out-of-window tokens) is validated via the production
``paged_attention_primitive`` path which dispatches
``paged_attention_v2_online`` (see ``paged_ops.cpp``).
"""

from __future__ import annotations
Expand Down
Loading
Loading