Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 16 additions & 117 deletions tests/test_metal_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# Adapted from vLLM's test_triton_unified_attention.py for Metal/MLX.
#
# Compares metal_unified_attention (the Metal kernel under development)
# Compares metal_unified_attention
# against ref_paged_attn (a naive pure-MLX loop implementation that is
# trivially correct). Both receive the same paged KV cache and query
# inputs; the test asserts their outputs match within FP tolerance.
Expand All @@ -12,6 +12,7 @@
import numpy as np
import pytest

from tools.attention_bench_utils import ref_paged_attn, run_v1_paged_attention
from vllm_metal.metal import metal_unified_attention

# Original upstream parameters (vLLM Triton/CUDA test_triton_unified_attention.py):
Expand All @@ -30,84 +31,21 @@


# ---------------------------------------------------------------------------
# Pure-MLX reference implementation
# Shared reference / decode helpers
# ---------------------------------------------------------------------------


def ref_paged_attn(
query: mx.array,
key_cache: mx.array,
value_cache: mx.array,
query_lens: list[int],
kv_lens: list[int],
block_tables: np.ndarray,
scale: float,
sliding_window: int | None = None,
soft_cap: float | None = None,
) -> mx.array:
"""Pure-MLX reference: gather K/V from paged cache, compute attention.

Processes each sequence independently with naive quadratic attention.
Supports GQA (num_q_heads != num_kv_heads), sliding window, and soft cap.
"""
num_seqs = len(query_lens)
_, block_size, num_kv_heads, head_size = key_cache.shape

outputs: list[mx.array] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx : start_idx + query_len]
q = q * scale

num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = mx.array(block_tables[i, :num_kv_blocks])

k = key_cache[block_indices].reshape(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].reshape(-1, num_kv_heads, head_size)
v = v[:kv_len]

# GQA: expand kv heads to match query heads
if q.shape[1] != k.shape[1]:
n_rep = q.shape[1] // k.shape[1]
k = mx.repeat(k, n_rep, axis=1)
v = mx.repeat(v, n_rep, axis=1)

attn = mx.einsum("qhd,khd->hqk", q, k).astype(mx.float32)

# Causal mask: True where attention should be masked out
empty_mask = mx.ones((query_len, kv_len))
mask = mx.triu(empty_mask, k=kv_len - query_len + 1).astype(mx.bool_)

if sliding_window is not None:
sliding_window_mask = mx.logical_not(
mx.triu(empty_mask, k=kv_len - (query_len + sliding_window) + 1).astype(
mx.bool_
)
)
mask = mx.logical_or(mask, sliding_window_mask)

if soft_cap is not None and soft_cap > 0:
attn = soft_cap * mx.tanh(attn / soft_cap)

attn = mx.where(mask, float("-inf"), attn)
attn = mx.softmax(attn, axis=-1).astype(v.dtype)
out = mx.einsum("hqk,khd->qhd", attn, v)

outputs.append(out)
start_idx += query_len

return mx.concatenate(outputs, axis=0)
#
# The pure-MLX textbook attention reference and the kernel_v1 wrapper are
# implemented in a shared module so the correctness tests and the benchmark
# script exercise the same logic.
#


# ---------------------------------------------------------------------------
# Triangle edge: v1 == ref (decode-only)
#
# Validates that the v1 kernel and the pure-MLX reference produce the same
# results for decode-only inputs. This test runs TODAY (no v2 needed) and
# also validates ref_paged_attn itself, so we can trust it as ground truth.
# results for decode-only inputs. This test also validates ref_paged_attn
# itself, so we can trust it as ground truth.
# ---------------------------------------------------------------------------


Expand Down Expand Up @@ -161,7 +99,7 @@ def test_v1_kernel_vs_reference(
0, num_blocks, shape=(num_seqs, max_num_blocks_per_seq)
).astype(mx.int32)

