diff --git a/tests/test_metal_unified_attention.py b/tests/test_metal_unified_attention.py deleted file mode 100644 index 8878903c..00000000 --- a/tests/test_metal_unified_attention.py +++ /dev/null @@ -1,322 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# -# Adapted from vLLM's test_triton_unified_attention.py for Metal/MLX. -# -# Compares metal_unified_attention -# against ref_paged_attn (a naive pure-MLX loop implementation that is -# trivially correct). Both receive the same paged KV cache and query -# inputs; the test asserts their outputs match within FP tolerance. - -import mlx.core as mx -import numpy as np -import pytest - -from tools.attention_bench_utils import ref_paged_attn, run_v1_paged_attention -from vllm_metal.metal import PARTITION_THRESHOLD, metal_unified_attention - -# Original upstream parameters (vLLM Triton/CUDA test_triton_unified_attention.py): -# HEAD_SIZES = [128, 256] -# NUM_BLOCKS = [32768, 2048] -# sliding_window = [None, 64, 128, 256] -# DTYPES = [torch.bfloat16] -# Also tested: FP8 output quantization (QDTYPES), 3D decode kernel -# (SEQ_THRESHOLD_3D) — both CUDA-specific, omitted here. -# Current sizes are reduced for Apple Silicon unified memory. -# TODO: try head_size=256 and larger num_blocks as @pytest.mark.slow variants. -NUM_HEADS = [(4, 4), (8, 2), (5, 1)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16] -DTYPES = [mx.float16] - - -# --------------------------------------------------------------------------- -# Shared reference / decode helpers -# --------------------------------------------------------------------------- -# -# The pure-MLX textbook attention reference and the kernel_v1 wrapper are -# implemented in a shared module so the correctness tests and the benchmark -# script exercise the same logic. -# - - -# --------------------------------------------------------------------------- -# Triangle edge: v1 == ref (decode-only) -# -# Validates that the v1 kernel and the pure-MLX reference produce the same -# results for decode-only inputs. This test also validates ref_paged_attn -# itself, so we can trust it as ground truth. -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "seq_lens", - [ - [(1, 523), (1, 37), (1, 2011)], - [(1, 1), (1, 128), (1, 2048)], - ], -) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("num_blocks", [256]) -def test_v1_kernel_vs_reference( - seq_lens: list[tuple[int, int]], - num_heads: tuple[int, int], - head_size: int, - dtype: mx.Dtype, - block_size: int, - num_blocks: int, -) -> None: - """v1 kernel == reference for decode-only inputs. - - Completes the triangle: if v1 == ref and v2 == v1, then v2 == ref. - Also serves as a smoke test for ref_paged_attn correctness. - """ - mx.random.seed(0) - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - assert all(q == 1 for q in query_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = mx.random.normal(shape=(num_seqs, num_query_heads, head_size)).astype(dtype) - key_cache = mx.random.normal( - shape=(num_blocks, block_size, num_kv_heads, head_size) - ).astype(dtype) - value_cache = mx.random.normal( - shape=(num_blocks, block_size, num_kv_heads, head_size) - ).astype(dtype) - kv_lens_arr = mx.array(kv_lens, dtype=mx.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = mx.random.randint( - 0, num_blocks, shape=(num_seqs, max_num_blocks_per_seq) - ).astype(mx.int32) - - v1_output = run_v1_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - scale=scale, - block_tables=block_tables, - seq_lens=kv_lens_arr, - block_size=block_size, - max_seq_len=max_kv_len, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=np.array(block_tables), - scale=scale, - ) - - atol, rtol = 1.5e-2, 1e-2 - np.testing.assert_allclose( - np.array(v1_output), np.array(ref_output), atol=atol, rtol=rtol - ) - - -# --------------------------------------------------------------------------- -# Scaffolding -# Triangle edge: v2 == v1 (decode-only scaffolding) -# -# Freezes parameters to the subset that the existing paged_attention_v1 -# already handles: every sequence has q_len=1, no sliding window, no -# soft_cap. Compares the v2 kernel output against the v1 kernel output -# to prove v2 is a drop-in replacement for decode. This remains useful as -# a focused decode-only regression test alongside the full varlen test below. -# --------------------------------------------------------------------------- -@pytest.mark.parametrize( - "seq_lens", - [ - [(1, 523), (1, 37), (1, 2011)], - [(1, 1), (1, 128), (1, 2048)], - ], -) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("num_blocks", [256]) -def test_metal_unified_attn_decode_only( - seq_lens: list[tuple[int, int]], - num_heads: tuple[int, int], - head_size: int, - dtype: mx.Dtype, - block_size: int, - num_blocks: int, -) -> None: - """Decode-only: all q_len=1, no sliding window, no soft cap. - - Compares the v2 unified kernel against the existing v1 paged_attention - kernel to prove v2 is a drop-in replacement for the decode path. - """ - mx.random.seed(0) - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - assert all(q == 1 for q in query_lens), "Scaffolding test requires q_len=1" - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = mx.random.normal(shape=(num_seqs, num_query_heads, head_size)).astype(dtype) - key_cache = mx.random.normal( - shape=(num_blocks, block_size, num_kv_heads, head_size) - ).astype(dtype) - value_cache = mx.random.normal( - shape=(num_blocks, block_size, num_kv_heads, head_size) - ).astype(dtype) - cu_query_lens = mx.cumsum(mx.array([0] + query_lens, dtype=mx.int32)) - kv_lens_arr = mx.array(kv_lens, dtype=mx.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = mx.random.randint( - 0, num_blocks, shape=(num_seqs, max_num_blocks_per_seq) - ).astype(mx.int32) - - # --- v1 kernel output (known-correct, production code) --- - v1_output = run_v1_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - scale=scale, - block_tables=block_tables, - seq_lens=kv_lens_arr, - block_size=block_size, - max_seq_len=max_kv_len, - ) - - # --- v2 kernel output --- - v2_output = mx.zeros_like(query) - - metal_unified_attention( - q=query, - k=key_cache, - v=value_cache, - out=v2_output, - cu_seqlens_q=cu_query_lens, - seqused_k=kv_lens_arr, - max_seqlen_q=1, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=(-1, -1), - block_table=block_tables, - softcap=0, - ) - - # Partitioned decode changes the reduction order versus v1, so long-context - # cases need a slightly looser tolerance than the no-partition exact-match - # path. - atol = rtol = 1e-4 - if max_kv_len >= PARTITION_THRESHOLD: - atol = rtol = 3e-4 - np.testing.assert_allclose( - np.array(v2_output), np.array(v1_output), atol=atol, rtol=rtol - ) - - -# --------------------------------------------------------------------------- -# Triangle edge: v2 == ref (full varlen unified attention) -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] -) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None, 128]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 50.0]) -@pytest.mark.parametrize("num_blocks", [256]) -def test_metal_unified_attn( - seq_lens: list[tuple[int, int]], - num_heads: tuple[int, int], - head_size: int, - sliding_window: int | None, - dtype: mx.Dtype, - block_size: int, - soft_cap: float | None, - num_blocks: int, -) -> None: - mx.random.seed(0) - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_query_len = max(query_lens) - max_kv_len = max(kv_lens) - window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) - scale = head_size**-0.5 - - query = mx.random.normal( - shape=(sum(query_lens), num_query_heads, head_size) - ).astype(dtype) - key_cache = mx.random.normal( - shape=(num_blocks, block_size, num_kv_heads, head_size) - ).astype(dtype) - value_cache = mx.random.normal( - shape=(num_blocks, block_size, num_kv_heads, head_size) - ).astype(dtype) - cu_query_lens = mx.cumsum(mx.array([0] + query_lens, dtype=mx.int32)) - kv_lens_arr = mx.array(kv_lens, dtype=mx.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = mx.random.randint( - 0, num_blocks, shape=(num_seqs, max_num_blocks_per_seq) - ).astype(mx.int32) - - output = mx.zeros_like(query) - - metal_unified_attention( - q=query, - k=key_cache, - v=value_cache, - out=output, - cu_seqlens_q=cu_query_lens, - seqused_k=kv_lens_arr, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - softcap=soft_cap if soft_cap is not None else 0, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=np.array(block_tables), - scale=scale, - sliding_window=sliding_window, - soft_cap=soft_cap, - ) - - atol, rtol = 1.5e-2, 1e-2 - np.testing.assert_allclose( - np.array(output), np.array(ref_output), atol=atol, rtol=rtol - ) diff --git a/tests/test_primitive_and_donation.py b/tests/test_primitive_and_donation.py index 90c85e93..e435f381 100644 --- a/tests/test_primitive_and_donation.py +++ b/tests/test_primitive_and_donation.py @@ -94,11 +94,13 @@ def _make_cache_and_inputs( "num_heads", [(4, 4), (8, 2)], ) +@pytest.mark.parametrize("sliding_window", [-1, 128]) @pytest.mark.parametrize("num_blocks", [256]) def test_primitive_vs_reference_decode( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], num_blocks: int, + sliding_window: int, ) -> None: """paged_attention_primitive matches the pure-MLX reference (decode).""" mx.random.seed(0) @@ -119,7 +121,7 @@ def test_primitive_vs_reference_decode( d["cu_seqlens_q"], BLOCK_SIZE, d["max_kv_len"], - -1, # sliding_window + sliding_window, out, ) mx.eval(out) @@ -132,6 +134,7 @@ def test_primitive_vs_reference_decode( kv_lens=d["kv_lens"], block_tables=np.array(d["block_tables"]), scale=d["scale"], + sliding_window=sliding_window if sliding_window >= 0 else None, ) mx.eval(ref) @@ -154,11 +157,13 @@ def test_primitive_vs_reference_decode( "num_heads", [(4, 4), (8, 2)], ) +@pytest.mark.parametrize("sliding_window", [-1, 128]) @pytest.mark.parametrize("num_blocks", [256]) def test_primitive_vs_reference_varlen( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], num_blocks: int, + sliding_window: int, ) -> None: """paged_attention_primitive matches reference for mixed prefill+decode.""" mx.random.seed(0) @@ -179,7 +184,7 @@ def test_primitive_vs_reference_varlen( d["cu_seqlens_q"], BLOCK_SIZE, d["max_kv_len"], - -1, # sliding_window + sliding_window, out, ) mx.eval(out) @@ -192,6 +197,7 @@ def test_primitive_vs_reference_varlen( kv_lens=d["kv_lens"], block_tables=np.array(d["block_tables"]), scale=d["scale"], + sliding_window=sliding_window if sliding_window >= 0 else None, ) mx.eval(ref) diff --git a/tests/test_sliding_window_wiring.py b/tests/test_sliding_window_wiring.py index 44a96dae..1715c11f 100644 --- a/tests/test_sliding_window_wiring.py +++ b/tests/test_sliding_window_wiring.py @@ -11,11 +11,9 @@ -> MetalPagedKVCache.sliding_window_per_layer Kernel-level correctness (that a ``sliding_window`` value actually -masks out-of-window tokens) is separately validated by -``test_metal_unified_attn`` in ``test_metal_unified_attention.py``; -both the production ``paged_attention_primitive`` and the test helper -``metal_unified_attention`` dispatch to the same -``paged_attention_v2_online`` kernel (see ``paged_ops.cpp``). +masks out-of-window tokens) is validated via the production +``paged_attention_primitive`` path which dispatches +``paged_attention_v2_online`` (see ``paged_ops.cpp``). """ from __future__ import annotations diff --git a/tools/benchmark/attention_benchmark.py b/tools/benchmark/attention_benchmark.py deleted file mode 100644 index 129449fa..00000000 --- a/tools/benchmark/attention_benchmark.py +++ /dev/null @@ -1,1085 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Benchmark Metal attention backends on shared synthetic workloads. - -Benchmarked backends: -- `v1` (decode-only paged attention) -- `v2` (Metal unified attention) -- `textbook` (pure-MLX reference) -- `sdpa-compute-only` (dense MLX SDPA only) -- `sdpa` (paged gather + dense MLX SDPA) - -Running with no arguments executes the built-in `all` preset group. Built-in -presets run `v1`, `v2`, `textbook`, and `sdpa` by default. Use -`--backend all` when you also want to include `sdpa-compute-only`. - -Examples: - python -m tools.benchmark.attention_benchmark - python -m tools.benchmark.attention_benchmark --group decode - python -m tools.benchmark.attention_benchmark --group small - python -m tools.benchmark.attention_benchmark --cases decode-small,varlen-light - python -m tools.benchmark.attention_benchmark --group decode --num-layers 32 - python -m tools.benchmark.attention_benchmark --group all --backend all - python -m tools.benchmark.attention_benchmark --output-json /tmp/attention.json - python -m tools.benchmark.attention_benchmark --output-csv /tmp/attention.csv - python -m tools.benchmark.attention_benchmark --mode decode --batch-size 8 --kv-lens 2048 - python -m tools.benchmark.attention_benchmark --mode varlen --q-lens 1,4,16,64 --kv-lens 128,256,512,1024 -""" - -from __future__ import annotations - -import argparse -import csv -import json -import math -import statistics -import sys -import time -from collections.abc import Callable -from dataclasses import dataclass -from pathlib import Path - -import mlx.core as mx -import numpy as np - -if __package__ in (None, ""): - raise SystemExit( - "Run this benchmark as a module: python -m tools.benchmark.attention_benchmark" - ) - -from tools.attention_bench_utils import ref_paged_attn, run_v1_paged_attention -from vllm_metal.metal import metal_unified_attention - -ALL_BACKENDS = ["v1", "v2", "textbook", "sdpa-compute-only", "sdpa"] -DTYPE_MAP = { - "float16": mx.float16, - "bfloat16": mx.bfloat16, - "float32": mx.float32, -} -DEFAULTS: dict[str, object] = { - "backend": "v1,v2,textbook,sdpa", - "warmup": 10, - "iters": 100, - "seed": 0, - "num_layers": 1, - "num_q_heads": 8, - "num_kv_heads": 8, - "head_dim": 128, - "block_size": 16, - "num_blocks": 256, - "dtype": "float16", -} -CASES: dict[str, dict[str, object]] = { - "decode-small": { - "mode": "decode", - "batch_size": 1, - "kv_lens": (128,), - }, - "decode-typical": { - "mode": "decode", - "batch_size": 8, - "kv_lens": (2048,), - }, - "decode-big-head": { - "mode": "decode", - "batch_size": 8, - "kv_lens": (2048,), - "num_q_heads": 32, - "num_kv_heads": 8, - "head_dim": 256, - }, - "decode-long": { - "mode": "decode", - "batch_size": 32, - "kv_lens": (8192,), - "num_blocks": 512, - }, - "varlen-light": { - "mode": "varlen", - "q_lens": (1, 4, 16, 64), - "kv_lens": (128, 256, 512, 1024), - }, - "varlen-typical": { - "mode": "varlen", - "q_lens": (32, 64, 128, 256), - "kv_lens": (512, 1024, 2048, 4096), - }, - "varlen-single-long": { - "mode": "varlen", - "q_lens": (256,), - "kv_lens": (4096,), - }, - "varlen-ragged-longtail": { - "mode": "varlen", - "q_lens": (1, 1, 8, 128), - "kv_lens": (4096, 8192, 512, 2048), - "num_blocks": 512, - }, -} -GROUPS: dict[str, tuple[str, ...]] = { - "all": tuple(CASES), - "decode": tuple(name for name in CASES if name.startswith("decode-")), - "varlen": tuple(name for name in CASES if name.startswith("varlen-")), - "small": ("decode-small", "varlen-light"), - "typical": ("decode-typical", "varlen-typical"), - "long": ( - "decode-big-head", - "decode-long", - "varlen-single-long", - "varlen-ragged-longtail", - ), -} -PRESET_FIELDS = ( - "backend", - "warmup", - "iters", - "seed", - "num_layers", - "num_q_heads", - "num_kv_heads", - "head_dim", - "block_size", - "num_blocks", - "dtype", - "mode", - "batch_size", - "q_lens", - "kv_lens", -) - - -def parse_int_list(value: str | tuple[int, ...] | list[int] | None) -> list[int] | None: - if value is None: - return None - if isinstance(value, str): - values = [chunk.strip() for chunk in value.split(",") if chunk.strip()] - if not values: - return None - return [int(v) for v in values] - return [int(v) for v in value] - - -def has_cli_override(flag: str) -> bool: - cli_args = tuple(sys.argv[1:]) - return any(arg == flag or arg.startswith(f"{flag}=") for arg in cli_args) - - -def parse_name_list(text: str, kind: str) -> list[str]: - values = [chunk.strip() for chunk in text.split(",") if chunk.strip()] - if not values: - raise ValueError(f"--{kind}s must include at least one {kind}") - return values - - -@dataclass(frozen=True) -class Workload: - mode: str - query_lens: list[int] - kv_lens: list[int] - num_layers: int - num_q_heads: int - num_kv_heads: int - head_dim: int - block_size: int - num_blocks: int - dtype_name: str - seed: int - - @property - def dtype(self) -> mx.Dtype: - return DTYPE_MAP[self.dtype_name] - - @property - def num_seqs(self) -> int: - return len(self.query_lens) - - @property - def total_q_tokens(self) -> int: - return sum(self.query_lens) - - @property - def max_q_len(self) -> int: - return max(self.query_lens) - - @property - def max_kv_len(self) -> int: - return max(self.kv_lens) - - @property - def scale(self) -> float: - return self.head_dim**-0.5 - - -@dataclass -class WorkloadData: - workload: Workload - queries: list[mx.array] - key_caches: list[mx.array] - value_caches: list[mx.array] - block_tables: mx.array - block_tables_np: np.ndarray - kv_lens_arr: mx.array - cu_query_lens: mx.array - - -@dataclass -class Result: - backend: str - mean_ms: float | None - p50_ms: float | None - p95_ms: float | None - tokens_per_s: float | None - notes: str = "" - - -@dataclass -class CaseRun: - case_name: str - workload: Workload - results: list[Result] - - -def apply_preset(args: argparse.Namespace, preset: dict[str, object]) -> None: - for attr in PRESET_FIELDS: - flag = f"--{attr.replace('_', '-')}" - if not has_cli_override(flag) and attr in preset: - setattr(args, attr, preset[attr]) - - -def manual_workload_requested(args: argparse.Namespace) -> bool: - return any( - value is not None - for value in (args.mode, args.batch_size, args.q_lens, args.kv_lens) - ) - - -def resolve_case_names(args: argparse.Namespace) -> list[str]: - if args.group is not None and args.cases is not None: - raise ValueError("Choose either --group or --cases, not both") - - if args.cases is not None: - case_names = parse_name_list(args.cases, "case") - unknown = [name for name in case_names if name not in CASES] - if unknown: - raise ValueError(f"Unknown case(s): {', '.join(unknown)}") - return case_names - - group_name = args.group or "all" - if group_name not in GROUPS: - raise ValueError(f"Unknown group: {group_name}") - return list(GROUPS[group_name]) - - -def build_case_invocations( - args: argparse.Namespace, -) -> list[tuple[str, argparse.Namespace]]: - if manual_workload_requested(args): - if args.group is not None or args.cases is not None: - raise ValueError( - "Cannot combine manual workload flags with --group or --cases" - ) - case_args = argparse.Namespace(**vars(args)) - apply_preset(case_args, DEFAULTS) - return [("custom", case_args)] - - case_names = resolve_case_names(args) - invocations: list[tuple[str, argparse.Namespace]] = [] - for case_name in case_names: - case_args = argparse.Namespace(**vars(args)) - apply_preset(case_args, DEFAULTS) - apply_preset(case_args, CASES[case_name]) - invocations.append((case_name, case_args)) - return invocations - - -def build_workload(args: argparse.Namespace) -> Workload: - q_lens = parse_int_list(args.q_lens) - kv_lens = parse_int_list(args.kv_lens) - - required_fields = ( - "num_layers", - "num_q_heads", - "num_kv_heads", - "head_dim", - "block_size", - "num_blocks", - "dtype", - "seed", - ) - missing = [field for field in required_fields if getattr(args, field) is None] - if missing: - raise ValueError(f"Missing required benchmark settings: {', '.join(missing)}") - - if args.mode == "decode": - if q_lens is None: - if args.batch_size is None: - raise ValueError("--batch-size is required for decode mode") - q_lens = [1] * args.batch_size - if any(q != 1 for q in q_lens): - raise ValueError("decode mode requires all q_lens to be 1") - - if kv_lens is None: - raise ValueError("--kv-lens is required") - if len(kv_lens) == 1: - kv_lens = kv_lens * len(q_lens) - elif len(kv_lens) != len(q_lens): - raise ValueError("decode mode requires kv_lens length to match batch size") - else: - if q_lens is None or kv_lens is None: - raise ValueError("varlen mode requires both --q-lens and --kv-lens") - if len(q_lens) != len(kv_lens): - raise ValueError("--q-lens and --kv-lens must have the same length") - - if args.num_q_heads % args.num_kv_heads != 0: - raise ValueError("num_q_heads must be divisible by num_kv_heads") - if args.num_layers < 1: - raise ValueError("num_layers must be at least 1") - if args.dtype not in DTYPE_MAP: - raise ValueError(f"Unsupported dtype: {args.dtype}") - return Workload( - mode=args.mode, - query_lens=q_lens, - kv_lens=kv_lens, - num_layers=args.num_layers, - num_q_heads=args.num_q_heads, - num_kv_heads=args.num_kv_heads, - head_dim=args.head_dim, - block_size=args.block_size, - num_blocks=args.num_blocks, - dtype_name=args.dtype, - seed=args.seed, - ) - - -def make_workload_data(workload: Workload) -> WorkloadData: - max_blocks_per_seq = math.ceil(workload.max_kv_len / workload.block_size) - if max_blocks_per_seq > workload.num_blocks: - raise ValueError( - f"num_blocks={workload.num_blocks} is too small for max_kv_len=" - f"{workload.max_kv_len} and block_size={workload.block_size}; need at least " - f"{max_blocks_per_seq}" - ) - - mx.random.seed(workload.seed) - block_tables = mx.random.randint( - 0, - workload.num_blocks, - shape=(workload.num_seqs, max_blocks_per_seq), - ).astype(mx.int32) - kv_lens_arr = mx.array(workload.kv_lens, dtype=mx.int32) - cu_query_lens = mx.cumsum(mx.array([0] + workload.query_lens, dtype=mx.int32)) - queries: list[mx.array] = [] - key_caches: list[mx.array] = [] - value_caches: list[mx.array] = [] - for layer_idx in range(workload.num_layers): - mx.random.seed(workload.seed + layer_idx) - queries.append( - mx.random.normal( - shape=(workload.total_q_tokens, workload.num_q_heads, workload.head_dim) - ).astype(workload.dtype) - ) - key_caches.append( - mx.random.normal( - shape=( - workload.num_blocks, - workload.block_size, - workload.num_kv_heads, - workload.head_dim, - ) - ).astype(workload.dtype) - ) - value_caches.append( - mx.random.normal( - shape=( - workload.num_blocks, - workload.block_size, - workload.num_kv_heads, - workload.head_dim, - ) - ).astype(workload.dtype) - ) - mx.eval( - *queries, - *key_caches, - *value_caches, - block_tables, - kv_lens_arr, - cu_query_lens, - ) - - return WorkloadData( - workload=workload, - queries=queries, - key_caches=key_caches, - value_caches=value_caches, - block_tables=block_tables, - block_tables_np=np.array(block_tables), - kv_lens_arr=kv_lens_arr, - cu_query_lens=cu_query_lens, - ) - - -def make_sdpa_mask( - query_len: int, - kv_len: int, -) -> mx.array: - empty_mask = mx.ones((query_len, kv_len)) - mask = mx.triu(empty_mask, k=kv_len - query_len + 1).astype(mx.bool_) - return mask[None, None, :, :] - - -def gather_dense_sdpa_inputs( - data: WorkloadData, - layer_idx: int, -) -> list[tuple[mx.array, mx.array, mx.array, mx.array]]: - workload = data.workload - query = data.queries[layer_idx] - key_cache = data.key_caches[layer_idx] - value_cache = data.value_caches[layer_idx] - prepared: list[tuple[mx.array, mx.array, mx.array, mx.array]] = [] - start = 0 - for i, query_len in enumerate(workload.query_lens): - kv_len = workload.kv_lens[i] - q = query[start : start + query_len].transpose(1, 0, 2)[None, ...] - - num_kv_blocks = math.ceil(kv_len / workload.block_size) - block_indices = data.block_tables[i, :num_kv_blocks] - k = key_cache[block_indices].reshape( - -1, workload.num_kv_heads, workload.head_dim - )[:kv_len] - v = value_cache[block_indices].reshape( - -1, workload.num_kv_heads, workload.head_dim - )[:kv_len] - k = k.transpose(1, 0, 2)[None, ...] - v = v.transpose(1, 0, 2)[None, ...] - mask = make_sdpa_mask(query_len, kv_len) - prepared.append((q, k, v, mask)) - start += query_len - - mx.eval(*(arr for item in prepared for arr in item)) - return prepared - - -def run_sdpa_from_prepared( - prepared: list[tuple[mx.array, mx.array, mx.array, mx.array]], - scale: float, -) -> mx.array: - outputs: list[mx.array] = [] - for q, k, v, mask in prepared: - out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) - outputs.append(out[0].transpose(1, 0, 2)) - return mx.concatenate(outputs, axis=0) - - -def time_backend( - fn: Callable[[], mx.array], - warmup: int, - iters: int, - total_q_tokens: int, - num_layers: int, -) -> tuple[float, float, float, float]: - for _ in range(warmup): - out = fn() - if out is not None: - mx.eval(out) - - timings_ms: list[float] = [] - for _ in range(iters): - mx.synchronize() - t0 = time.perf_counter() - out = fn() - if out is not None: - mx.eval(out) - mx.synchronize() - timings_ms.append((time.perf_counter() - t0) * 1000.0) - - mean_ms = statistics.fmean(timings_ms) / num_layers - p50_ms = float(np.percentile(timings_ms, 50)) / num_layers - p95_ms = float(np.percentile(timings_ms, 95)) / num_layers - tokens_per_s = total_q_tokens / (mean_ms / 1000.0) - return mean_ms, p50_ms, p95_ms, tokens_per_s - - -def benchmark_backend( - backend: str, - data: WorkloadData, - warmup: int, - iters: int, -) -> Result: - workload = data.workload - notes = "" - - if backend == "v1": - if workload.mode != "decode": - return Result( - backend=backend, - mean_ms=None, - p50_ms=None, - p95_ms=None, - tokens_per_s=None, - notes="unsupported in varlen mode", - ) - - def fn() -> mx.array: - out = None - for layer_idx in range(workload.num_layers): - out = run_v1_paged_attention( - query=data.queries[layer_idx], - key_cache=data.key_caches[layer_idx], - value_cache=data.value_caches[layer_idx], - num_kv_heads=workload.num_kv_heads, - scale=workload.scale, - block_tables=data.block_tables, - seq_lens=data.kv_lens_arr, - block_size=workload.block_size, - max_seq_len=workload.max_kv_len, - ) - assert out is not None - return out - - notes = "decode-only" - elif backend == "v2": - - def fn() -> mx.array: - out = None - for layer_idx in range(workload.num_layers): - out = _run_v2(data, layer_idx) - assert out is not None - return out - elif backend == "textbook": - - def fn() -> mx.array: - out = None - for layer_idx in range(workload.num_layers): - out = ref_paged_attn( - query=data.queries[layer_idx], - key_cache=data.key_caches[layer_idx], - value_cache=data.value_caches[layer_idx], - query_lens=workload.query_lens, - kv_lens=workload.kv_lens, - block_tables=data.block_tables_np, - scale=workload.scale, - sliding_window=None, - soft_cap=None, - ) - assert out is not None - return out - elif backend == "sdpa-compute-only": - prepared_per_layer = [ - gather_dense_sdpa_inputs(data, layer_idx) - for layer_idx in range(workload.num_layers) - ] - - def fn() -> mx.array: - out = None - for prepared in prepared_per_layer: - out = run_sdpa_from_prepared(prepared, workload.scale) - assert out is not None - return out - - notes = "dense compute only" - elif backend == "sdpa": - - def fn() -> mx.array: - out = None - for layer_idx in range(workload.num_layers): - out = run_sdpa_from_prepared( - gather_dense_sdpa_inputs(data, layer_idx), workload.scale - ) - assert out is not None - return out - - notes = "includes gather" - else: - raise ValueError(f"Unknown backend: {backend}") - - try: - mean_ms, p50_ms, p95_ms, tokens_per_s = time_backend( - fn, warmup, iters, workload.total_q_tokens, workload.num_layers - ) - return Result( - backend=backend, - mean_ms=mean_ms, - p50_ms=p50_ms, - p95_ms=p95_ms, - tokens_per_s=tokens_per_s, - notes=notes, - ) - except Exception as exc: - error_note = f"error: {type(exc).__name__}: {exc}" - notes = error_note if not notes else f"{notes}; {error_note}" - return Result( - backend=backend, - mean_ms=None, - p50_ms=None, - p95_ms=None, - tokens_per_s=None, - notes=notes, - ) - - -def _run_v2(data: WorkloadData, layer_idx: int) -> mx.array: - workload = data.workload - out = mx.zeros_like(data.queries[layer_idx]) - metal_unified_attention( - q=data.queries[layer_idx], - k=data.key_caches[layer_idx], - v=data.value_caches[layer_idx], - out=out, - cu_seqlens_q=data.cu_query_lens, - seqused_k=data.kv_lens_arr, - max_seqlen_q=workload.max_q_len, - max_seqlen_k=workload.max_kv_len, - softmax_scale=workload.scale, - causal=True, - window_size=(-1, -1), - block_table=data.block_tables, - softcap=0, - ) - return out - - -def format_query_spec(workload: Workload) -> str: - if workload.mode == "decode": - return f"batch={workload.num_seqs}, q_len=1, kv_len={workload.kv_lens}" - return "seq_lens=" + str( - list(zip(workload.query_lens, workload.kv_lens, strict=False)) - ) - - -def short_query_spec(workload: Workload) -> str: - if workload.mode == "decode": - kv = ( - workload.kv_lens[0] if len(set(workload.kv_lens)) == 1 else workload.kv_lens - ) - return f"B={workload.num_seqs}, q=1, kv={kv}" - pairs = list(zip(workload.query_lens, workload.kv_lens, strict=False)) - if len(pairs) <= 4: - return " ".join(f"{q}/{kv}" for q, kv in pairs) - return ( - f"{len(pairs)} seqs; max_q={workload.max_q_len}; max_kv={workload.max_kv_len}" - ) - - -def valid_results(results: list[Result]) -> list[Result]: - return [result for result in results if result.mean_ms is not None] - - -def mean_ms_key(result: Result) -> float: - assert result.mean_ms is not None - return result.mean_ms - - -def ordered_backends(case_runs: list[CaseRun]) -> list[str]: - present = {result.backend for case_run in case_runs for result in case_run.results} - return [backend for backend in ALL_BACKENDS if backend in present] - - -def case_kind(workload: Workload) -> str: - return "decode" if workload.mode == "decode" else "varlen" - - -def display_case_name(case_run: CaseRun) -> str: - prefix = f"{case_kind(case_run.workload)}-" - if case_run.case_name.startswith(prefix): - return case_run.case_name[len(prefix) :] - return case_run.case_name - - -def backend_label(backend: str) -> str: - return { - "v1": "v1", - "v2": "v2", - "textbook": "textbook", - "sdpa-compute-only": "sdpa-compute-only", - "sdpa": "sdpa", - }.get(backend, backend.replace("-", "_")) - - -def format_time_ms(result: Result | None) -> str: - if result is None: - return "-" - if result.mean_ms is None: - if result.notes.startswith("error:"): - return "ERROR" - return "N/A" - return f"{result.mean_ms:.3f}" - - -def format_vs_best(result: Result | None, best: Result | None) -> str: - if result is None or result.mean_ms is None or best is None or best.mean_ms is None: - return "-" - pct = result.mean_ms / best.mean_ms * 100.0 - if math.isclose(result.mean_ms, best.mean_ms, rel_tol=0.0, abs_tol=1e-9): - return f"{pct:.1f}% best" - return f"{pct:.1f}%" - - -def comparison_headers(backends: list[str], compare_to_fastest: bool) -> list[str]: - headers = ["case", "type", "batch", "shape"] - for backend in backends: - label = backend_label(backend) - headers.append(label) - if compare_to_fastest: - headers.append(f"{label}_vs_best") - return headers - - -def comparison_rows(case_runs: list[CaseRun], backends: list[str]) -> list[list[str]]: - compare_to_fastest = len(backends) > 1 - rows: list[list[str]] = [] - for case_run in case_runs: - results_by_backend = {result.backend: result for result in case_run.results} - best = min( - valid_results(case_run.results), - key=mean_ms_key, - default=None, - ) - row = [ - display_case_name(case_run), - case_kind(case_run.workload), - str(case_run.workload.num_seqs), - short_query_spec(case_run.workload), - ] - for backend in backends: - result = results_by_backend.get(backend) - row.append(format_time_ms(result)) - if compare_to_fastest: - row.append(format_vs_best(result, best)) - rows.append(row) - return rows - - -def print_text_table(headers: list[str], rows: list[list[str]]) -> None: - widths = [len(header) for header in headers] - for row in rows: - for i, cell in enumerate(row): - widths[i] = max(widths[i], len(cell)) - - print(" | ".join(header.ljust(widths[i]) for i, header in enumerate(headers))) - print("-+-".join("-" * width for width in widths)) - for row in rows: - print(" | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row))) - - -def summary_dict( - case_runs: list[CaseRun], args: argparse.Namespace -) -> dict[str, object]: - block_sizes = sorted({run.workload.block_size for run in case_runs}) - dtypes = sorted({run.workload.dtype_name for run in case_runs}) - num_layers = sorted({run.workload.num_layers for run in case_runs}) - seeds = sorted({run.workload.seed for run in case_runs}) - return { - "cases": [run.case_name for run in case_runs], - "num_layers": num_layers[0] if len(num_layers) == 1 else num_layers, - "block_size": block_sizes[0] if len(block_sizes) == 1 else block_sizes, - "dtype": dtypes[0] if len(dtypes) == 1 else dtypes, - "warmup": args.warmup, - "iters": args.iters, - "seed": seeds[0] if len(seeds) == 1 else seeds, - } - - -def comparison_rows_dict( - case_runs: list[CaseRun], backends: list[str] -) -> list[dict[str, object]]: - rows: list[dict[str, object]] = [] - for case_run in case_runs: - row: dict[str, object] = { - "case": display_case_name(case_run), - "case_name": case_run.case_name, - "type": case_kind(case_run.workload), - "batch": case_run.workload.num_seqs, - "shape": short_query_spec(case_run.workload), - } - results_by_backend = {result.backend: result for result in case_run.results} - best = min( - valid_results(case_run.results), - key=mean_ms_key, - default=None, - ) - for backend in backends: - result = results_by_backend.get(backend) - label = backend_label(backend) - row[label] = ( - None - if result is None or result.mean_ms is None - else round(result.mean_ms, 3) - ) - row[f"{label}_vs_best"] = ( - None - if result is None - or result.mean_ms is None - or best is None - or best.mean_ms is None - else round(result.mean_ms / best.mean_ms * 100.0, 1) - ) - rows.append(row) - return rows - - -def json_payload( - case_runs: list[CaseRun], args: argparse.Namespace -) -> dict[str, object]: - backends = ordered_backends(case_runs) - return { - "summary": summary_dict(case_runs, args), - "rows": comparison_rows_dict(case_runs, backends), - } - - -def write_json(path: Path, case_runs: list[CaseRun], args: argparse.Namespace) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(json_payload(case_runs, args), indent=2) + "\n") - - -def write_csv(path: Path, case_runs: list[CaseRun]) -> None: - backends = ordered_backends(case_runs) - rows = comparison_rows_dict(case_runs, backends) - fieldnames = ( - list(rows[0].keys()) - if rows - else ["case", "case_name", "type", "batch", "shape"] - ) - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", newline="") as handle: - writer = csv.DictWriter(handle, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(rows) - - -def write_exports(case_runs: list[CaseRun], args: argparse.Namespace) -> None: - if args.output_json: - write_json(Path(args.output_json).expanduser(), case_runs, args) - if args.output_csv: - write_csv(Path(args.output_csv).expanduser(), case_runs) - - -def print_summary(case_runs: list[CaseRun], args: argparse.Namespace) -> None: - summary = summary_dict(case_runs, args) - summary_parts = [ - f"num_layers: {summary['num_layers']}" - if isinstance(summary["num_layers"], int) - else "num_layers: mixed", - f"block_size: {summary['block_size']}" - if isinstance(summary["block_size"], int) - else "block_size: mixed", - f"dtype: {summary['dtype']}" - if isinstance(summary["dtype"], str) - else "dtype: mixed", - f"warmup: {args.warmup}", - f"iters: {args.iters}", - f"seed: {summary['seed']}" - if isinstance(summary["seed"], int) - else "seed: mixed", - ] - print(" ".join(summary_parts)) - - -def print_case_header(case_run: CaseRun, args: argparse.Namespace) -> None: - print("\nMetal Attention Benchmark") - print(f"case: {case_run.case_name}") - print(f"mode: {case_run.workload.mode}") - print(f"workload: {format_query_spec(case_run.workload)}") - print( - "heads(q/kv): " - f"{case_run.workload.num_q_heads}/{case_run.workload.num_kv_heads} " - f"head_dim: {case_run.workload.head_dim} " - f"block_size: {case_run.workload.block_size} " - f"num_blocks: {case_run.workload.num_blocks} " - f"num_layers: {case_run.workload.num_layers}" - ) - print( - f"dtype: {case_run.workload.dtype_name} warmup: {args.warmup} " - f"iters: {args.iters} seed: {case_run.workload.seed}" - ) - - -def print_results(case_run: CaseRun, args: argparse.Namespace) -> None: - print_case_header(case_run, args) - print() - backends = ordered_backends([case_run]) - headers = comparison_headers(backends, compare_to_fastest=len(backends) > 1) - rows = comparison_rows([case_run], backends) - print_text_table(headers, rows) - write_exports([case_run], args) - - -def print_combined_results(case_runs: list[CaseRun], args: argparse.Namespace) -> None: - print("\nMetal Attention Benchmark") - print(f"cases: {', '.join(run.case_name for run in case_runs)}") - print_summary(case_runs, args) - print() - - backends = ordered_backends(case_runs) - headers = comparison_headers(backends, compare_to_fastest=len(backends) > 1) - rows = comparison_rows(case_runs, backends) - print_text_table(headers, rows) - write_exports(case_runs, args) - - -def resolve_backends(text: str, mode: str) -> list[str]: - if text == "all": - backends = list(ALL_BACKENDS) - else: - backends = [chunk.strip() for chunk in text.split(",") if chunk.strip()] - - invalid = [backend for backend in backends if backend not in ALL_BACKENDS] - if invalid: - raise ValueError(f"Unknown backend(s): {', '.join(invalid)}") - return backends - - -def make_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "--group", - choices=sorted(GROUPS), - default=None, - help="Built-in preset group to run; defaults to all when no manual workload is given", - ) - parser.add_argument( - "--cases", - default=None, - help="Comma-separated explicit preset case names to run", - ) - parser.add_argument( - "--mode", - choices=["decode", "varlen"], - default=None, - help="Manual workload mode; when set, runs one custom case instead of preset cases", - ) - parser.add_argument( - "--backend", - default=None, - help="all|v1|v2|textbook|sdpa-compute-only|sdpa or a comma-separated subset", - ) - parser.add_argument( - "--batch-size", - type=int, - default=None, - help="Decode mode only: number of sequences; implies q_len=1 for each sequence", - ) - parser.add_argument( - "--q-lens", - default=None, - help="Comma-separated query lengths; required for manual varlen mode", - ) - parser.add_argument( - "--kv-lens", - default=None, - help="Comma-separated KV lengths; one value may be repeated across all decode sequences", - ) - parser.add_argument( - "--num-q-heads", - type=int, - default=None, - help="Number of query heads", - ) - parser.add_argument( - "--num-kv-heads", - type=int, - default=None, - help="Number of key/value heads; must divide num-q-heads", - ) - parser.add_argument( - "--head-dim", - type=int, - default=None, - help="Attention head dimension", - ) - parser.add_argument( - "--block-size", - type=int, - default=None, - help="Paged KV block size", - ) - parser.add_argument( - "--num-blocks", - type=int, - default=None, - help="Number of blocks in the synthetic paged KV cache", - ) - parser.add_argument( - "--dtype", - choices=sorted(DTYPE_MAP), - default=None, - help="Element dtype for synthetic inputs", - ) - parser.add_argument( - "--warmup", - type=int, - default=None, - help="Number of warmup iterations before timing", - ) - parser.add_argument( - "--iters", - type=int, - default=None, - help="Number of measured iterations", - ) - parser.add_argument( - "--seed", - type=int, - default=None, - help="Random seed for reproducible synthetic inputs", - ) - parser.add_argument( - "--num-layers", - type=int, - default=None, - help="Number of attention layers to benchmark; timings are reported per layer", - ) - parser.add_argument( - "--output-json", - default=None, - help="Write structured benchmark results to a JSON file", - ) - parser.add_argument( - "--output-csv", - default=None, - help="Write row-oriented benchmark results to a CSV file", - ) - return parser - - -def main() -> None: - parser = make_parser() - args = parser.parse_args() - try: - case_invocations = build_case_invocations(args) - except ValueError as exc: - parser.error(str(exc)) - - case_runs: list[CaseRun] = [] - for case_name, case_args in case_invocations: - if case_args.mode is None: - parser.error("--mode is required for manual workloads") - - try: - workload = build_workload(case_args) - backends = resolve_backends(case_args.backend, workload.mode) - except ValueError as exc: - parser.error(str(exc)) - if not backends: - raise ValueError("No backends selected") - - data = make_workload_data(workload) - results: list[Result] = [] - for backend in backends: - result = benchmark_backend(backend, data, case_args.warmup, case_args.iters) - results.append(result) - case_runs.append(CaseRun(case_name or "custom", workload, results)) - - if len(case_runs) == 1: - print_results(case_runs[0], case_args) - else: - display_args = argparse.Namespace(**vars(args)) - apply_preset(display_args, DEFAULTS) - print_combined_results(case_runs, display_args) - - -if __name__ == "__main__": - main() diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py index 2e84c7e8..cdc375eb 100644 --- a/vllm_metal/metal/__init__.py +++ b/vllm_metal/metal/__init__.py @@ -18,7 +18,7 @@ from pathlib import Path from types import ModuleType -from vllm_metal.metal.constants import PARTITION_SIZE, PARTITION_THRESHOLD +from vllm_metal.metal.constants import PARTITION_SIZE logger = logging.getLogger(__name__) @@ -89,108 +89,6 @@ def _build_gdn_source() -> str: return "\n".join(parts) -def metal_unified_attention( - q, # [total_q_tokens, num_q_heads, head_size] - k, # [num_blocks, block_size, num_kv_heads, head_size] - v, # [num_blocks, block_size, num_kv_heads, head_size] - out, # [total_q_tokens, num_q_heads, head_size] - cu_seqlens_q, # [num_seqs + 1], int32 - seqused_k, # [num_seqs], int32 - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - causal: bool, - window_size: tuple[int, int], - block_table, # [num_seqs, max_blocks_per_seq], int32 - softcap: float, -) -> None: - """Unified varlen paged attention for Metal. - - Supports variable-length queries (prefill + decode) with online softmax, - paged KV cache, causal masking, sliding window, and soft capping. - - Grid: one threadgroup per (head, query_token). Each threadgroup uses - binary search on cu_seqlens_q to find its sequence and computes causal - attention against the paged KV cache. - """ - assert causal, "Only causal attention is supported" - import mlx.core as mx - - # Extract dimensions from cache shape - # k shape: [num_blocks, block_size, num_kv_heads, head_size] - num_kv_heads = k.shape[2] - block_size = k.shape[1] - - # Convert window_size tuple to a single sliding_window int. - # window_size = (left, right) where left = sw-1, right = 0 for causal. - # sliding_window = left + 1 = total window size. -1 = disabled. - if window_size == (-1, -1): - sliding_window = -1 - else: - sliding_window = window_size[0] + 1 - - ops = get_ops() - - # Ensure all inputs are evaluated before raw Metal dispatch - mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q) - max_num_partitions = max(1, (max_seqlen_k + PARTITION_SIZE - 1) // PARTITION_SIZE) - use_partitioning = ( - PARTITION_SIZE % block_size == 0 - and max_seqlen_q == 1 - and max_seqlen_k >= PARTITION_THRESHOLD - and max_num_partitions > 1 - ) - - if use_partitioning: - exp_sums = mx.zeros( - (q.shape[0], q.shape[1], max_num_partitions), dtype=mx.float32 - ) - max_logits = mx.zeros( - (q.shape[0], q.shape[1], max_num_partitions), dtype=mx.float32 - ) - tmp_out = mx.zeros( - (q.shape[0], q.shape[1], max_num_partitions, q.shape[2]), - dtype=q.dtype, - ) - mx.eval(exp_sums, max_logits, tmp_out) - ops.paged_attention_v2_online_partitioned( - out, - q, - k, - v, - num_kv_heads, - softmax_scale, - softcap, - block_table, - seqused_k, - cu_seqlens_q, - block_size, - max_seqlen_k, - sliding_window, - exp_sums, - max_logits, - tmp_out, - ) - mx.synchronize() - else: - ops.paged_attention_v2_online( - out, - q, - k, - v, - num_kv_heads, - softmax_scale, - softcap, - block_table, - seqused_k, - cu_seqlens_q, - block_size, - max_seqlen_k, - sliding_window, - ) - mx.synchronize() - - def get_ops() -> ModuleType: """JIT-build and import the native paged_ops extension. diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 89e16c6b..c8dda124 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -315,7 +315,7 @@ void paged_attention_v1_impl( } // --------------------------------------------------------------------------- -// paged_attention_v2_online — dispatch helper + eager wrappers +// paged_attention_v2_online — dispatch helper (used by PagedAttentionPrimitive) // --------------------------------------------------------------------------- static void dispatch_paged_attention_v2_online( @@ -461,208 +461,6 @@ static void dispatch_paged_attention_v2_online( } } -// Eager wrapper — keeps the old handle-based API for metal_unified_attention. -// Non-partitioned case delegates to the dispatch helper above; -// partitioned case is handled inline (same as original code on main). -void paged_attention_v2_online_impl_common( - nb::handle out_h, - nb::handle query_h, - nb::handle key_cache_h, - nb::handle value_cache_h, - int num_kv_heads, - float scale, - float softcap, - nb::handle block_tables_h, - nb::handle seq_lens_h, - nb::handle cu_seqlens_q_h, - int block_size, - int max_seq_len, - int sliding_window, - array* exp_sums, - array* max_logits, - array* tmp_out, - array* sinks -) { - auto& out = *nb::inst_ptr(out_h); - auto& query = *nb::inst_ptr(query_h); - auto& key_cache = *nb::inst_ptr(key_cache_h); - auto& value_cache = *nb::inst_ptr(value_cache_h); - auto& block_tables = *nb::inst_ptr(block_tables_h); - auto& seq_lens = *nb::inst_ptr(seq_lens_h); - auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); - - // Non-partitioned case: delegate to the shared dispatch helper - bool needs_partitioning = - exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr; - if (!needs_partitioning) { - dispatch_paged_attention_v2_online( - out, query, key_cache, value_cache, - num_kv_heads, scale, softcap, - block_tables, seq_lens, cu_seqlens_q, - block_size, max_seq_len, sliding_window, - default_stream(Device::gpu)); - return; - } - - // Partitioned path (unchanged from main) - auto s = default_stream(Device::gpu); - auto& d = metal::device(Device::gpu); - - int total_q_tokens = static_cast(query.shape(0)); - int num_heads = static_cast(query.shape(1)); - int head_size = static_cast(query.shape(2)); - int max_blocks = static_cast(block_tables.shape(1)); - int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; - int max_num_partitions = - std::max(1, (max_seq_len + kPartitionSize - 1) / kPartitionSize); - bool use_partitioning = - kPartitionSize % block_size == 0 && max_num_partitions > 1; - - auto dt = dtype_to_metal(query.dtype()); - auto k_cache_dt = dtype_to_metal(key_cache.dtype()); - auto v_cache_dt = dtype_to_metal(value_cache.dtype()); - std::string kname = - "paged_attention_" + dt + "_cache_" + k_cache_dt + "_" + v_cache_dt + - "_hs" + std::to_string(head_size) + - "_bs" + std::to_string(block_size) + - "_nt256_nsl32_ps" + - std::to_string(use_partitioning ? kPartitionSize : 0); - - bool use_alibi = false; - bool use_fp8 = false; - bool use_sinks = sinks != nullptr; - bool use_tq_fc = false; - int k_bits_i = 8; - - // Same hash-name discipline as v2: encode every varying function constant - // (use_sinks here) so MLX doesn't reuse a stale specialization. - std::string hash_name = kname + "_v2" - + "_sinks" + (use_sinks ? "1" : "0"); - - auto* lib = d.get_library("paged_attention_v2_kern"); - auto* kernel = d.get_kernel( - kname, lib, hash_name, - {{&use_partitioning, MTL::DataType::DataTypeBool, NS::UInteger(10)}, - {&use_alibi, MTL::DataType::DataTypeBool, NS::UInteger(20)}, - {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}, - {&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}, - {&use_tq_fc, MTL::DataType::DataTypeBool, NS::UInteger(50)}, - {&k_bits_i, MTL::DataType::DataTypeInt, NS::UInteger(60)}}); - - constexpr int NUM_THREADS = 256; - constexpr int NUM_SIMD_LANES = 32; - constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; - int warp_scores_bytes = NUM_WARPS * block_size - * static_cast(sizeof(float)); - int merge_bytes = (2 * NUM_WARPS + NUM_WARPS * head_size) - * static_cast(sizeof(float)); - size_t shmem = static_cast(std::max(warp_scores_bytes, merge_bytes)); - - auto& enc = get_command_encoder_compat(d, s); - enc.set_compute_pipeline_state(kernel); - enc.set_threadgroup_memory_length(shmem, 0); - - if (use_partitioning) { - enc.set_output_array(*exp_sums, 0); - enc.set_output_array(*max_logits, 1); - enc.set_output_array(*tmp_out, 2); - } else { - enc.set_output_array(out, 2); - } - enc.set_input_array(query, 3); - enc.set_input_array(key_cache, 4); - enc.set_input_array(value_cache, 5); - - int32_t nkv = static_cast(num_kv_heads); - enc.set_bytes(nkv, 8); - enc.set_bytes(scale, 9); - float softcapping = softcap; - enc.set_bytes(softcapping, 10); - - enc.set_input_array(block_tables, 11); - enc.set_input_array(seq_lens, 12); - - int32_t max_blocks_i = static_cast(max_blocks); - enc.set_bytes(max_blocks_i, 13); - - int32_t q_stride = static_cast(num_heads * head_size); - int32_t kv_block_stride = static_cast(key_cache.strides()[0]); - int32_t kv_head_stride = static_cast(key_cache.strides()[2]); - enc.set_bytes(q_stride, 15); - enc.set_bytes(kv_block_stride, 16); - enc.set_bytes(kv_head_stride, 17); - if (use_sinks) { - enc.set_input_array(*sinks, 18); - } - - enc.set_input_array(cu_seqlens_q, 19); - int32_t num_seqs_i = static_cast(num_seqs); - enc.set_bytes(num_seqs_i, 20); - int32_t sliding_window_i = static_cast(sliding_window); - enc.set_bytes(sliding_window_i, 21); - - const int32_t grid_z = - static_cast(use_partitioning ? max_num_partitions : 1); - enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, total_q_tokens, grid_z), - MTL::Size::Make(NUM_THREADS, 1, 1)); - - if (use_partitioning) { - std::string reduce_kname = - "paged_attention_v2_reduce_" + dt + - "_hs" + std::to_string(head_size) + - "_nt256_nsl32_ps" + std::to_string(kPartitionSize); - // The reduce kernel now references function_constant(50) (use_turboquant) - // to gate the deferred-FWHT path. Reaching this dispatch from the - // non-TQ wrapper means use_turboquant is always false here; encode it - // in the hash name anyway so a future TQ-partitioned dispatch picks up - // its own specialisation instead of reusing the non-TQ one. - bool reduce_use_tq = false; - auto* reduce_kernel = d.get_kernel( - reduce_kname, - lib, - reduce_kname + "_v2_reduce_tq" + (reduce_use_tq ? "1" : "0"), - {{&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}, - {&reduce_use_tq, MTL::DataType::DataTypeBool, NS::UInteger(50)}}); - size_t reduce_shmem = - static_cast(2 * max_num_partitions * sizeof(float)); - enc.set_compute_pipeline_state(reduce_kernel); - enc.set_threadgroup_memory_length(reduce_shmem, 0); - - enc.set_output_array(out, 0); - enc.set_input_array(*exp_sums, 1); - enc.set_input_array(*max_logits, 2); - enc.set_input_array(*tmp_out, 3); - enc.set_input_array(seq_lens, 4); - int32_t max_num_partitions_i = static_cast(max_num_partitions); - enc.set_bytes(max_num_partitions_i, 5); - if (use_sinks) { - enc.set_input_array(*sinks, 6); - } - enc.set_input_array(cu_seqlens_q, 7); - enc.set_bytes(num_seqs_i, 8); - enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, total_q_tokens, 1), - MTL::Size::Make(NUM_THREADS, 1, 1)); - } - - add_temporary_compat(enc, out, d, s); - add_temporary_compat(enc, query, d, s); - add_temporary_compat(enc, key_cache, d, s); - add_temporary_compat(enc, value_cache, d, s); - add_temporary_compat(enc, block_tables, d, s); - add_temporary_compat(enc, seq_lens, d, s); - add_temporary_compat(enc, cu_seqlens_q, d, s); - if (use_partitioning) { - add_temporary_compat(enc, *exp_sums, d, s); - add_temporary_compat(enc, *max_logits, d, s); - add_temporary_compat(enc, *tmp_out, d, s); - } - if (use_sinks) { - add_temporary_compat(enc, *sinks, d, s); - } -} - // --------------------------------------------------------------------------- // Paged attention primitive (read-only): paged_attention_v2_online only. // @@ -765,82 +563,6 @@ static array paged_attention_primitive_fn( {query, key_cache, value_cache, block_tables, seq_lens, cu_seqlens_q}); } -void paged_attention_v2_online_impl( - nb::handle out_h, - nb::handle query_h, - nb::handle key_cache_h, - nb::handle value_cache_h, - int num_kv_heads, - float scale, - float softcap, - nb::handle block_tables_h, - nb::handle seq_lens_h, - nb::handle cu_seqlens_q_h, - int block_size, - int max_seq_len, - int sliding_window -) { - paged_attention_v2_online_impl_common( - out_h, - query_h, - key_cache_h, - value_cache_h, - num_kv_heads, - scale, - softcap, - block_tables_h, - seq_lens_h, - cu_seqlens_q_h, - block_size, - max_seq_len, - sliding_window, - nullptr, - nullptr, - nullptr, - nullptr); -} - -void paged_attention_v2_online_partitioned_impl( - nb::handle out_h, - nb::handle query_h, - nb::handle key_cache_h, - nb::handle value_cache_h, - int num_kv_heads, - float scale, - float softcap, - nb::handle block_tables_h, - nb::handle seq_lens_h, - nb::handle cu_seqlens_q_h, - int block_size, - int max_seq_len, - int sliding_window, - nb::handle exp_sums_h, - nb::handle max_logits_h, - nb::handle tmp_out_h -) { - auto& exp_sums = *nb::inst_ptr(exp_sums_h); - auto& max_logits = *nb::inst_ptr(max_logits_h); - auto& tmp_out = *nb::inst_ptr(tmp_out_h); - paged_attention_v2_online_impl_common( - out_h, - query_h, - key_cache_h, - value_cache_h, - num_kv_heads, - scale, - softcap, - block_tables_h, - seq_lens_h, - cu_seqlens_q_h, - block_size, - max_seq_len, - sliding_window, - &exp_sums, - &max_logits, - &tmp_out, - nullptr); -} - // --------------------------------------------------------------------------- // tq_encode — fused TurboQuant encode + paged scatter // @@ -1202,31 +924,6 @@ NB_MODULE(_paged_ops, m) { nb::arg("block_size"), nb::arg("max_seq_len"), "Zero-copy paged attention (v1, no partitioning)."); - m.def("paged_attention_v2_online", &paged_attention_v2_online_impl, - nb::arg("out"), nb::arg("query"), - nb::arg("key_cache"), nb::arg("value_cache"), - nb::arg("num_kv_heads"), nb::arg("scale"), - nb::arg("softcap"), - nb::arg("block_tables"), nb::arg("seq_lens"), - nb::arg("cu_seqlens_q"), - nb::arg("block_size"), nb::arg("max_seq_len"), - nb::arg("sliding_window"), - "Online-softmax varlen paged attention (v2, unified prefill+decode)."); - - m.def("paged_attention_v2_online_partitioned", - &paged_attention_v2_online_partitioned_impl, - nb::arg("out"), nb::arg("query"), - nb::arg("key_cache"), nb::arg("value_cache"), - nb::arg("num_kv_heads"), nb::arg("scale"), - nb::arg("softcap"), - nb::arg("block_tables"), nb::arg("seq_lens"), - nb::arg("cu_seqlens_q"), - nb::arg("block_size"), nb::arg("max_seq_len"), - nb::arg("sliding_window"), - nb::arg("exp_sums"), nb::arg("max_logits"), nb::arg("tmp_out"), - "Online-softmax varlen paged attention (v2) with caller-provided " - "partition scratch buffers."); - // Paged attention primitive (read-only): dispatches paged_attention_v2_online. // Cache writes are handled by MLX-native scatter upstream. // Uses overwrite_descriptor to bypass cross-module nanobind RTTI.