From 7a2ec9338cc5b12ef5f8dc6a6210fd0b89c0f360 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 2 Apr 2026 23:26:41 -0500 Subject: [PATCH 1/7] poc primitive by copy_shared_buffer Signed-off-by: ran --- vllm_metal/metal/paged_ops.cpp | 408 ++++++++++++------ .../metal_kernel_backend/attention_sdpa.py | 42 +- 2 files changed, 311 insertions(+), 139 deletions(-) diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 1932ad87..4afaf850 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -16,6 +16,7 @@ #include "mlx/mlx.h" #include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" namespace nb = nanobind; using namespace mlx::core; @@ -72,62 +73,42 @@ static std::string dtype_to_metal(Dtype dt) { } // --------------------------------------------------------------------------- -// reshape_and_cache +// reshape_and_cache — dispatch helper + eager binding // --------------------------------------------------------------------------- -void reshape_and_cache_impl( - nb::handle key_h, - nb::handle value_h, - nb::handle key_cache_h, - nb::handle value_cache_h, - nb::handle slot_mapping_h -) { - // Extract C++ arrays from Python handles - auto& key = *nb::inst_ptr(key_h); - auto& value = *nb::inst_ptr(value_h); - auto& key_cache = *nb::inst_ptr(key_cache_h); - auto& value_cache = *nb::inst_ptr(value_cache_h); - auto& slot_mapping = *nb::inst_ptr(slot_mapping_h); - - auto s = default_stream(Device::gpu); - auto& d = metal::device(Device::gpu); +static void dispatch_reshape_and_cache( + const array& key, const array& value, + array& key_cache, array& value_cache, + const array& slot_mapping, Stream s) { + auto& d = metal::device(s.device); int num_tokens = static_cast(key.shape(0)); int num_heads = static_cast(key.shape(1)); int head_size = static_cast(key.shape(2)); - // Cache layout: [num_blocks, block_size, num_kv_heads, head_size] int block_size = static_cast(key_cache.shape(1)); - // Contiguous strides (arrays must be row-major after mx.eval) int32_t key_stride = static_cast(num_heads * head_size); int32_t value_stride = static_cast(num_heads * head_size); int32_t num_heads_i = static_cast(num_heads); int32_t head_size_i = static_cast(head_size); int32_t block_size_i = static_cast(block_size); - // Kernel name: same kv and cache dtype (no FP8) auto dt = dtype_to_metal(key.dtype()); - std::string kname = - "reshape_and_cache_kv_" + dt + "_cache_" + dt; + std::string kname = "reshape_and_cache_kv_" + dt + "_cache_" + dt; - // Get library & specialise kernel with function constants auto* lib = d.get_library("paged_reshape_cache"); bool use_fp8 = false; auto* kernel = d.get_kernel( kname, lib, kname, {{&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(10)}}); - // Dispatch on the current MLX command encoder auto& enc = d.get_command_encoder(s.index); enc.set_compute_pipeline_state(kernel); - - // Buffer bindings (match reshape_and_cache.metal signature) enc.set_input_array(key, 0); enc.set_input_array(value, 1); enc.set_output_array(key_cache, 2); enc.set_output_array(value_cache, 3); enc.set_input_array(slot_mapping, 4); - // 5, 6: k_scale / v_scale — unused (use_fp8_scales=false) enc.set_bytes(key_stride, 7); enc.set_bytes(value_stride, 8); enc.set_bytes(num_heads_i, 9); @@ -139,7 +120,6 @@ void reshape_and_cache_impl( MTL::Size::Make(num_tokens, 1, 1), MTL::Size::Make(tpg, 1, 1)); - // Keep ALL referenced arrays alive until the command buffer completes d.add_temporary(key, s.index); d.add_temporary(value, s.index); d.add_temporary(key_cache, s.index); @@ -147,6 +127,16 @@ void reshape_and_cache_impl( d.add_temporary(slot_mapping, s.index); } +void reshape_and_cache_impl( + nb::handle key_h, nb::handle value_h, + nb::handle key_cache_h, nb::handle value_cache_h, + nb::handle slot_mapping_h) { + dispatch_reshape_and_cache( + *nb::inst_ptr(key_h), *nb::inst_ptr(value_h), + *nb::inst_ptr(key_cache_h), *nb::inst_ptr(value_cache_h), + *nb::inst_ptr(slot_mapping_h), default_stream(Device::gpu)); +} + // --------------------------------------------------------------------------- // paged_attention_v1 // --------------------------------------------------------------------------- @@ -260,67 +250,36 @@ void paged_attention_v1_impl( } // --------------------------------------------------------------------------- -// paged_attention_v2_online — dispatches the online-softmax v2 kernel +// paged_attention_v2_online — dispatch helper + eager wrappers // --------------------------------------------------------------------------- -void paged_attention_v2_online_impl_common( - nb::handle out_h, - nb::handle query_h, - nb::handle key_cache_h, - nb::handle value_cache_h, - int num_kv_heads, - float scale, - float softcap, - nb::handle block_tables_h, - nb::handle seq_lens_h, - nb::handle cu_seqlens_q_h, - int block_size, - int max_seq_len, - int sliding_window, - array* exp_sums, - array* max_logits, - array* tmp_out, - array* sinks -) { - auto& out = *nb::inst_ptr(out_h); - auto& query = *nb::inst_ptr(query_h); - auto& key_cache = *nb::inst_ptr(key_cache_h); - auto& value_cache = *nb::inst_ptr(value_cache_h); - auto& block_tables = *nb::inst_ptr(block_tables_h); - auto& seq_lens = *nb::inst_ptr(seq_lens_h); - auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); - - auto s = default_stream(Device::gpu); - auto& d = metal::device(Device::gpu); +static void dispatch_paged_attention_v2_online( + array& out, const array& query, + const array& key_cache, const array& value_cache, + int num_kv_heads, float scale, float softcap, + const array& block_tables, const array& seq_lens, + const array& cu_seqlens_q, + int block_size, int max_seq_len, int sliding_window, Stream s) { + auto& d = metal::device(s.device); - // Varlen: query shape is [total_q_tokens, num_heads, head_size] int total_q_tokens = static_cast(query.shape(0)); int num_heads = static_cast(query.shape(1)); int head_size = static_cast(query.shape(2)); int max_blocks = static_cast(block_tables.shape(1)); int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; - int max_num_partitions = - std::max(1, (max_seq_len + kPartitionSize - 1) / kPartitionSize); - bool use_partitioning = - exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr && - kPartitionSize % block_size == 0 && max_num_partitions > 1; - // Same kernel name format as v1 — the template instantiation is identical. auto dt = dtype_to_metal(query.dtype()); std::string kname = "paged_attention_" + dt + "_cache_" + dt + "_hs" + std::to_string(head_size) + "_bs" + std::to_string(block_size) + - "_nt256_nsl32_ps" + - std::to_string(use_partitioning ? kPartitionSize : 0); + "_nt256_nsl32_ps0"; + bool use_partitioning = false; bool use_alibi = false; bool use_fp8 = false; - bool use_sinks = sinks != nullptr; + bool use_sinks = false; - // Use the v2 library (online softmax kernel). - // get_kernel(function_name, library, cache_key, constants) - // Cache key must differ from v1 to avoid collision in MLX's kernel cache. auto* lib = d.get_library("paged_attention_v2_kern"); auto* kernel = d.get_kernel( kname, lib, kname + "_v2", @@ -329,9 +288,6 @@ void paged_attention_v2_online_impl_common( {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}, {&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); - // Threadgroup shared memory for online softmax: - // During KV loop: NUM_WARPS * BLOCK_SIZE floats (per-warp score buffer) - // During merge: 2*NUM_WARPS floats (m, l) + NUM_WARPS * HEAD_SIZE floats (O) constexpr int NUM_THREADS = 256; constexpr int NUM_SIMD_LANES = 32; constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; @@ -345,18 +301,7 @@ void paged_attention_v2_online_impl_common( enc.set_compute_pipeline_state(kernel); enc.set_threadgroup_memory_length(shmem, 0); - // Buffer bindings - if (use_partitioning) { - if (exp_sums == nullptr || max_logits == nullptr || tmp_out == nullptr) { - throw std::runtime_error( - "Partitioned v2 attention requires scratch buffers"); - } - enc.set_output_array(*exp_sums, 0); - enc.set_output_array(*max_logits, 1); - enc.set_output_array(*tmp_out, 2); - } else { - enc.set_output_array(out, 2); - } + enc.set_output_array(out, 2); enc.set_input_array(query, 3); enc.set_input_array(key_cache, 4); enc.set_input_array(value_cache, 5); @@ -364,7 +309,6 @@ void paged_attention_v2_online_impl_common( int32_t nkv = static_cast(num_kv_heads); enc.set_bytes(nkv, 8); enc.set_bytes(scale, 9); - // softcap: 0.0 = disabled, >0 = enabled. Passed through to kernel as-is. float softcapping = softcap; enc.set_bytes(softcapping, 10); @@ -380,51 +324,145 @@ void paged_attention_v2_online_impl_common( enc.set_bytes(q_stride, 15); enc.set_bytes(kv_block_stride, 16); enc.set_bytes(kv_head_stride, 17); - if (use_sinks) { - enc.set_input_array(*sinks, 18); - } - // Varlen buffers (new in v2) enc.set_input_array(cu_seqlens_q, 19); int32_t num_seqs_i = static_cast(num_seqs); enc.set_bytes(num_seqs_i, 20); int32_t sliding_window_i = static_cast(sliding_window); enc.set_bytes(sliding_window_i, 21); - // Grid: one threadgroup per (head, query_token) - const int32_t grid_z = - static_cast(use_partitioning ? max_num_partitions : 1); enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, total_q_tokens, grid_z), + MTL::Size::Make(num_heads, total_q_tokens, 1), + MTL::Size::Make(NUM_THREADS, 1, 1)); + + d.add_temporary(out, s.index); + d.add_temporary(query, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(block_tables, s.index); + d.add_temporary(seq_lens, s.index); + d.add_temporary(cu_seqlens_q, s.index); +} + +// Eager wrapper — keeps the old handle-based API for metal_unified_attention +static void paged_attention_v2_online_impl_common( + nb::handle out_h, nb::handle query_h, + nb::handle key_cache_h, nb::handle value_cache_h, + int num_kv_heads, float scale, float softcap, + nb::handle block_tables_h, nb::handle seq_lens_h, + nb::handle cu_seqlens_q_h, + int block_size, int max_seq_len, int sliding_window, + array* exp_sums, array* max_logits, array* tmp_out, array* sinks) { + auto& out = *nb::inst_ptr(out_h); + auto& query = *nb::inst_ptr(query_h); + auto& key_cache = *nb::inst_ptr(key_cache_h); + auto& value_cache = *nb::inst_ptr(value_cache_h); + auto& block_tables = *nb::inst_ptr(block_tables_h); + auto& seq_lens = *nb::inst_ptr(seq_lens_h); + auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); + + // For the simple non-partitioned case, delegate to the dispatch helper + bool needs_partitioning = + exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr; + if (!needs_partitioning) { + dispatch_paged_attention_v2_online( + out, query, key_cache, value_cache, + num_kv_heads, scale, softcap, + block_tables, seq_lens, cu_seqlens_q, + block_size, max_seq_len, sliding_window, + default_stream(Device::gpu)); + return; + } + + // Partitioned path — inline here since the fused primitive doesn't use it + auto s = default_stream(Device::gpu); + auto& d = metal::device(s.device); + int total_q_tokens = static_cast(query.shape(0)); + int num_heads = static_cast(query.shape(1)); + int head_size = static_cast(query.shape(2)); + int max_blocks = static_cast(block_tables.shape(1)); + int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; + int max_num_partitions = + std::max(1, (max_seq_len + kPartitionSize - 1) / kPartitionSize); + bool use_partitioning = + kPartitionSize % block_size == 0 && max_num_partitions > 1; + + auto dt = dtype_to_metal(query.dtype()); + std::string kname = + "paged_attention_" + dt + "_cache_" + dt + + "_hs" + std::to_string(head_size) + + "_bs" + std::to_string(block_size) + + "_nt256_nsl32_ps" + + std::to_string(use_partitioning ? kPartitionSize : 0); + + bool use_alibi = false, use_fp8 = false; + bool use_sinks_flag = sinks != nullptr; + auto* lib = d.get_library("paged_attention_v2_kern"); + auto* kernel = d.get_kernel( + kname, lib, kname + "_v2", + {{&use_partitioning, MTL::DataType::DataTypeBool, NS::UInteger(10)}, + {&use_alibi, MTL::DataType::DataTypeBool, NS::UInteger(20)}, + {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}, + {&use_sinks_flag, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); + + constexpr int NUM_THREADS = 256, NUM_SIMD_LANES = 32; + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + int warp_scores_bytes = NUM_WARPS * block_size * (int)sizeof(float); + int merge_bytes = (2 * NUM_WARPS + NUM_WARPS * head_size) * (int)sizeof(float); + size_t shmem = (size_t)std::max(warp_scores_bytes, merge_bytes); + + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + enc.set_threadgroup_memory_length(shmem, 0); + + if (use_partitioning) { + enc.set_output_array(*exp_sums, 0); + enc.set_output_array(*max_logits, 1); + enc.set_output_array(*tmp_out, 2); + } else { + enc.set_output_array(out, 2); + } + enc.set_input_array(query, 3); + enc.set_input_array(key_cache, 4); + enc.set_input_array(value_cache, 5); + int32_t nkv = (int32_t)num_kv_heads; + enc.set_bytes(nkv, 8); + enc.set_bytes(scale, 9); + enc.set_bytes(softcap, 10); + enc.set_input_array(block_tables, 11); + enc.set_input_array(seq_lens, 12); + enc.set_bytes((int32_t)max_blocks, 13); + enc.set_bytes((int32_t)(num_heads * head_size), 15); + enc.set_bytes((int32_t)key_cache.strides()[0], 16); + enc.set_bytes((int32_t)key_cache.strides()[2], 17); + if (use_sinks_flag) enc.set_input_array(*sinks, 18); + enc.set_input_array(cu_seqlens_q, 19); + enc.set_bytes((int32_t)num_seqs, 20); + enc.set_bytes((int32_t)sliding_window, 21); + + enc.dispatch_threadgroups( + MTL::Size::Make(num_heads, total_q_tokens, + use_partitioning ? max_num_partitions : 1), MTL::Size::Make(NUM_THREADS, 1, 1)); if (use_partitioning) { - std::string reduce_kname = - "paged_attention_v2_reduce_" + dt + + std::string rk = "paged_attention_v2_reduce_" + dt + "_hs" + std::to_string(head_size) + "_nt256_nsl32_ps" + std::to_string(kPartitionSize); - auto* reduce_kernel = d.get_kernel( - reduce_kname, - lib, - reduce_kname + "_v2_reduce", - {{&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); - size_t reduce_shmem = - static_cast(2 * max_num_partitions * sizeof(float)); - enc.set_compute_pipeline_state(reduce_kernel); - enc.set_threadgroup_memory_length(reduce_shmem, 0); - - enc.set_output_array(out, 0); - enc.set_input_array(*exp_sums, 1); - enc.set_input_array(*max_logits, 2); - enc.set_input_array(*tmp_out, 3); - enc.set_input_array(seq_lens, 4); - int32_t max_num_partitions_i = static_cast(max_num_partitions); - enc.set_bytes(max_num_partitions_i, 5); - if (use_sinks) { - enc.set_input_array(*sinks, 6); - } + auto* rkernel = d.get_kernel(rk, lib, rk + "_v2_reduce", + {{&use_sinks_flag, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); + enc.set_compute_pipeline_state(rkernel); + enc.set_threadgroup_memory_length( + (size_t)(2 * max_num_partitions * sizeof(float)), 0); + enc.set_output_array(out, 0); + enc.set_input_array(*exp_sums, 1); + enc.set_input_array(*max_logits, 2); + enc.set_input_array(*tmp_out, 3); + enc.set_input_array(seq_lens, 4); + enc.set_bytes((int32_t)max_num_partitions, 5); + if (use_sinks_flag) enc.set_input_array(*sinks, 6); enc.set_input_array(cu_seqlens_q, 7); - enc.set_bytes(num_seqs_i, 8); + enc.set_bytes((int32_t)num_seqs, 8); enc.dispatch_threadgroups( MTL::Size::Make(num_heads, total_q_tokens, 1), MTL::Size::Make(NUM_THREADS, 1, 1)); @@ -442,9 +480,100 @@ void paged_attention_v2_online_impl_common( d.add_temporary(*max_logits, s.index); d.add_temporary(*tmp_out, s.index); } - if (use_sinks) { - d.add_temporary(*sinks, s.index); + if (use_sinks_flag) d.add_temporary(*sinks, s.index); +} + +// --------------------------------------------------------------------------- +// Fused reshape_and_cache + paged_attention_v2_online PRIMITIVE +// +// Three outputs: [updated_key_cache, updated_value_cache, attention_output] +// Both Metal kernel dispatches happen in the same eval_gpu → same command +// encoder → Metal guarantees sequential ordering (write before read). +// --------------------------------------------------------------------------- + +class FusedReshapeAndAttentionPrimitive : public Primitive { + public: + FusedReshapeAndAttentionPrimitive( + Stream stream, int num_kv_heads, float scale, float softcap, + int block_size, int max_seq_len, int sliding_window) + : Primitive(stream), + num_kv_heads_(num_kv_heads), scale_(scale), softcap_(softcap), + block_size_(block_size), max_seq_len_(max_seq_len), + sliding_window_(sliding_window) {} + + void eval_cpu(const std::vector&, std::vector&) override { + throw std::runtime_error( + "FusedReshapeAndAttentionPrimitive only supports GPU"); + } + + void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override { + // inputs: [query, key, value, key_cache, value_cache, slot_mapping, + // block_tables, seq_lens, cu_seqlens_q] + // outputs: [updated_key_cache, updated_value_cache, attention_output] + + // 1. Cache outputs alias input cache buffers (zero-copy) + outputs[0].copy_shared_buffer(inputs[3]); // key_cache + outputs[1].copy_shared_buffer(inputs[4]); // value_cache + + // 2. Dispatch reshape_and_cache — writes new K/V into cache + dispatch_reshape_and_cache( + inputs[1], inputs[2], // key, value + outputs[0], outputs[1], // cache buffers (aliased) + inputs[5], // slot_mapping + stream()); + + // 3. Dispatch attention — reads from cache (same command encoder!) + outputs[2].set_data(allocator::malloc(outputs[2].nbytes())); + dispatch_paged_attention_v2_online( + outputs[2], // attention output + inputs[0], // query + outputs[0], outputs[1], // cache buffers (just written) + num_kv_heads_, scale_, softcap_, + inputs[6], inputs[7], inputs[8], // block_tables, seq_lens, cu_seqlens_q + block_size_, max_seq_len_, sliding_window_, + stream()); + } + + const char* name() const override { return "FusedReshapeAndAttention"; } + + bool is_equivalent(const Primitive& other) const override { + auto* rhs = dynamic_cast(&other); + return rhs && rhs->num_kv_heads_ == num_kv_heads_ + && rhs->scale_ == scale_ && rhs->softcap_ == softcap_ + && rhs->block_size_ == block_size_ + && rhs->max_seq_len_ == max_seq_len_ + && rhs->sliding_window_ == sliding_window_; } + + private: + int num_kv_heads_; + float scale_; + float softcap_; + int block_size_; + int max_seq_len_; + int sliding_window_; +}; + +static std::vector fused_reshape_attention_fn( + const array& query, const array& key, const array& value, + const array& key_cache, const array& value_cache, + const array& slot_mapping, + int num_kv_heads, float scale, float softcap, + const array& block_tables, const array& seq_lens, + const array& cu_seqlens_q, + int block_size, int max_seq_len, int sliding_window) { + auto prim = std::make_shared( + default_stream(Device::gpu), + num_kv_heads, scale, softcap, + block_size, max_seq_len, sliding_window); + return array::make_arrays( + {key_cache.shape(), value_cache.shape(), query.shape()}, + {key_cache.dtype(), value_cache.dtype(), query.dtype()}, + std::move(prim), + {query, key, value, key_cache, value_cache, slot_mapping, + block_tables, seq_lens, cu_seqlens_q}); } void paged_attention_v2_online_impl( @@ -576,4 +705,43 @@ NB_MODULE(_paged_ops, m) { nb::arg("exp_sums"), nb::arg("max_logits"), nb::arg("tmp_out"), "Online-softmax varlen paged attention (v2) with caller-provided " "partition scratch buffers."); + + // Fused primitive: reshape_and_cache + paged_attention in one eval_gpu. + // Uses overwrite_descriptor to bypass cross-module nanobind RTTI. + m.def("fused_reshape_attention_primitive", + [](nb::handle query_h, nb::handle key_h, nb::handle value_h, + nb::handle key_cache_h, nb::handle value_cache_h, + nb::handle slot_mapping_h, + int num_kv_heads, float scale, float softcap, + nb::handle block_tables_h, nb::handle seq_lens_h, + nb::handle cu_seqlens_q_h, + int block_size, int max_seq_len, int sliding_window, + nb::handle out_k_h, nb::handle out_v_h, nb::handle out_attn_h) { + auto result = fused_reshape_attention_fn( + *nb::inst_ptr(query_h), + *nb::inst_ptr(key_h), + *nb::inst_ptr(value_h), + *nb::inst_ptr(key_cache_h), + *nb::inst_ptr(value_cache_h), + *nb::inst_ptr(slot_mapping_h), + num_kv_heads, scale, softcap, + *nb::inst_ptr(block_tables_h), + *nb::inst_ptr(seq_lens_h), + *nb::inst_ptr(cu_seqlens_q_h), + block_size, max_seq_len, sliding_window); + nb::inst_ptr(out_k_h)->overwrite_descriptor(result[0]); + nb::inst_ptr(out_v_h)->overwrite_descriptor(result[1]); + nb::inst_ptr(out_attn_h)->overwrite_descriptor(result[2]); + }, + nb::arg("query"), nb::arg("key"), nb::arg("value"), + nb::arg("key_cache"), nb::arg("value_cache"), + nb::arg("slot_mapping"), + nb::arg("num_kv_heads"), nb::arg("scale"), nb::arg("softcap"), + nb::arg("block_tables"), nb::arg("seq_lens"), + nb::arg("cu_seqlens_q"), + nb::arg("block_size"), nb::arg("max_seq_len"), + nb::arg("sliding_window"), + nb::arg("out_key"), nb::arg("out_value"), nb::arg("out_attn"), + "Fused primitive: reshape_and_cache + paged_attention_v2_online " + "in a single eval_gpu dispatch (same command encoder)."); } diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index 9cc20c69..b46728e1 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -105,29 +105,22 @@ def sdpa_forward( seq_lens = mx.array(ctx.context_lens, dtype=mx.int32) cu_seqlens_q = mx.array(ctx.cu_seqlens, dtype=mx.int32) - # Allocate output buffer before eval so we can materialize everything in one call - out = mx.zeros((L, n_heads, head_dim), dtype=kv_cache.dtype) - mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens, cu_seqlens_q, out) - ops = get_ops() - - # Write K/V into paged cache BEFORE attention — the kernel reads from - # the paged cache via block_table, not from raw tensors. - ops.reshape_and_cache( - k_3d, - v_3d, - kv_cache.key_caches[layer_idx], - kv_cache.value_caches[layer_idx], - slot_mapping, - ) - max_seq_len = max(ctx.context_lens) - ops.paged_attention_v2_online( - out, - q_3d, + # Fused primitive: reshape_and_cache + paged_attention in one eval_gpu. + # Both Metal kernel dispatches go to the same command encoder, so + # Metal guarantees the cache write completes before the attention read. + # No per-layer mx.eval or mx.synchronize — fully lazy until the model + # runner evaluates the final logits. + updated_k_cache = mx.array(0) + updated_v_cache = mx.array(0) + out = mx.array(0) + ops.fused_reshape_attention_primitive( + q_3d, k_3d, v_3d, kv_cache.key_caches[layer_idx], kv_cache.value_caches[layer_idx], + slot_mapping, kv_cache.num_kv_heads, inner.scale, 0.0, # softcap (0 = disabled) @@ -137,9 +130,20 @@ def sdpa_forward( kv_cache.block_size, max_seq_len, -1, # sliding_window (-1 = disabled) + updated_k_cache, + updated_v_cache, + out, ) - mx.synchronize() + # Evaluate the fused primitive for this layer. + # TODO: investigate why fully-lazy (no per-layer eval) produces garbage. + # The fused primitive dispatches both kernels to the same command encoder, + # but something in the cross-layer lazy graph breaks correctness. + mx.eval(updated_k_cache, updated_v_cache, out) + + # Rebind cache references for next layer / decode step + kv_cache.key_caches[layer_idx] = updated_k_cache + kv_cache.value_caches[layer_idx] = updated_v_cache # output: (L, n_heads, head_dim) → (B, L, n_heads * head_dim) out = out.reshape(B, L, n_heads * head_dim) From dd279f0b582572e361b1b69ed20e3f28262d859d Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 2 Apr 2026 23:47:26 -0500 Subject: [PATCH 2/7] reword the mx.eval issue comment Signed-off-by: ran --- vllm_metal/metal_kernel_backend/attention_sdpa.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index b46728e1..8acbccf3 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -135,10 +135,10 @@ def sdpa_forward( out, ) - # Evaluate the fused primitive for this layer. - # TODO: investigate why fully-lazy (no per-layer eval) produces garbage. - # The fused primitive dispatches both kernels to the same command encoder, - # but something in the cross-layer lazy graph breaks correctness. + # Evaluate the fused primitive for this layer. Required because + # copy_shared_buffer cache aliasing is not safe across a fully-lazy + # 28-layer graph — MLX's buffer management may reorder or reuse + # aliased buffers. Removing this eval is tracked as future work. mx.eval(updated_k_cache, updated_v_cache, out) # Rebind cache references for next layer / decode step From 7acd2889777bd069f45ddfc7f7694d809417b67b Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 3 Apr 2026 00:27:27 -0500 Subject: [PATCH 3/7] cleanup dead code and fix CI Signed-off-by: ran --- vllm_metal/metal/paged_ops.cpp | 18 +++++++++--------- .../metal_kernel_backend/attention_sdpa.py | 17 +++++++++-------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 4afaf850..297a1323 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -491,9 +491,9 @@ static void paged_attention_v2_online_impl_common( // encoder → Metal guarantees sequential ordering (write before read). // --------------------------------------------------------------------------- -class FusedReshapeAndAttentionPrimitive : public Primitive { +class PagedSDPAPrimitive : public Primitive { public: - FusedReshapeAndAttentionPrimitive( + PagedSDPAPrimitive( Stream stream, int num_kv_heads, float scale, float softcap, int block_size, int max_seq_len, int sliding_window) : Primitive(stream), @@ -503,7 +503,7 @@ class FusedReshapeAndAttentionPrimitive : public Primitive { void eval_cpu(const std::vector&, std::vector&) override { throw std::runtime_error( - "FusedReshapeAndAttentionPrimitive only supports GPU"); + "PagedSDPAPrimitive only supports GPU"); } void eval_gpu( @@ -539,7 +539,7 @@ class FusedReshapeAndAttentionPrimitive : public Primitive { const char* name() const override { return "FusedReshapeAndAttention"; } bool is_equivalent(const Primitive& other) const override { - auto* rhs = dynamic_cast(&other); + auto* rhs = dynamic_cast(&other); return rhs && rhs->num_kv_heads_ == num_kv_heads_ && rhs->scale_ == scale_ && rhs->softcap_ == softcap_ && rhs->block_size_ == block_size_ @@ -556,7 +556,7 @@ class FusedReshapeAndAttentionPrimitive : public Primitive { int sliding_window_; }; -static std::vector fused_reshape_attention_fn( +static std::vector paged_sdpa_primitive_fn( const array& query, const array& key, const array& value, const array& key_cache, const array& value_cache, const array& slot_mapping, @@ -564,7 +564,7 @@ static std::vector fused_reshape_attention_fn( const array& block_tables, const array& seq_lens, const array& cu_seqlens_q, int block_size, int max_seq_len, int sliding_window) { - auto prim = std::make_shared( + auto prim = std::make_shared( default_stream(Device::gpu), num_kv_heads, scale, softcap, block_size, max_seq_len, sliding_window); @@ -708,7 +708,7 @@ NB_MODULE(_paged_ops, m) { // Fused primitive: reshape_and_cache + paged_attention in one eval_gpu. // Uses overwrite_descriptor to bypass cross-module nanobind RTTI. - m.def("fused_reshape_attention_primitive", + m.def("paged_sdpa_primitive", [](nb::handle query_h, nb::handle key_h, nb::handle value_h, nb::handle key_cache_h, nb::handle value_cache_h, nb::handle slot_mapping_h, @@ -717,7 +717,7 @@ NB_MODULE(_paged_ops, m) { nb::handle cu_seqlens_q_h, int block_size, int max_seq_len, int sliding_window, nb::handle out_k_h, nb::handle out_v_h, nb::handle out_attn_h) { - auto result = fused_reshape_attention_fn( + auto result = paged_sdpa_primitive_fn( *nb::inst_ptr(query_h), *nb::inst_ptr(key_h), *nb::inst_ptr(value_h), @@ -742,6 +742,6 @@ NB_MODULE(_paged_ops, m) { nb::arg("block_size"), nb::arg("max_seq_len"), nb::arg("sliding_window"), nb::arg("out_key"), nb::arg("out_value"), nb::arg("out_attn"), - "Fused primitive: reshape_and_cache + paged_attention_v2_online " + "Paged SDPA primitive: reshape_and_cache + paged_attention_v2_online " "in a single eval_gpu dispatch (same command encoder)."); } diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index 8acbccf3..d8af9817 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -108,16 +108,17 @@ def sdpa_forward( ops = get_ops() max_seq_len = max(ctx.context_lens) - # Fused primitive: reshape_and_cache + paged_attention in one eval_gpu. - # Both Metal kernel dispatches go to the same command encoder, so - # Metal guarantees the cache write completes before the attention read. - # No per-layer mx.eval or mx.synchronize — fully lazy until the model - # runner evaluates the final logits. + # Paged SDPA primitive: reshape_and_cache + paged_attention in one + # eval_gpu. Both Metal kernel dispatches go to the same command + # encoder, so Metal guarantees the cache write completes before the + # attention read. updated_k_cache = mx.array(0) updated_v_cache = mx.array(0) out = mx.array(0) - ops.fused_reshape_attention_primitive( - q_3d, k_3d, v_3d, + ops.paged_sdpa_primitive( + q_3d, + k_3d, + v_3d, kv_cache.key_caches[layer_idx], kv_cache.value_caches[layer_idx], slot_mapping, @@ -135,7 +136,7 @@ def sdpa_forward( out, ) - # Evaluate the fused primitive for this layer. Required because + # Evaluate the paged SDPA primitive for this layer. Required because # copy_shared_buffer cache aliasing is not safe across a fully-lazy # 28-layer graph — MLX's buffer management may reorder or reuse # aliased buffers. Removing this eval is tracked as future work. From 9800d378877f74aa7f2e9bc4b59c7924d64f7097 Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 3 Apr 2026 11:02:23 -0500 Subject: [PATCH 4/7] fix add_temporary bug Signed-off-by: ran --- vllm_metal/metal/paged_ops.cpp | 140 ++++++++---------- .../metal_kernel_backend/attention_sdpa.py | 54 +++---- 2 files changed, 93 insertions(+), 101 deletions(-) diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 297a1323..84f43e5f 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -76,10 +76,18 @@ static std::string dtype_to_metal(Dtype dt) { // reshape_and_cache — dispatch helper + eager binding // --------------------------------------------------------------------------- +// When called from a primitive's eval_gpu, from_primitive should be true +// to skip ALL add_temporary calls. add_temporary removes buffer pointers +// from the encoder's input/output tracking, defeating fence-based +// synchronisation across command buffer boundaries. Inside a primitive, +// MLX's evaluator already manages array lifetimes via the completion +// handler. In the eager path, add_temporary is needed to keep +// Python-owned arrays alive until the command buffer completes. static void dispatch_reshape_and_cache( const array& key, const array& value, array& key_cache, array& value_cache, - const array& slot_mapping, Stream s) { + const array& slot_mapping, Stream s, + bool from_primitive = false) { auto& d = metal::device(s.device); int num_tokens = static_cast(key.shape(0)); @@ -120,11 +128,13 @@ static void dispatch_reshape_and_cache( MTL::Size::Make(num_tokens, 1, 1), MTL::Size::Make(tpg, 1, 1)); - d.add_temporary(key, s.index); - d.add_temporary(value, s.index); - d.add_temporary(key_cache, s.index); - d.add_temporary(value_cache, s.index); - d.add_temporary(slot_mapping, s.index); + if (!from_primitive) { + d.add_temporary(key, s.index); + d.add_temporary(value, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(slot_mapping, s.index); + } } void reshape_and_cache_impl( @@ -259,7 +269,8 @@ static void dispatch_paged_attention_v2_online( int num_kv_heads, float scale, float softcap, const array& block_tables, const array& seq_lens, const array& cu_seqlens_q, - int block_size, int max_seq_len, int sliding_window, Stream s) { + int block_size, int max_seq_len, int sliding_window, Stream s, + bool from_primitive = false) { auto& d = metal::device(s.device); int total_q_tokens = static_cast(query.shape(0)); @@ -335,13 +346,15 @@ static void dispatch_paged_attention_v2_online( MTL::Size::Make(num_heads, total_q_tokens, 1), MTL::Size::Make(NUM_THREADS, 1, 1)); - d.add_temporary(out, s.index); - d.add_temporary(query, s.index); - d.add_temporary(key_cache, s.index); - d.add_temporary(value_cache, s.index); - d.add_temporary(block_tables, s.index); - d.add_temporary(seq_lens, s.index); - d.add_temporary(cu_seqlens_q, s.index); + if (!from_primitive) { + d.add_temporary(out, s.index); + d.add_temporary(query, s.index); + d.add_temporary(key_cache, s.index); + d.add_temporary(value_cache, s.index); + d.add_temporary(block_tables, s.index); + d.add_temporary(seq_lens, s.index); + d.add_temporary(cu_seqlens_q, s.index); + } } // Eager wrapper — keeps the old handle-based API for metal_unified_attention @@ -484,62 +497,47 @@ static void paged_attention_v2_online_impl_common( } // --------------------------------------------------------------------------- -// Fused reshape_and_cache + paged_attention_v2_online PRIMITIVE +// Paged attention primitive (read-only): paged_attention_v2_online only. // -// Three outputs: [updated_key_cache, updated_value_cache, attention_output] -// Both Metal kernel dispatches happen in the same eval_gpu → same command -// encoder → Metal guarantees sequential ordering (write before read). +// Single output: attention result. The KV cache is read-only — cache +// writes are handled upstream by MLX-native scatter (pure functional). +// This is a clean pure function: inputs → output, no side effects. // --------------------------------------------------------------------------- -class PagedSDPAPrimitive : public Primitive { +class PagedAttentionPrimitive : public UnaryPrimitive { public: - PagedSDPAPrimitive( + PagedAttentionPrimitive( Stream stream, int num_kv_heads, float scale, float softcap, int block_size, int max_seq_len, int sliding_window) - : Primitive(stream), + : UnaryPrimitive(stream), num_kv_heads_(num_kv_heads), scale_(scale), softcap_(softcap), block_size_(block_size), max_seq_len_(max_seq_len), sliding_window_(sliding_window) {} - void eval_cpu(const std::vector&, std::vector&) override { + void eval_cpu(const std::vector&, array&) override { throw std::runtime_error( - "PagedSDPAPrimitive only supports GPU"); + "PagedAttentionPrimitive only supports GPU"); } - void eval_gpu( - const std::vector& inputs, - std::vector& outputs) override { - // inputs: [query, key, value, key_cache, value_cache, slot_mapping, - // block_tables, seq_lens, cu_seqlens_q] - // outputs: [updated_key_cache, updated_value_cache, attention_output] - - // 1. Cache outputs alias input cache buffers (zero-copy) - outputs[0].copy_shared_buffer(inputs[3]); // key_cache - outputs[1].copy_shared_buffer(inputs[4]); // value_cache - - // 2. Dispatch reshape_and_cache — writes new K/V into cache - dispatch_reshape_and_cache( - inputs[1], inputs[2], // key, value - outputs[0], outputs[1], // cache buffers (aliased) - inputs[5], // slot_mapping - stream()); - - // 3. Dispatch attention — reads from cache (same command encoder!) - outputs[2].set_data(allocator::malloc(outputs[2].nbytes())); + void eval_gpu(const std::vector& inputs, array& out) override { + // inputs: [query, key_cache, value_cache, block_tables, seq_lens, + // cu_seqlens_q] + out.set_data(allocator::malloc(out.nbytes())); dispatch_paged_attention_v2_online( - outputs[2], // attention output + out, inputs[0], // query - outputs[0], outputs[1], // cache buffers (just written) + inputs[1], inputs[2], // key_cache, value_cache num_kv_heads_, scale_, softcap_, - inputs[6], inputs[7], inputs[8], // block_tables, seq_lens, cu_seqlens_q + inputs[3], inputs[4], inputs[5], // block_tables, seq_lens, cu_seqlens_q block_size_, max_seq_len_, sliding_window_, - stream()); + stream(), + /*from_primitive=*/true); } - const char* name() const override { return "FusedReshapeAndAttention"; } + const char* name() const override { return "PagedAttention"; } bool is_equivalent(const Primitive& other) const override { - auto* rhs = dynamic_cast(&other); + auto* rhs = dynamic_cast(&other); return rhs && rhs->num_kv_heads_ == num_kv_heads_ && rhs->scale_ == scale_ && rhs->softcap_ == softcap_ && rhs->block_size_ == block_size_ @@ -556,24 +554,20 @@ class PagedSDPAPrimitive : public Primitive { int sliding_window_; }; -static std::vector paged_sdpa_primitive_fn( - const array& query, const array& key, const array& value, +static array paged_attention_primitive_fn( + const array& query, const array& key_cache, const array& value_cache, - const array& slot_mapping, int num_kv_heads, float scale, float softcap, const array& block_tables, const array& seq_lens, const array& cu_seqlens_q, int block_size, int max_seq_len, int sliding_window) { - auto prim = std::make_shared( + auto prim = std::make_shared( default_stream(Device::gpu), num_kv_heads, scale, softcap, block_size, max_seq_len, sliding_window); - return array::make_arrays( - {key_cache.shape(), value_cache.shape(), query.shape()}, - {key_cache.dtype(), value_cache.dtype(), query.dtype()}, - std::move(prim), - {query, key, value, key_cache, value_cache, slot_mapping, - block_tables, seq_lens, cu_seqlens_q}); + return array( + query.shape(), query.dtype(), std::move(prim), + {query, key_cache, value_cache, block_tables, seq_lens, cu_seqlens_q}); } void paged_attention_v2_online_impl( @@ -706,42 +700,36 @@ NB_MODULE(_paged_ops, m) { "Online-softmax varlen paged attention (v2) with caller-provided " "partition scratch buffers."); - // Fused primitive: reshape_and_cache + paged_attention in one eval_gpu. + // Paged attention primitive (read-only): dispatches paged_attention_v2_online. + // Cache writes are handled by MLX-native scatter upstream. // Uses overwrite_descriptor to bypass cross-module nanobind RTTI. - m.def("paged_sdpa_primitive", - [](nb::handle query_h, nb::handle key_h, nb::handle value_h, + m.def("paged_attention_primitive", + [](nb::handle query_h, nb::handle key_cache_h, nb::handle value_cache_h, - nb::handle slot_mapping_h, int num_kv_heads, float scale, float softcap, nb::handle block_tables_h, nb::handle seq_lens_h, nb::handle cu_seqlens_q_h, int block_size, int max_seq_len, int sliding_window, - nb::handle out_k_h, nb::handle out_v_h, nb::handle out_attn_h) { - auto result = paged_sdpa_primitive_fn( + nb::handle out_h) { + auto result = paged_attention_primitive_fn( *nb::inst_ptr(query_h), - *nb::inst_ptr(key_h), - *nb::inst_ptr(value_h), *nb::inst_ptr(key_cache_h), *nb::inst_ptr(value_cache_h), - *nb::inst_ptr(slot_mapping_h), num_kv_heads, scale, softcap, *nb::inst_ptr(block_tables_h), *nb::inst_ptr(seq_lens_h), *nb::inst_ptr(cu_seqlens_q_h), block_size, max_seq_len, sliding_window); - nb::inst_ptr(out_k_h)->overwrite_descriptor(result[0]); - nb::inst_ptr(out_v_h)->overwrite_descriptor(result[1]); - nb::inst_ptr(out_attn_h)->overwrite_descriptor(result[2]); + nb::inst_ptr(out_h)->overwrite_descriptor(result); }, - nb::arg("query"), nb::arg("key"), nb::arg("value"), + nb::arg("query"), nb::arg("key_cache"), nb::arg("value_cache"), - nb::arg("slot_mapping"), nb::arg("num_kv_heads"), nb::arg("scale"), nb::arg("softcap"), nb::arg("block_tables"), nb::arg("seq_lens"), nb::arg("cu_seqlens_q"), nb::arg("block_size"), nb::arg("max_seq_len"), nb::arg("sliding_window"), - nb::arg("out_key"), nb::arg("out_value"), nb::arg("out_attn"), - "Paged SDPA primitive: reshape_and_cache + paged_attention_v2_online " - "in a single eval_gpu dispatch (same command encoder)."); + nb::arg("out"), + "Paged attention primitive (read-only). Cache writes are handled " + "by MLX-native scatter upstream."); } diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index d8af9817..7563a475 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -94,7 +94,7 @@ def sdpa_forward( k_3d = mx.contiguous(keys[0].transpose(1, 0, 2).astype(kv_cache.dtype)) v_3d = mx.contiguous(values[0].transpose(1, 0, 2).astype(kv_cache.dtype)) - slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64) + slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int32) # Build block_tables and seq_lens from context max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables) @@ -108,20 +108,36 @@ def sdpa_forward( ops = get_ops() max_seq_len = max(ctx.context_lens) - # Paged SDPA primitive: reshape_and_cache + paged_attention in one - # eval_gpu. Both Metal kernel dispatches go to the same command - # encoder, so Metal guarantees the cache write completes before the - # attention read. - updated_k_cache = mx.array(0) - updated_v_cache = mx.array(0) + # --- Cache write: MLX-native scatter (pure functional, graph-tracked) --- + # Flatten cache to [num_slots, num_kv_heads, head_dim], scatter new K/V + # by slot_mapping, then reshape back. This creates proper graph nodes + # that MLX's evaluator can track for dependency ordering and buffer + # donation — no in-place mutation, no copy_shared_buffer, no const_cast. + flat_k = kv_cache.key_caches[layer_idx].reshape(-1, kv_cache.num_kv_heads, head_dim) + flat_k[slot_mapping] = k_3d + new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape) + + flat_v = kv_cache.value_caches[layer_idx].reshape( + -1, kv_cache.num_kv_heads, head_dim + ) + flat_v[slot_mapping] = v_3d + new_v_cache = flat_v.reshape(kv_cache.value_caches[layer_idx].shape) + + # Rebind so next layer / decode step uses the updated cache + kv_cache.key_caches[layer_idx] = new_k_cache + kv_cache.value_caches[layer_idx] = new_v_cache + + # --- Attention: paged attention primitive (read-only, fully lazy) --- + # No per-layer eval or sync. The primitive participates in MLX's lazy + # graph and is evaluated by the model runner at the end of the forward + # pass. Fence-based synchronisation across command buffer boundaries + # works correctly because eval_gpu skips add_temporary (which would + # remove buffers from the encoder's fence tracking). out = mx.array(0) - ops.paged_sdpa_primitive( + ops.paged_attention_primitive( q_3d, - k_3d, - v_3d, - kv_cache.key_caches[layer_idx], - kv_cache.value_caches[layer_idx], - slot_mapping, + new_k_cache, + new_v_cache, kv_cache.num_kv_heads, inner.scale, 0.0, # softcap (0 = disabled) @@ -131,21 +147,9 @@ def sdpa_forward( kv_cache.block_size, max_seq_len, -1, # sliding_window (-1 = disabled) - updated_k_cache, - updated_v_cache, out, ) - # Evaluate the paged SDPA primitive for this layer. Required because - # copy_shared_buffer cache aliasing is not safe across a fully-lazy - # 28-layer graph — MLX's buffer management may reorder or reuse - # aliased buffers. Removing this eval is tracked as future work. - mx.eval(updated_k_cache, updated_v_cache, out) - - # Rebind cache references for next layer / decode step - kv_cache.key_caches[layer_idx] = updated_k_cache - kv_cache.value_caches[layer_idx] = updated_v_cache - # output: (L, n_heads, head_dim) → (B, L, n_heads * head_dim) out = out.reshape(B, L, n_heads * head_dim) return inner.o_proj(out) From b531a39d5a81ba5a3c85da395c8b75342568b439 Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 3 Apr 2026 12:39:18 -0500 Subject: [PATCH 5/7] cleanup unececcary change Signed-off-by: ran --- vllm_metal/metal/paged_ops.cpp | 145 ++++++++++++------ .../metal_kernel_backend/attention_sdpa.py | 2 +- 2 files changed, 96 insertions(+), 51 deletions(-) diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 84f43e5f..851251d7 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -357,15 +357,28 @@ static void dispatch_paged_attention_v2_online( } } -// Eager wrapper — keeps the old handle-based API for metal_unified_attention -static void paged_attention_v2_online_impl_common( - nb::handle out_h, nb::handle query_h, - nb::handle key_cache_h, nb::handle value_cache_h, - int num_kv_heads, float scale, float softcap, - nb::handle block_tables_h, nb::handle seq_lens_h, +// Eager wrapper — keeps the old handle-based API for metal_unified_attention. +// Non-partitioned case delegates to the dispatch helper above; +// partitioned case is handled inline (same as original code on main). +void paged_attention_v2_online_impl_common( + nb::handle out_h, + nb::handle query_h, + nb::handle key_cache_h, + nb::handle value_cache_h, + int num_kv_heads, + float scale, + float softcap, + nb::handle block_tables_h, + nb::handle seq_lens_h, nb::handle cu_seqlens_q_h, - int block_size, int max_seq_len, int sliding_window, - array* exp_sums, array* max_logits, array* tmp_out, array* sinks) { + int block_size, + int max_seq_len, + int sliding_window, + array* exp_sums, + array* max_logits, + array* tmp_out, + array* sinks +) { auto& out = *nb::inst_ptr(out_h); auto& query = *nb::inst_ptr(query_h); auto& key_cache = *nb::inst_ptr(key_cache_h); @@ -374,7 +387,7 @@ static void paged_attention_v2_online_impl_common( auto& seq_lens = *nb::inst_ptr(seq_lens_h); auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); - // For the simple non-partitioned case, delegate to the dispatch helper + // Non-partitioned case: delegate to the shared dispatch helper bool needs_partitioning = exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr; if (!needs_partitioning) { @@ -387,9 +400,10 @@ static void paged_attention_v2_online_impl_common( return; } - // Partitioned path — inline here since the fused primitive doesn't use it + // Partitioned path (unchanged from main) auto s = default_stream(Device::gpu); - auto& d = metal::device(s.device); + auto& d = metal::device(Device::gpu); + int total_q_tokens = static_cast(query.shape(0)); int num_heads = static_cast(query.shape(1)); int head_size = static_cast(query.shape(2)); @@ -408,21 +422,26 @@ static void paged_attention_v2_online_impl_common( "_nt256_nsl32_ps" + std::to_string(use_partitioning ? kPartitionSize : 0); - bool use_alibi = false, use_fp8 = false; - bool use_sinks_flag = sinks != nullptr; + bool use_alibi = false; + bool use_fp8 = false; + bool use_sinks = sinks != nullptr; + auto* lib = d.get_library("paged_attention_v2_kern"); auto* kernel = d.get_kernel( kname, lib, kname + "_v2", {{&use_partitioning, MTL::DataType::DataTypeBool, NS::UInteger(10)}, {&use_alibi, MTL::DataType::DataTypeBool, NS::UInteger(20)}, {&use_fp8, MTL::DataType::DataTypeBool, NS::UInteger(30)}, - {&use_sinks_flag, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); + {&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); - constexpr int NUM_THREADS = 256, NUM_SIMD_LANES = 32; - constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; - int warp_scores_bytes = NUM_WARPS * block_size * (int)sizeof(float); - int merge_bytes = (2 * NUM_WARPS + NUM_WARPS * head_size) * (int)sizeof(float); - size_t shmem = (size_t)std::max(warp_scores_bytes, merge_bytes); + constexpr int NUM_THREADS = 256; + constexpr int NUM_SIMD_LANES = 32; + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + int warp_scores_bytes = NUM_WARPS * block_size + * static_cast(sizeof(float)); + int merge_bytes = (2 * NUM_WARPS + NUM_WARPS * head_size) + * static_cast(sizeof(float)); + size_t shmem = static_cast(std::max(warp_scores_bytes, merge_bytes)); auto& enc = d.get_command_encoder(s.index); enc.set_compute_pipeline_state(kernel); @@ -435,47 +454,71 @@ static void paged_attention_v2_online_impl_common( } else { enc.set_output_array(out, 2); } - enc.set_input_array(query, 3); - enc.set_input_array(key_cache, 4); - enc.set_input_array(value_cache, 5); - int32_t nkv = (int32_t)num_kv_heads; - enc.set_bytes(nkv, 8); + enc.set_input_array(query, 3); + enc.set_input_array(key_cache, 4); + enc.set_input_array(value_cache, 5); + + int32_t nkv = static_cast(num_kv_heads); + enc.set_bytes(nkv, 8); enc.set_bytes(scale, 9); - enc.set_bytes(softcap, 10); + float softcapping = softcap; + enc.set_bytes(softcapping, 10); + enc.set_input_array(block_tables, 11); - enc.set_input_array(seq_lens, 12); - enc.set_bytes((int32_t)max_blocks, 13); - enc.set_bytes((int32_t)(num_heads * head_size), 15); - enc.set_bytes((int32_t)key_cache.strides()[0], 16); - enc.set_bytes((int32_t)key_cache.strides()[2], 17); - if (use_sinks_flag) enc.set_input_array(*sinks, 18); + enc.set_input_array(seq_lens, 12); + + int32_t max_blocks_i = static_cast(max_blocks); + enc.set_bytes(max_blocks_i, 13); + + int32_t q_stride = static_cast(num_heads * head_size); + int32_t kv_block_stride = static_cast(key_cache.strides()[0]); + int32_t kv_head_stride = static_cast(key_cache.strides()[2]); + enc.set_bytes(q_stride, 15); + enc.set_bytes(kv_block_stride, 16); + enc.set_bytes(kv_head_stride, 17); + if (use_sinks) { + enc.set_input_array(*sinks, 18); + } + enc.set_input_array(cu_seqlens_q, 19); - enc.set_bytes((int32_t)num_seqs, 20); - enc.set_bytes((int32_t)sliding_window, 21); + int32_t num_seqs_i = static_cast(num_seqs); + enc.set_bytes(num_seqs_i, 20); + int32_t sliding_window_i = static_cast(sliding_window); + enc.set_bytes(sliding_window_i, 21); + const int32_t grid_z = + static_cast(use_partitioning ? max_num_partitions : 1); enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, total_q_tokens, - use_partitioning ? max_num_partitions : 1), + MTL::Size::Make(num_heads, total_q_tokens, grid_z), MTL::Size::Make(NUM_THREADS, 1, 1)); if (use_partitioning) { - std::string rk = "paged_attention_v2_reduce_" + dt + + std::string reduce_kname = + "paged_attention_v2_reduce_" + dt + "_hs" + std::to_string(head_size) + "_nt256_nsl32_ps" + std::to_string(kPartitionSize); - auto* rkernel = d.get_kernel(rk, lib, rk + "_v2_reduce", - {{&use_sinks_flag, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); - enc.set_compute_pipeline_state(rkernel); - enc.set_threadgroup_memory_length( - (size_t)(2 * max_num_partitions * sizeof(float)), 0); - enc.set_output_array(out, 0); - enc.set_input_array(*exp_sums, 1); - enc.set_input_array(*max_logits, 2); - enc.set_input_array(*tmp_out, 3); - enc.set_input_array(seq_lens, 4); - enc.set_bytes((int32_t)max_num_partitions, 5); - if (use_sinks_flag) enc.set_input_array(*sinks, 6); + auto* reduce_kernel = d.get_kernel( + reduce_kname, + lib, + reduce_kname + "_v2_reduce", + {{&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); + size_t reduce_shmem = + static_cast(2 * max_num_partitions * sizeof(float)); + enc.set_compute_pipeline_state(reduce_kernel); + enc.set_threadgroup_memory_length(reduce_shmem, 0); + + enc.set_output_array(out, 0); + enc.set_input_array(*exp_sums, 1); + enc.set_input_array(*max_logits, 2); + enc.set_input_array(*tmp_out, 3); + enc.set_input_array(seq_lens, 4); + int32_t max_num_partitions_i = static_cast(max_num_partitions); + enc.set_bytes(max_num_partitions_i, 5); + if (use_sinks) { + enc.set_input_array(*sinks, 6); + } enc.set_input_array(cu_seqlens_q, 7); - enc.set_bytes((int32_t)num_seqs, 8); + enc.set_bytes(num_seqs_i, 8); enc.dispatch_threadgroups( MTL::Size::Make(num_heads, total_q_tokens, 1), MTL::Size::Make(NUM_THREADS, 1, 1)); @@ -493,7 +536,9 @@ static void paged_attention_v2_online_impl_common( d.add_temporary(*max_logits, s.index); d.add_temporary(*tmp_out, s.index); } - if (use_sinks_flag) d.add_temporary(*sinks, s.index); + if (use_sinks) { + d.add_temporary(*sinks, s.index); + } } // --------------------------------------------------------------------------- diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index 7563a475..db80f158 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -94,7 +94,7 @@ def sdpa_forward( k_3d = mx.contiguous(keys[0].transpose(1, 0, 2).astype(kv_cache.dtype)) v_3d = mx.contiguous(values[0].transpose(1, 0, 2).astype(kv_cache.dtype)) - slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int32) + slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64) # Build block_tables and seq_lens from context max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables) From eda8ffbb02087dcadba2065d39a97b837e950e20 Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 6 Apr 2026 01:24:31 -0500 Subject: [PATCH 6/7] add donation test Signed-off-by: ran --- tests/test_primitive_and_donation.py | 296 +++++++++++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 tests/test_primitive_and_donation.py diff --git a/tests/test_primitive_and_donation.py b/tests/test_primitive_and_donation.py new file mode 100644 index 00000000..90c85e93 --- /dev/null +++ b/tests/test_primitive_and_donation.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the paged attention primitive and scatter-based cache donation. + +Covers two of Eric's review concerns on PR #225: + 1. No test coverage for the primitive path (paged_attention_primitive). + 2. Buffer donation may silently fail, making scatter-based cache writes + O(entire_cache) instead of O(new_tokens). + +Run with: + python -m pytest tests/test_primitive_and_donation.py -v -s +""" + +from __future__ import annotations + +import mlx.core as mx +import numpy as np +import pytest + +from tools.attention_bench_utils import ref_paged_attn +from vllm_metal.metal import get_ops +from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache + +# ── Shared fixtures ────────────────────────────────────────────────────────── + +NUM_KV_HEADS_CASES = [2, 4] +NUM_QUERY_HEADS_CASES = [4, 8] # must be divisible by corresponding kv heads +HEAD_SIZE = 128 +BLOCK_SIZE = 16 +DTYPE = mx.float16 + + +def _make_cache_and_inputs( + num_blocks: int, + num_kv_heads: int, + num_query_heads: int, + seq_lens: list[tuple[int, int]], + *, + dtype: mx.Dtype = DTYPE, +): + """Build a populated cache and matching query/metadata tensors.""" + block_size = BLOCK_SIZE + head_size = HEAD_SIZE + num_seqs = len(seq_lens) + query_lens = [s[0] for s in seq_lens] + kv_lens = [s[1] for s in seq_lens] + total_q = sum(query_lens) + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + key_cache = mx.random.normal( + shape=(num_blocks, block_size, num_kv_heads, head_size) + ).astype(dtype) + value_cache = mx.random.normal( + shape=(num_blocks, block_size, num_kv_heads, head_size) + ).astype(dtype) + query = mx.random.normal(shape=(total_q, num_query_heads, head_size)).astype(dtype) + + max_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = mx.random.randint( + 0, num_blocks, shape=(num_seqs, max_blocks_per_seq) + ).astype(mx.int32) + + kv_lens_arr = mx.array(kv_lens, dtype=mx.int32) + cu_seqlens_q = mx.cumsum(mx.array([0] + query_lens, dtype=mx.int32)) + + mx.eval(key_cache, value_cache, query, block_tables, kv_lens_arr, cu_seqlens_q) + + return { + "query": query, + "key_cache": key_cache, + "value_cache": value_cache, + "num_kv_heads": num_kv_heads, + "scale": scale, + "block_tables": block_tables, + "kv_lens_arr": kv_lens_arr, + "cu_seqlens_q": cu_seqlens_q, + "query_lens": query_lens, + "kv_lens": kv_lens, + "max_kv_len": max_kv_len, + } + + +# ── 1. Primitive correctness tests ────────────────────────────────────────── + + +@pytest.mark.parametrize( + "seq_lens", + [ + [(1, 523), (1, 37), (1, 2011)], + [(1, 1), (1, 128), (1, 2048)], + ], +) +@pytest.mark.parametrize( + "num_heads", + [(4, 4), (8, 2)], +) +@pytest.mark.parametrize("num_blocks", [256]) +def test_primitive_vs_reference_decode( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + num_blocks: int, +) -> None: + """paged_attention_primitive matches the pure-MLX reference (decode).""" + mx.random.seed(0) + num_query_heads, num_kv_heads = num_heads + d = _make_cache_and_inputs(num_blocks, num_kv_heads, num_query_heads, seq_lens) + + ops = get_ops() + out = mx.array(0) + ops.paged_attention_primitive( + d["query"], + d["key_cache"], + d["value_cache"], + d["num_kv_heads"], + d["scale"], + 0.0, # softcap + d["block_tables"], + d["kv_lens_arr"], + d["cu_seqlens_q"], + BLOCK_SIZE, + d["max_kv_len"], + -1, # sliding_window + out, + ) + mx.eval(out) + + ref = ref_paged_attn( + query=d["query"], + key_cache=d["key_cache"], + value_cache=d["value_cache"], + query_lens=d["query_lens"], + kv_lens=d["kv_lens"], + block_tables=np.array(d["block_tables"]), + scale=d["scale"], + ) + mx.eval(ref) + + np.testing.assert_allclose( + np.array(out), + np.array(ref), + atol=1.5e-2, + rtol=1e-2, + ) + + +@pytest.mark.parametrize( + "seq_lens", + [ + [(1, 1328), (5, 18), (129, 463)], + [(1, 523), (1, 37), (1, 2011)], + ], +) +@pytest.mark.parametrize( + "num_heads", + [(4, 4), (8, 2)], +) +@pytest.mark.parametrize("num_blocks", [256]) +def test_primitive_vs_reference_varlen( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + num_blocks: int, +) -> None: + """paged_attention_primitive matches reference for mixed prefill+decode.""" + mx.random.seed(0) + num_query_heads, num_kv_heads = num_heads + d = _make_cache_and_inputs(num_blocks, num_kv_heads, num_query_heads, seq_lens) + + ops = get_ops() + out = mx.array(0) + ops.paged_attention_primitive( + d["query"], + d["key_cache"], + d["value_cache"], + d["num_kv_heads"], + d["scale"], + 0.0, # softcap + d["block_tables"], + d["kv_lens_arr"], + d["cu_seqlens_q"], + BLOCK_SIZE, + d["max_kv_len"], + -1, # sliding_window + out, + ) + mx.eval(out) + + ref = ref_paged_attn( + query=d["query"], + key_cache=d["key_cache"], + value_cache=d["value_cache"], + query_lens=d["query_lens"], + kv_lens=d["kv_lens"], + block_tables=np.array(d["block_tables"]), + scale=d["scale"], + ) + mx.eval(ref) + + np.testing.assert_allclose( + np.array(out), + np.array(ref), + atol=1.5e-2, + rtol=1e-2, + ) + + +# ── 2. Buffer donation test ───────────────────────────────────────────────── + + +@pytest.mark.parametrize("num_blocks", [128, 256]) +def test_scatter_cache_donation(num_blocks: int) -> None: + """Verify that scatter-based cache write reuses the buffer (donation). + + MLX's buffer donation is an optimisation, not a contract. This test + acts as an early-warning: if donation stops happening (due to MLX + changes or unexpected reference leaks), the memory delta will spike + and the assertion will fail. + """ + num_kv_heads = 4 + head_dim = HEAD_SIZE + num_layers = 1 + block_size = BLOCK_SIZE + dtype = DTYPE + + cache = MetalPagedKVCache( + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + num_blocks=num_blocks, + block_size=block_size, + dtype=dtype, + ) + # Already eval'd by MetalPagedKVCache.__init__ + + cache_nbytes = cache.key_caches[0].nbytes # size of one K or V cache + + # Simulate a decode step: scatter 3 tokens into random slots + num_tokens = 3 + slot_indices = mx.array( + np.random.choice(num_blocks * block_size, size=num_tokens, replace=False), + dtype=mx.int64, + ) + new_k = mx.random.normal(shape=(num_tokens, num_kv_heads, head_dim)).astype(dtype) + new_v = mx.random.normal(shape=(num_tokens, num_kv_heads, head_dim)).astype(dtype) + mx.eval(slot_indices, new_k, new_v) + + # Warm up: run multiple rounds so the MLX memory pool stabilises. + # IMPORTANT: delete all locals afterwards — any stray reference to the + # old cache array bumps use_count and defeats buffer donation. + for _ in range(5): + _fk = cache.key_caches[0].reshape(-1, num_kv_heads, head_dim) + _fk[slot_indices] = new_k + _wk = _fk.reshape(cache.key_caches[0].shape) + cache.key_caches[0] = _wk + _fv = cache.value_caches[0].reshape(-1, num_kv_heads, head_dim) + _fv[slot_indices] = new_v + _wv = _fv.reshape(cache.value_caches[0].shape) + cache.value_caches[0] = _wv + mx.eval(cache.key_caches[0], cache.value_caches[0]) + del _fk, _wk, _fv, _wv + + # ── Measure at steady state (average of several rounds) ── + total_delta = 0 + num_rounds = 5 + for _ in range(num_rounds): + mem_before = mx.get_active_memory() + + # K cache scatter + rebind + flat_k = cache.key_caches[0].reshape(-1, num_kv_heads, head_dim) + flat_k[slot_indices] = new_k + new_k_cache = flat_k.reshape(cache.key_caches[0].shape) + cache.key_caches[0] = new_k_cache + + # V cache scatter + rebind + flat_v = cache.value_caches[0].reshape(-1, num_kv_heads, head_dim) + flat_v[slot_indices] = new_v + new_v_cache = flat_v.reshape(cache.value_caches[0].shape) + cache.value_caches[0] = new_v_cache + + mx.eval(cache.key_caches[0], cache.value_caches[0]) + + mem_after = mx.get_active_memory() + total_delta += mem_after - mem_before + del flat_k, new_k_cache, flat_v, new_v_cache + + avg_delta = total_delta / num_rounds + + # If donation works: avg_delta ≈ 0 (buffers reused in-place). + # If donation fails: avg_delta ≈ 2 * cache_nbytes (full copy for K + V). + # Allow generous headroom (one full cache) for pool fluctuations. + threshold = cache_nbytes + assert avg_delta < threshold, ( + f"Buffer donation likely failed: avg memory growth {avg_delta:,.0f} " + f"bytes/round over {num_rounds} rounds, " + f"but each cache is only {cache_nbytes:,} bytes. " + f"Expected near-zero growth with donation." + ) From 9848cf65de97cd118db5591ee201a815daaddd97 Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 6 Apr 2026 01:29:26 -0500 Subject: [PATCH 7/7] add donation comments Signed-off-by: ran --- vllm_metal/metal_kernel_backend/attention_sdpa.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index db80f158..a9f2bf05 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -113,6 +113,12 @@ def sdpa_forward( # by slot_mapping, then reshape back. This creates proper graph nodes # that MLX's evaluator can track for dependency ordering and buffer # donation — no in-place mutation, no copy_shared_buffer, no const_cast. + # + # DONATION INVARIANT: the rebind (below) must drop the list's reference + # to the old cache *before* mx.eval runs. At eval time the old cache + # must have use_count == 1 (only the graph) for MLX to donate its + # buffer to the scatter output. Do NOT insert mx.eval between the + # scatter and the rebind, or hold extra references to the old cache. flat_k = kv_cache.key_caches[layer_idx].reshape(-1, kv_cache.num_kv_heads, head_dim) flat_k[slot_mapping] = k_3d new_k_cache = flat_k.reshape(kv_cache.key_caches[layer_idx].shape)