v1_output = _run_v1_paged_attention(
v1_output = run_v1_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
Expand Down Expand Up @@ -195,49 +133,10 @@ def test_v1_kernel_vs_reference(
#
# Freezes parameters to the subset that the existing paged_attention_v1
# already handles: every sequence has q_len=1, no sliding window, no
# soft_cap. Compares the v2 kernel output against the v1 kernel output
# to prove v2 is a drop-in replacement for decode. Get this green first
# when building the v2 kernel, then graduate to the full varlen test below.
#
# DELETE this test once test_metal_unified_attn passes.
# soft_cap. Compares the v2 kernel output against the v1 kernel output
# to prove v2 is a drop-in replacement for decode. This remains useful as
# a focused decode-only regression test alongside the full varlen test below.
# ---------------------------------------------------------------------------


def _run_v1_paged_attention(
query: mx.array,
key_cache: mx.array,
value_cache: mx.array,
num_kv_heads: int,
scale: float,
block_tables: mx.array,
seq_lens: mx.array,
block_size: int,
max_seq_len: int,
) -> mx.array:
"""Run the existing v1 paged_attention kernel and return the output."""
from vllm_metal.metal import get_ops

ops = get_ops()

out = mx.zeros_like(query)
mx.eval(out, query, key_cache, value_cache, block_tables, seq_lens)

ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
)
mx.synchronize()
return out


@pytest.mark.parametrize(
"seq_lens",
[
Expand Down Expand Up @@ -290,7 +189,7 @@ def test_metal_unified_attn_decode_only(
).astype(mx.int32)

# --- v1 kernel output (known-correct, production code) ---
v1_output = _run_v1_paged_attention(
v1_output = run_v1_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
Expand All @@ -302,7 +201,7 @@ def test_metal_unified_attn_decode_only(
max_seq_len=max_kv_len,
)

# --- v2 kernel output (under development) ---
# --- v2 kernel output ---
v2_output = mx.zeros_like(query)

metal_unified_attention(
Expand Down
64 changes: 64 additions & 0 deletions tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Tools

## Attention Benchmark

The repository includes a local benchmark utility for comparing Metal attention backends:

```bash
source .venv-vllm-metal/bin/activate
python -m tools.benchmark.attention_benchmark
```

Running with no arguments executes the built-in `all` preset group and prints one combined text table to stdout.
By default, presets run `v1`, `v2`, `textbook`, and `sdpa`. Use `--backend all` when you also want `sdpa-compute-only`.
`num_layers` is supported as a shared benchmark setting; multi-layer runs repeat the same workload across layers and report per-layer latency.

Built-in groups:
- `all`: every built-in case
- `decode`: all decode cases
- `varlen`: all varlen cases
- `small`: `decode-small` + `varlen-light`
- `typical`: `decode-typical` + `varlen-typical`
- `long`: `decode-big-head` + `decode-long` + `varlen-single-long` + `varlen-ragged-longtail`

Built-in cases:
- `decode-small`
- `decode-typical`
- `decode-big-head`
- `decode-long`
- `varlen-light`
- `varlen-typical`
- `varlen-single-long`
- `varlen-ragged-longtail`

Useful examples:

```bash
# Run the default all group
python -m tools.benchmark.attention_benchmark

# Run a built-in group
python -m tools.benchmark.attention_benchmark --group decode
python -m tools.benchmark.attention_benchmark --group varlen
python -m tools.benchmark.attention_benchmark --group typical
python -m tools.benchmark.attention_benchmark --group long

# Run explicit cases
python -m tools.benchmark.attention_benchmark --cases decode-small,varlen-light

# Include sdpa-compute-only in addition to the default backends
python -m tools.benchmark.attention_benchmark --group all --backend all

# Write structured exports in addition to the stdout table
python -m tools.benchmark.attention_benchmark --group decode --output-json /tmp/attention.json
python -m tools.benchmark.attention_benchmark --group decode --output-csv /tmp/attention.csv

# Override shared benchmark settings on a built-in preset run
python -m tools.benchmark.attention_benchmark --group decode --num-layers 10 --iters 200

# Define a manual workload
python -m tools.benchmark.attention_benchmark --mode decode --batch-size 8 --kv-lens 2048

# Define a manual varlen workload
python -m tools.benchmark.attention_benchmark --mode varlen --q-lens 1,4,16,64 --kv-lens 128,256,512,1024
```
95 changes: 95 additions & 0 deletions tools/attention_bench_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
Comment thread
Kingwl marked this conversation as resolved.
"""Shared helpers for attention correctness tests and benchmarks."""

from __future__ import annotations

import mlx.core as mx
import numpy as np


def ref_paged_attn(
query: mx.array,
key_cache: mx.array,
value_cache: mx.array,
query_lens: list[int],
kv_lens: list[int],
block_tables: np.ndarray,
scale: float,
sliding_window: int | None = None,
soft_cap: float | None = None,
) -> mx.array:
"""Pure-MLX reference: gather K/V from paged cache, compute attention."""
_, block_size, num_kv_heads, head_size = key_cache.shape

outputs: list[mx.array] = []
start_idx = 0
for i, query_len in enumerate(query_lens):
kv_len = kv_lens[i]
q = query[start_idx : start_idx + query_len] * scale

num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = mx.array(block_tables[i, :num_kv_blocks])

k = key_cache[block_indices].reshape(-1, num_kv_heads, head_size)[:kv_len]
v = value_cache[block_indices].reshape(-1, num_kv_heads, head_size)[:kv_len]

if q.shape[1] != k.shape[1]:
n_rep = q.shape[1] // k.shape[1]
k = mx.repeat(k, n_rep, axis=1)
v = mx.repeat(v, n_rep, axis=1)

attn = mx.einsum("qhd,khd->hqk", q, k).astype(mx.float32)

empty_mask = mx.ones((query_len, kv_len))
mask = mx.triu(empty_mask, k=kv_len - query_len + 1).astype(mx.bool_)

if sliding_window is not None:
sliding_window_mask = mx.logical_not(
mx.triu(empty_mask, k=kv_len - (query_len + sliding_window) + 1).astype(
mx.bool_
)
)
mask = mx.logical_or(mask, sliding_window_mask)

if soft_cap is not None and soft_cap > 0:
attn = soft_cap * mx.tanh(attn / soft_cap)

attn = mx.where(mask, float("-inf"), attn)
attn = mx.softmax(attn, axis=-1).astype(v.dtype)
outputs.append(mx.einsum("hqk,khd->qhd", attn, v))
start_idx += query_len

return mx.concatenate(outputs, axis=0)


def run_v1_paged_attention(
query: mx.array,
key_cache: mx.array,
value_cache: mx.array,
num_kv_heads: int,
scale: float,
block_tables: mx.array,
seq_lens: mx.array,
block_size: int,
max_seq_len: int,
) -> mx.array:
"""Run kernel_v1 paged attention."""
from vllm_metal.metal import get_ops

ops = get_ops()
out = mx.zeros_like(query)
mx.eval(out, query, key_cache, value_cache, block_tables, seq_lens)
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
)
mx.synchronize()
return out
Loading
Loading