diff --git a/tests/test_primitive_and_donation.py b/tests/test_primitive_and_donation.py new file mode 100644 index 00000000..90c85e93 --- /dev/null +++ b/tests/test_primitive_and_donation.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the paged attention primitive and scatter-based cache donation. + +Covers two of Eric's review concerns on PR #225: + 1. No test coverage for the primitive path (paged_attention_primitive). + 2. Buffer donation may silently fail, making scatter-based cache writes + O(entire_cache) instead of O(new_tokens). + +Run with: + python -m pytest tests/test_primitive_and_donation.py -v -s +""" + +from __future__ import annotations + +import mlx.core as mx +import numpy as np +import pytest + +from tools.attention_bench_utils import ref_paged_attn +from vllm_metal.metal import get_ops +from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache + +# ── Shared fixtures ────────────────────────────────────────────────────────── + +NUM_KV_HEADS_CASES = [2, 4] +NUM_QUERY_HEADS_CASES = [4, 8] # must be divisible by corresponding kv heads +HEAD_SIZE = 128 +BLOCK_SIZE = 16 +DTYPE = mx.float16 + + +def _make_cache_and_inputs( + num_blocks: int, + num_kv_heads: int, + num_query_heads: int, + seq_lens: list[tuple[int, int]], + *, + dtype: mx.Dtype = DTYPE, +): + """Build a populated cache and matching query/metadata tensors.""" + block_size = BLOCK_SIZE + head_size = HEAD_SIZE + num_seqs = len(seq_lens) + query_lens = [s[0] for s in seq_lens] + kv_lens = [s[1] for s in seq_lens] + total_q = sum(query_lens) + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + key_cache = mx.random.normal( + shape=(num_blocks, block_size, num_kv_heads, head_size) + ).astype(dtype) + value_cache = mx.random.normal( + shape=(num_blocks, block_size, num_kv_heads, head_size) + ).astype(dtype) + query = mx.random.normal(shape=(total_q, num_query_heads, head_size)).astype(dtype) + + max_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = mx.random.randint( + 0, num_blocks, shape=(num_seqs, max_blocks_per_seq) + ).astype(mx.int32) + + kv_lens_arr = mx.array(kv_lens, dtype=mx.int32) + cu_seqlens_q = mx.cumsum(mx.array([0] + query_lens, dtype=mx.int32)) + + mx.eval(key_cache, value_cache, query, block_tables, kv_lens_arr, cu_seqlens_q) + + return { + "query": query, + "key_cache": key_cache, + "value_cache": value_cache, + "num_kv_heads": num_kv_heads, + "scale": scale, + "block_tables": block_tables, + "kv_lens_arr": kv_lens_arr, + "cu_seqlens_q": cu_seqlens_q, + "query_lens": query_lens, + "kv_lens": kv_lens, + "max_kv_len": max_kv_len, + } + + +# ── 1. Primitive correctness tests ────────────────────────────────────────── + + +@pytest.mark.parametrize( + "seq_lens", + [ + [(1, 523), (1, 37), (1, 2011)], + [(1, 1), (1, 128), (1, 2048)], + ], +) +@pytest.mark.parametrize( + "num_heads", + [(4, 4), (8, 2)], +) +@pytest.mark.parametrize("num_blocks", [256]) +def test_primitive_vs_reference_decode( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + num_blocks: int, +) -> None: + """paged_attention_primitive matches the pure-MLX reference (decode).""" + mx.random.seed(0) + num_query_heads, num_kv_heads = num_heads + d = _make_cache_and_inputs(num_blocks, num_kv_heads, num_query_heads, seq_lens) + + ops = get_ops() + out = mx.array(0) + ops.paged_attention_primitive( + d["query"], + d["key_cache"], + d["value_cache"], + d["num_kv_heads"], + d["scale"], + 0.0, # softcap + d["block_tables"], + d["kv_lens_arr"], + d["cu_seqlens_q"], + BLOCK_SIZE, + d["max_kv_len"], + -1, # sliding_window + out, + ) + mx.eval(out) + + ref = ref_paged_attn( + query=d["query"], + key_cache=d["key_cache"], + value_cache=d["value_cache"], + query_lens=d["query_lens"], + kv_lens=d["kv_lens"], + block_tables=np.array(d["block_tables"]), + scale=d["scale"], + ) + mx.eval(ref) + + np.testing.assert_allclose( + np.array(out), + np.array(ref), + atol=1.5e-2, + rtol=1e-2, + ) + + +@pytest.mark.parametrize( + "seq_lens", + [ + [(1, 1328), (5, 18), (129, 463)], + [(1, 523), (1, 37), (1, 2011)], + ], +) +@pytest.mark.parametrize( + "num_heads", + [(4, 4), (8, 2)], +) +@pytest.mark.parametrize("num_blocks", [256]) +def test_primitive_vs_reference_varlen( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + num_blocks: int, +) -> None: + """paged_attention_primitive matches reference for mixed prefill+decode.""" + mx.random.seed(0) + num_query_heads, num_kv_heads = num_heads + d = _make_cache_and_inputs(num_blocks, num_kv_heads, num_query_heads, seq_lens) + + ops = get_ops() + out = mx.array(0) + ops.paged_attention_primitive( + d["query"], + d["key_cache"], + d["value_cache"], + d["num_kv_heads"], + d["scale"], + 0.0, # softcap + d["block_tables"], + d["kv_lens_arr"], + d["cu_seqlens_q"], + BLOCK_SIZE, + d["max_kv_len"], + -1, # sliding_window + out, + ) + mx.eval(out) + + ref = ref_paged_attn( + query=d["query"], + key_cache=d["key_cache"], + value_cache=d["value_cache"], + query_lens=d["query_lens"], + kv_lens=d["kv_lens"], + block_tables=np.array(d["block_tables"]), + scale=d["scale"], + ) + mx.eval(ref) + + np.testing.assert_allclose( + np.array(out), + np.array(ref), + atol=1.5e-2, + rtol=1e-2, + ) + + +# ── 2. Buffer donation test ───────────────────────────────────────────────── + + +@pytest.mark.parametrize("num_blocks", [128, 256]) +def test_scatter_cache_donation(num_blocks: int) -> None: + """Verify that scatter-based cache write reuses the buffer (donation). + + MLX's buffer donation is an optimisation, not a contract. This test + acts as an early-warning: if donation stops happening (due to MLX + changes or unexpected reference leaks), the memory delta will spike + and the assertion will fail. + """ + num_kv_heads = 4 + head_dim = HEAD_SIZE + num_layers = 1 + block_size = BLOCK_SIZE + dtype = DTYPE + + cache = MetalPagedKVCache( + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + num_blocks=num_blocks, + block_size=block_size, + dtype=dtype, + ) + # Already eval'd by MetalPagedKVCache.__init__ + + cache_nbytes = cache.key_caches[0].nbytes # size of one K or V cache + + # Simulate a decode step: scatter 3 tokens into random slots + num_tokens = 3 + slot_indices = mx.array( + np.random.choice(num_blocks * block_size, size=num_tokens, replace=False), + dtype=mx.int64, + ) + new_k = mx.random.normal(shape=(num_tokens, num_kv_heads, head_dim)).astype(dtype) + new_v = mx.random.normal(shape=(num_tokens, num_kv_heads, head_dim)).astype(dtype) + mx.eval(slot_indices, new_k, new_v) + + # Warm up: run multiple rounds so the MLX memory pool stabilises. + # IMPORTANT: delete all locals afterwards — any stray reference to the + # old cache array bumps use_count and defeats buffer donation. + for _ in range(5): + _fk = cache.key_caches[0].reshape(-1, num_kv_heads, head_dim) + _fk[slot_indices] = new_k + _wk = _fk.reshape(cache.key_caches[0].shape) + cache.key_caches[0] = _wk + _fv = cache.value_caches[0].reshape(-1, num_kv_heads, head_dim) + _fv[slot_indices] = new_v + _wv = _fv.reshape(cache.value_caches[0].shape) + cache.value_caches[0] = _wv + mx.eval(cache.key_caches[0], cache.value_caches[0]) + del _fk, _wk, _fv, _wv + + # ── Measure at steady state (average of several rounds) ── + total_delta = 0 + num_rounds = 5 + for _ in range(num_rounds): + mem_before = mx.get_active_memory() + + # K cache scatter + rebind + flat_k = cache.key_caches[0].reshape(-1, num_kv_heads, head_dim) + flat_k[slot_indices] = new_k + new_k_cache = flat_k.reshape(cache.key_caches[0].shape) + cache.key_caches[0] = new_k_cache + + # V cache scatter + rebind + flat_v = cache.value_caches[0].reshape(-1, num_kv_heads, head_dim) + flat_v[slot_indices] = new_v + new_v_cache = flat_v.reshape(cache.value_caches[0].shape) + cache.value_caches[0] = new_v_cache + + mx.eval(cache.key_caches[0], cache.value_caches[0]) + + mem_after = mx.get_active_memory() + total_delta += mem_after - mem_before + del flat_k, new_k_cache, flat_v, new_v_cache + + avg_delta = total_delta / num_rounds + + # If donation works: avg_delta ≈ 0 (buffers reused in-place). + # If donation fails: avg_delta ≈ 2 * cache_nbytes (full copy for K + V). + # Allow generous headroom (one full cache) for pool fluctuations. + threshold = cache_nbytes + assert avg_delta < threshold, ( + f"Buffer donation likely failed: avg memory growth {avg_delta:,.0f} " + f"bytes/round over {num_rounds} rounds, " + f"but each cache is only {cache_nbytes:,} bytes. " + f"Expected near-zero growth with donation." + ) diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 78d1451a..33cb3e5d 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -16,6 +16,7 @@ #include "mlx/mlx.h" #include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" namespace nb = nanobind; using namespace mlx::core; @@ -72,62 +73,50 @@ static std::string dtype_to_metal(Dtype dt) { } // --------------------------------------------------------------------------- -// reshape_and_cache +// reshape_and_cache — dispatch helper + eager binding // --------------------------------------------------------------------------- -void reshape_and_cache_impl( - nb::handle key_h, - nb::handle value_h, - nb::handle key_cache_h, - nb::handle value_cache_h, - nb::handle slot_mapping_h -) { - // Extract C++ arrays from Python handles - auto& key = *nb::inst_ptr(key_h); - auto& value = *nb::inst_ptr(value_h); - auto& key_cache = *nb::inst_ptr(key_cache_h); - auto& value_cache = *nb::inst_ptr(value_cache_h); - auto& slot_mapping = *nb::inst_ptr(slot_mapping_h); - - auto s = default_stream(Device::gpu); - auto& d = metal::device(Device::gpu); +// When called from a primitive's eval_gpu, from_primitive should be true +// to skip ALL add_temporary calls. add_temporary removes buffer pointers +// from the encoder's input/output tracking, defeating fence-based +// synchronisation across command buffer boundaries. Inside a primitive, +// MLX's evaluator already manages array lifetimes via the completion +// handler. In the eager path, add_temporary is needed to keep +// Python-owned arrays alive until the command buffer completes. +static void dispatch_reshape_and_cache( + const array& key, const array& value, + array& key_cache, array& value_cache, + const array& slot_mapping, Stream s, + bool from_primitive = false) { + auto& d = metal::device(s.device); int num_tokens = static_cast(key.shape(0)); int num_heads = static_cast(key.shape(1)); int head_size = static_cast(key.shape(2)); - // Cache layout: [num_blocks, block_size, num_kv_heads, head_size] int block_size = static_cast(key_cache.shape(1)); - // Contiguous strides (arrays must be row-major after mx.eval) int32_t key_stride = static_cast(num_heads * head_size); int32_t value_stride = static_cast(num_heads * head_size); int32_t num_heads_i = static_cast(num_heads); int32_t head_size_i = static_cast(head_size); int32_t block_size_i = static_cast(block_size); - // Kernel name: same kv and cache dtype (no FP8) auto dt = dtype_to_metal(key.dtype()); - std::string kname = - "reshape_and_cache_kv_" + dt + "_cache_" + dt; + std::string kname = "reshape_and_cache_kv_" + dt + "_cache_" + dt; - // Get library & specialise kernel with function constants auto* lib = d.get_library("paged_reshape_cache"); bool use_fp8 = false; auto* kernel = d.get_kernel( kname, lib, kname, {{&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(10)}}); - // Dispatch on the current MLX command encoder auto& enc = d.get_command_encoder(s.index); enc.set_compute_pipeline_state(kernel); - - // Buffer bindings (match reshape_and_cache.metal signature) enc.set_input_array(key, 0); enc.set_input_array(value, 1); enc.set_output_array(key_cache, 2); enc.set_output_array(value_cache, 3); enc.set_input_array(slot_mapping, 4); - // 5, 6: k_scale / v_scale — unused (use_fp8_scales=false) enc.set_bytes(key_stride, 7); enc.set_bytes(value_stride, 8); enc.set_bytes(num_heads_i, 9); @@ -139,12 +128,23 @@ void reshape_and_cache_impl( MTL::Size::Make(num_tokens, 1, 1), MTL::Size::Make(tpg, 1, 1)); - // Keep ALL referenced arrays alive until the command buffer completes - d.add_temporary(key, s.index); - d.add_temporary(value, s.index); - d.add_temporary(key_cache, s.index); - d.add_temporary(value_cache, s.index); - d.add_temporary(slot_mapping, s.index); + if (!from_primitive) { + d.add_temporary(key, s.index); + d.add_temporary(value, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(slot_mapping, s.index); + } +} + +void reshape_and_cache_impl( + nb::handle key_h, nb::handle value_h, + nb::handle key_cache_h, nb::handle value_cache_h, + nb::handle slot_mapping_h) { + dispatch_reshape_and_cache( + *nb::inst_ptr(key_h), *nb::inst_ptr(value_h), + *nb::inst_ptr(key_cache_h), *nb::inst_ptr(value_cache_h), + *nb::inst_ptr(slot_mapping_h), default_stream(Device::gpu)); } // --------------------------------------------------------------------------- @@ -260,9 +260,106 @@ void paged_attention_v1_impl( } // --------------------------------------------------------------------------- -// paged_attention_v2_online — dispatches the online-softmax v2 kernel +// paged_attention_v2_online — dispatch helper + eager wrappers // --------------------------------------------------------------------------- +static void dispatch_paged_attention_v2_online( + array& out, const array& query, + const array& key_cache, const array& value_cache, + int num_kv_heads, float scale, float softcap, + const array& block_tables, const array& seq_lens, + const array& cu_seqlens_q, + int block_size, int max_seq_len, int sliding_window, Stream s, + bool from_primitive = false) { + auto& d = metal::device(s.device); + + int total_q_tokens = static_cast(query.shape(0)); + int num_heads = static_cast(query.shape(1)); + int head_size = static_cast(query.shape(2)); + int max_blocks = static_cast(block_tables.shape(1)); + int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; + + auto dt = dtype_to_metal(query.dtype()); + std::string kname = + "paged_attention_" + dt + "_cache_" + dt + + "_hs" + std::to_string(head_size) + + "_bs" + std::to_string(block_size) + + "_nt256_nsl32_ps0"; + + bool use_partitioning = false; + bool use_alibi = false; + bool use_fp8 = false; + bool use_sinks = false; + + auto* lib = d.get_library("paged_attention_v2_kern"); + auto* kernel = d.get_kernel( + kname, lib, kname + "_v2", + {{&use_partitioning, MTL::DataType::DataTypeBool, NS::UInteger(10)}, + {&use_alibi, MTL::DataType::DataTypeBool, NS::UInteger(20)}, + {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}, + {&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); + + constexpr int NUM_THREADS = 256; + constexpr int NUM_SIMD_LANES = 32; + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + int warp_scores_bytes = NUM_WARPS * block_size + * static_cast(sizeof(float)); + int merge_bytes = (2 * NUM_WARPS + NUM_WARPS * head_size) + * static_cast(sizeof(float)); + size_t shmem = static_cast(std::max(warp_scores_bytes, merge_bytes)); + + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + enc.set_threadgroup_memory_length(shmem, 0); + + enc.set_output_array(out, 2); + enc.set_input_array(query, 3); + enc.set_input_array(key_cache, 4); + enc.set_input_array(value_cache, 5); + + int32_t nkv = static_cast(num_kv_heads); + enc.set_bytes(nkv, 8); + enc.set_bytes(scale, 9); + float softcapping = softcap; + enc.set_bytes(softcapping, 10); + + enc.set_input_array(block_tables, 11); + enc.set_input_array(seq_lens, 12); + + int32_t max_blocks_i = static_cast(max_blocks); + enc.set_bytes(max_blocks_i, 13); + + int32_t q_stride = static_cast(num_heads * head_size); + int32_t kv_block_stride = static_cast(key_cache.strides()[0]); + int32_t kv_head_stride = static_cast(key_cache.strides()[2]); + enc.set_bytes(q_stride, 15); + enc.set_bytes(kv_block_stride, 16); + enc.set_bytes(kv_head_stride, 17); + + enc.set_input_array(cu_seqlens_q, 19); + int32_t num_seqs_i = static_cast(num_seqs); + enc.set_bytes(num_seqs_i, 20); + int32_t sliding_window_i = static_cast(sliding_window); + enc.set_bytes(sliding_window_i, 21); + + enc.dispatch_threadgroups( + MTL::Size::Make(num_heads, total_q_tokens, 1), + MTL::Size::Make(NUM_THREADS, 1, 1)); + + if (!from_primitive) { + d.add_temporary(out, s.index); + d.add_temporary(query, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(block_tables, s.index); + d.add_temporary(seq_lens, s.index); + d.add_temporary(cu_seqlens_q, s.index); + } +} + +// Eager wrapper — keeps the old handle-based API for metal_unified_attention. +// Non-partitioned case delegates to the dispatch helper above; +// partitioned case is handled inline (same as original code on main). void paged_attention_v2_online_impl_common( nb::handle out_h, nb::handle query_h, @@ -290,10 +387,23 @@ void paged_attention_v2_online_impl_common( auto& seq_lens = *nb::inst_ptr(seq_lens_h); auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); + // Non-partitioned case: delegate to the shared dispatch helper + bool needs_partitioning = + exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr; + if (!needs_partitioning) { + dispatch_paged_attention_v2_online( + out, query, key_cache, value_cache, + num_kv_heads, scale, softcap, + block_tables, seq_lens, cu_seqlens_q, + block_size, max_seq_len, sliding_window, + default_stream(Device::gpu)); + return; + } + + // Partitioned path (unchanged from main) auto s = default_stream(Device::gpu); auto& d = metal::device(Device::gpu); - // Varlen: query shape is [total_q_tokens, num_heads, head_size] int total_q_tokens = static_cast(query.shape(0)); int num_heads = static_cast(query.shape(1)); int head_size = static_cast(query.shape(2)); @@ -302,10 +412,8 @@ void paged_attention_v2_online_impl_common( int max_num_partitions = std::max(1, (max_seq_len + kPartitionSize - 1) / kPartitionSize); bool use_partitioning = - exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr && kPartitionSize % block_size == 0 && max_num_partitions > 1; - // Same kernel name format as v1 — the template instantiation is identical. auto dt = dtype_to_metal(query.dtype()); std::string kname = "paged_attention_" + dt + "_cache_" + dt + @@ -318,9 +426,6 @@ void paged_attention_v2_online_impl_common( bool use_fp8 = false; bool use_sinks = sinks != nullptr; - // Use the v2 library (online softmax kernel). - // get_kernel(function_name, library, cache_key, constants) - // Cache key must differ from v1 to avoid collision in MLX's kernel cache. auto* lib = d.get_library("paged_attention_v2_kern"); auto* kernel = d.get_kernel( kname, lib, kname + "_v2", @@ -329,9 +434,6 @@ void paged_attention_v2_online_impl_common( {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}, {&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); - // Threadgroup shared memory for online softmax: - // During KV loop: NUM_WARPS * BLOCK_SIZE floats (per-warp score buffer) - // During merge: 2*NUM_WARPS floats (m, l) + NUM_WARPS * HEAD_SIZE floats (O) constexpr int NUM_THREADS = 256; constexpr int NUM_SIMD_LANES = 32; constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; @@ -345,12 +447,7 @@ void paged_attention_v2_online_impl_common( enc.set_compute_pipeline_state(kernel); enc.set_threadgroup_memory_length(shmem, 0); - // Buffer bindings if (use_partitioning) { - if (exp_sums == nullptr || max_logits == nullptr || tmp_out == nullptr) { - throw std::runtime_error( - "Partitioned v2 attention requires scratch buffers"); - } enc.set_output_array(*exp_sums, 0); enc.set_output_array(*max_logits, 1); enc.set_output_array(*tmp_out, 2); @@ -364,7 +461,6 @@ void paged_attention_v2_online_impl_common( int32_t nkv = static_cast(num_kv_heads); enc.set_bytes(nkv, 8); enc.set_bytes(scale, 9); - // softcap: 0.0 = disabled, >0 = enabled. Passed through to kernel as-is. float softcapping = softcap; enc.set_bytes(softcapping, 10); @@ -384,14 +480,12 @@ void paged_attention_v2_online_impl_common( enc.set_input_array(*sinks, 18); } - // Varlen buffers (new in v2) enc.set_input_array(cu_seqlens_q, 19); int32_t num_seqs_i = static_cast(num_seqs); enc.set_bytes(num_seqs_i, 20); int32_t sliding_window_i = static_cast(sliding_window); enc.set_bytes(sliding_window_i, 21); - // Grid: one threadgroup per (head, query_token) const int32_t grid_z = static_cast(use_partitioning ? max_num_partitions : 1); enc.dispatch_threadgroups( @@ -447,6 +541,80 @@ void paged_attention_v2_online_impl_common( } } +// --------------------------------------------------------------------------- +// Paged attention primitive (read-only): paged_attention_v2_online only. +// +// Single output: attention result. The KV cache is read-only — cache +// writes are handled upstream by MLX-native scatter (pure functional). +// This is a clean pure function: inputs → output, no side effects. +// --------------------------------------------------------------------------- + +class PagedAttentionPrimitive : public UnaryPrimitive { + public: + PagedAttentionPrimitive( + Stream stream, int num_kv_heads, float scale, float softcap, + int block_size, int max_seq_len, int sliding_window) + : UnaryPrimitive(stream), + num_kv_heads_(num_kv_heads), scale_(scale), softcap_(softcap), + block_size_(block_size), max_seq_len_(max_seq_len), + sliding_window_(sliding_window) {} + + void eval_cpu(const std::vector&, array&) override { + throw std::runtime_error( + "PagedAttentionPrimitive only supports GPU"); + } + + void eval_gpu(const std::vector& inputs, array& out) override { + // inputs: [query, key_cache, value_cache, block_tables, seq_lens, + // cu_seqlens_q] + out.set_data(allocator::malloc(out.nbytes())); + dispatch_paged_attention_v2_online( + out, + inputs[0], // query + inputs[1], inputs[2], // key_cache, value_cache + num_kv_heads_, scale_, softcap_, + inputs[3], inputs[4], inputs[5], // block_tables, seq_lens, cu_seqlens_q + block_size_, max_seq_len_, sliding_window_, + stream(), + /*from_primitive=*/true); + } + + const char* name() const override { return "PagedAttention"; } + + bool is_equivalent(const Primitive& other) const override { + auto* rhs = dynamic_cast(&other); + return rhs && rhs->num_kv_heads_ == num_kv_heads_ + && rhs->scale_ == scale_ && rhs->softcap_ == softcap_ + && rhs->block_size_ == block_size_ + && rhs->max_seq_len_ == max_seq_len_ + && rhs->sliding_window_ == sliding_window_; + } + + private: + int num_kv_heads_; + float scale_; + float softcap_; + int block_size_; + int max_seq_len_; + int sliding_window_; +}; + +static array paged_attention_primitive_fn( + const array& query, + const array& key_cache, const array& value_cache, + int num_kv_heads, float scale, float softcap, + const array& block_tables, const array& seq_lens, + const array& cu_seqlens_q, + int block_size, int max_seq_len, int sliding_window) { + auto prim = std::make_shared( + default_stream(Device::gpu), + num_kv_heads, scale, softcap, + block_size, max_seq_len, sliding_window); + return array( + query.shape(), query.dtype(), std::move(prim), + {query, key_cache, value_cache, block_tables, seq_lens, cu_seqlens_q}); +} + void paged_attention_v2_online_impl( nb::handle out_h, nb::handle query_h, @@ -662,6 +830,39 @@ NB_MODULE(_paged_ops, m) { "Online-softmax varlen paged attention (v2) with caller-provided " "partition scratch buffers."); + // Paged attention primitive (read-only): dispatches paged_attention_v2_online. + // Cache writes are handled by MLX-native scatter upstream. + // Uses overwrite_descriptor to bypass cross-module nanobind RTTI. + m.def("paged_attention_primitive", + [](nb::handle query_h, + nb::handle key_cache_h, nb::handle value_cache_h, + int num_kv_heads, float scale, float softcap, + nb::handle block_tables_h, nb::handle seq_lens_h, + nb::handle cu_seqlens_q_h, + int block_size, int max_seq_len, int sliding_window, + nb::handle out_h) { + auto result = paged_attention_primitive_fn( + *nb::inst_ptr(query_h), + *nb::inst_ptr(key_cache_h), + *nb::inst_ptr(value_cache_h), + num_kv_heads, scale, softcap, + *nb::inst_ptr(block_tables_h), + *nb::inst_ptr(seq_lens_h), + *nb::inst_ptr(cu_seqlens_q_h), + block_size, max_seq_len, sliding_window); + nb::inst_ptr(out_h)->overwrite_descriptor(result); + }, + nb::arg("query"), + nb::arg("key_cache"), nb::arg("value_cache"), + nb::arg("num_kv_heads"), nb::arg("scale"), nb::arg("softcap"), + nb::arg("block_tables"), nb::arg("seq_lens"), + nb::arg("cu_seqlens_q"), + nb::arg("block_size"), nb::arg("max_seq_len"), + nb::arg("sliding_window"), + nb::arg("out"), + "Paged attention primitive (read-only). Cache writes are handled " + "by MLX-native scatter upstream."); + m.def("gdn_linear_attention", &gdn_linear_attention_impl, nb::arg("q"), nb::arg("k"), nb::arg("v"), nb::arg("g"), nb::arg("beta"), diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index 2453911d..5bca908a 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -125,29 +125,45 @@ def sdpa_forward( seq_lens = mx.array(ctx.context_lens, dtype=mx.int32) cu_seqlens_q = mx.array(ctx.cu_seqlens, dtype=mx.int32) - # Allocate output buffer before eval so we can materialize everything in one call - out = mx.zeros((L, n_heads, head_dim), dtype=kv_cache.dtype) - mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens, cu_seqlens_q, out) - ops = get_ops() - - # Write K/V into paged cache BEFORE attention — the kernel reads from - # the paged cache via block_table, not from raw tensors. - ops.reshape_and_cache( - k_3d, - v_3d, - kv_cache.key_caches[layer_idx], - kv_cache.value_caches[layer_idx], - slot_mapping, - ) - max_seq_len = max(ctx.context_lens) - ops.paged_attention_v2_online( - out, + # --- Cache write: MLX-native scatter (pure functional, graph-tracked) --- + # Flatten cache to [num_slots, num_kv_heads, head_dim], scatter new K/V + # by slot_mapping, then reshape back. This creates proper graph nodes + # that MLX's evaluator can track for dependency ordering and buffer + # donation — no in-place mutation, no copy_shared_buffer, no const_cast. + # + # DONATION INVARIANT: the rebind (below) must drop the list's reference + # to the old cache *before* mx.eval runs. At eval time the old cache + # must have use_count == 1 (only the graph) for MLX to donate its + # buffer to the scatter output. Do NOT insert mx.eval between the + # scatter and the rebind, or hold extra references to the old cache. + flat_k = kv_cache.key_caches[layer_idx].reshape(-1, kv_cache.num_kv_heads, head_dim) + flat_k[slot_mapping] = k_3d + new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape) + + flat_v = kv_cache.value_caches[layer_idx].reshape( + -1, kv_cache.num_kv_heads, head_dim + ) + flat_v[slot_mapping] = v_3d + new_v_cache = flat_v.reshape(kv_cache.value_caches[layer_idx].shape) + + # Rebind so next layer / decode step uses the updated cache + kv_cache.key_caches[layer_idx] = new_k_cache + kv_cache.value_caches[layer_idx] = new_v_cache + + # --- Attention: paged attention primitive (read-only, fully lazy) --- + # No per-layer eval or sync. The primitive participates in MLX's lazy + # graph and is evaluated by the model runner at the end of the forward + # pass. Fence-based synchronisation across command buffer boundaries + # works correctly because eval_gpu skips add_temporary (which would + # remove buffers from the encoder's fence tracking). + out = mx.array(0) + ops.paged_attention_primitive( q_3d, - kv_cache.key_caches[layer_idx], - kv_cache.value_caches[layer_idx], + new_k_cache, + new_v_cache, kv_cache.num_kv_heads, inner.scale, 0.0, # softcap (0 = disabled) @@ -157,10 +173,9 @@ def sdpa_forward( kv_cache.block_size, max_seq_len, -1, # sliding_window (-1 = disabled) + out, ) - mx.synchronize() - # output: (L, n_heads, head_dim) → (B, L, n_heads * head_dim) out = out.reshape(B, L, n_heads * head_dim) if gate is not None: