From 6e5a957a5e991547a570095cdf191b2a84f28c66 Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 16 Mar 2026 15:53:46 -0500 Subject: [PATCH 1/5] varlen prototype Signed-off-by: ran --- tests/test_metal_unified_attention.py | 6 -- vllm_metal/metal/__init__.py | 36 +++++----- .../metal/kernels_v2/pagedattention.metal | 71 ++++++++++++++++--- vllm_metal/metal/paged_ops.cpp | 31 ++++++-- 4 files changed, 105 insertions(+), 39 deletions(-) diff --git a/tests/test_metal_unified_attention.py b/tests/test_metal_unified_attention.py index 4a3187c5..f40fcfe8 100644 --- a/tests/test_metal_unified_attention.py +++ b/tests/test_metal_unified_attention.py @@ -357,12 +357,6 @@ def test_metal_unified_attn( query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] - # xfail cases that need features not yet in the v2 kernel: - # varlen (q_len > 1), sliding window, or soft capping. - # Decode-only cases with no extras already work and should pass. - max_query_len_val = max(query_lens) - if max_query_len_val > 1 or sliding_window is not None or soft_cap is not None: - pytest.xfail("v2 varlen/sliding-window/soft-cap not yet implemented") num_query_heads = num_heads[0] num_kv_heads = num_heads[1] assert num_query_heads % num_kv_heads == 0 diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py index 92fdf222..7badfa8e 100644 --- a/vllm_metal/metal/__init__.py +++ b/vllm_metal/metal/__init__.py @@ -87,35 +87,32 @@ def metal_unified_attention( ) -> None: """Unified varlen paged attention for Metal. - Currently supports decode-only (max_seqlen_q=1). Sliding window and - soft capping are not yet supported. These will be enabled when the v2 - kernel is extended to handle variable-length queries (prefill + decode). + Supports variable-length queries (prefill + decode) with online softmax, + paged KV cache, causal masking, sliding window, and soft capping. + + Grid: one threadgroup per (head, query_token). Each threadgroup uses + binary search on cu_seqlens_q to find its sequence and computes causal + attention against the paged KV cache. """ import mlx.core as mx - if max_seqlen_q != 1: - raise NotImplementedError( - f"metal_unified_attention only supports decode (max_seqlen_q=1), " - f"got {max_seqlen_q}" - ) - if window_size != (-1, -1): - raise NotImplementedError( - f"Sliding window not yet supported, got window_size={window_size}" - ) - if softcap != 0: - raise NotImplementedError( - f"Soft capping not yet supported, got softcap={softcap}" - ) - # Extract dimensions from cache shape # k shape: [num_blocks, block_size, num_kv_heads, head_size] num_kv_heads = k.shape[2] block_size = k.shape[1] + # Convert window_size tuple to a single sliding_window int. + # window_size = (left, right) where left = sw-1, right = 0 for causal. + # sliding_window = left + 1 = total window size. -1 = disabled. + if window_size == (-1, -1): + sliding_window = -1 + else: + sliding_window = window_size[0] + 1 + ops = get_ops() # Ensure all inputs are evaluated before raw Metal dispatch - mx.eval(out, q, k, v, block_table, seqused_k) + mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q) ops.paged_attention_v2_online( out, @@ -124,10 +121,13 @@ def metal_unified_attention( v, num_kv_heads, softmax_scale, + softcap, block_table, seqused_k, + cu_seqlens_q, block_size, max_seqlen_k, + sliding_window, ) mx.synchronize() diff --git a/vllm_metal/metal/kernels_v2/pagedattention.metal b/vllm_metal/metal/kernels_v2/pagedattention.metal index 48eea53d..5d81d5b2 100644 --- a/vllm_metal/metal/kernels_v2/pagedattention.metal +++ b/vllm_metal/metal/kernels_v2/pagedattention.metal @@ -756,6 +756,34 @@ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +// Binary search to find which sequence a global query token belongs to. +// +// In varlen (ragged-batch) attention, queries from multiple sequences are +// packed contiguously into a flat array: +// q[0..q_len_0-1] → seq 0, q[q_len_0..q_len_0+q_len_1-1] → seq 1, ... +// The kernel launches one threadgroup per (head, query_token) in a flat grid. +// Each threadgroup needs to discover which sequence it belongs to so it can +// look up the correct block_table row, kv_len, and causal mask boundary. +// +// This is the same approach used by the upstream vLLM unified Triton kernel +// (triton_unified_attention.py:find_seq_idx) and FlashAttention's varlen API. +// +// cu_seqlens_q is sorted ascending: [0, q_len_0, q_len_0+q_len_1, ...]. +// Returns seq_idx such that cu_seqlens_q[seq_idx] <= q_token_idx < cu_seqlens_q[seq_idx+1]. +inline int find_seq_idx(const device int32_t *cu_seqlens_q, + int q_token_idx, int num_seqs) { + int lo = 0, hi = num_seqs; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (cu_seqlens_q[mid] <= q_token_idx) { + lo = mid; + } else { + hi = mid - 1; + } + } + return lo; +} + constant bool use_partitioning [[function_constant(10)]]; constant bool use_alibi [[function_constant(20)]]; constant bool use_fp8_scales [[function_constant(30)]]; @@ -795,24 +823,41 @@ template 0; - const uint32_t context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const uint32_t context_len = context_lens[seq_idx]; // total KV length for this seq + + // Causal: this query token can attend to KV positions [0, effective_context_len). + const int effective_context_len = (int)context_len - q_len + q_pos_in_seq + 1; + if (effective_context_len <= 0) { + // No KV tokens to attend to — output zeros (handled naturally by loop not executing). + return; + } + + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= effective_context_len) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_context_blocks = DIVIDE_ROUND_UP(effective_context_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; @@ -867,7 +912,7 @@ template = context_len; + // Causal mask: only attend to KV positions < effective_context_len. + bool mask = token_idx >= effective_context_len; + // Sliding window mask: skip positions too far in the past. + if (sliding_window >= 0) { + mask = mask || (token_idx < effective_context_len - sliding_window); + } warp_scores[physical_block_offset] = mask ? -FLT_MAX : qk; } } @@ -981,7 +1031,7 @@ template (out_h); auto& query = *nb::inst_ptr(query_h); @@ -276,14 +279,17 @@ void paged_attention_v2_online_impl( auto& value_cache = *nb::inst_ptr(value_cache_h); auto& block_tables = *nb::inst_ptr(block_tables_h); auto& seq_lens = *nb::inst_ptr(seq_lens_h); + auto& cu_seqlens_q = *nb::inst_ptr(cu_seqlens_q_h); auto s = default_stream(Device::gpu); auto& d = metal::device(Device::gpu); - int num_seqs = static_cast(query.shape(0)); + // Varlen: query shape is [total_q_tokens, num_heads, head_size] + int total_q_tokens = static_cast(query.shape(0)); int num_heads = static_cast(query.shape(1)); int head_size = static_cast(query.shape(2)); int max_blocks = static_cast(block_tables.shape(1)); + int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; // Same kernel name format as v1 — the template instantiation is identical. auto dt = dtype_to_metal(query.dtype()); @@ -325,7 +331,7 @@ void paged_attention_v2_online_impl( enc.set_compute_pipeline_state(kernel); enc.set_threadgroup_memory_length(shmem, 0); - // Buffer bindings — identical to v1. + // Buffer bindings enc.set_output_array(out, 2); enc.set_input_array(query, 3); enc.set_input_array(key_cache, 4); @@ -334,7 +340,8 @@ void paged_attention_v2_online_impl( int32_t nkv = static_cast(num_kv_heads); enc.set_bytes(nkv, 8); enc.set_bytes(scale, 9); - float softcapping = 1.0f; + // softcap=0 means disabled; the kernel uses 1.0 as the disabled sentinel. + float softcapping = (softcap == 0.f) ? 1.0f : softcap; enc.set_bytes(softcapping, 10); enc.set_input_array(block_tables, 11); @@ -350,8 +357,16 @@ void paged_attention_v2_online_impl( enc.set_bytes(kv_block_stride, 16); enc.set_bytes(kv_head_stride, 17); + // Varlen buffers (new in v2) + enc.set_input_array(cu_seqlens_q, 19); + int32_t num_seqs_i = static_cast(num_seqs); + enc.set_bytes(num_seqs_i, 20); + int32_t sliding_window_i = static_cast(sliding_window); + enc.set_bytes(sliding_window_i, 21); + + // Grid: one threadgroup per (head, query_token) enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, num_seqs, 1), + MTL::Size::Make(num_heads, total_q_tokens, 1), MTL::Size::Make(NUM_THREADS, 1, 1)); d.add_temporary(out, s.index); @@ -360,6 +375,7 @@ void paged_attention_v2_online_impl( d.add_temporary(value_cache, s.index); d.add_temporary(block_tables, s.index); d.add_temporary(seq_lens, s.index); + d.add_temporary(cu_seqlens_q, s.index); } // --------------------------------------------------------------------------- @@ -393,7 +409,10 @@ NB_MODULE(_paged_ops, m) { nb::arg("out"), nb::arg("query"), nb::arg("key_cache"), nb::arg("value_cache"), nb::arg("num_kv_heads"), nb::arg("scale"), + nb::arg("softcap"), nb::arg("block_tables"), nb::arg("seq_lens"), + nb::arg("cu_seqlens_q"), nb::arg("block_size"), nb::arg("max_seq_len"), - "Online-softmax paged attention (v2, decode-only)."); + nb::arg("sliding_window"), + "Online-softmax varlen paged attention (v2, unified prefill+decode)."); } From 32d27c4decf92fbceb4ab54a1e298561d40c29e0 Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 16 Mar 2026 16:07:45 -0500 Subject: [PATCH 2/5] update fun signature, to align with full unified attention Signed-off-by: ran --- vllm_metal/metal_kernel_backend/paged_attention.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index a590d495..0c040ffa 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -226,7 +226,11 @@ def _metal_kernel_decode_attention( max_seq_len = max(ctx.context_lens) scale = attn_module.scale - # Zero-copy paged attention (v2, online softmax) + # Build cu_seqlens_q for varlen dispatch: decode has q_len=1 per sequence. + cu_seqlens_q = mx.array(list(range(B + 1)), dtype=mx.int32) + mx.eval(cu_seqlens_q) + + # Zero-copy paged attention (v2, online softmax, varlen-capable) ops.paged_attention_v2_online( out, q_3d, @@ -234,10 +238,13 @@ def _metal_kernel_decode_attention( cache.value_caches[layer_idx], cache.num_kv_heads, scale, + 0.0, # softcap (0 = disabled) block_tables, seq_lens, + cu_seqlens_q, cache.block_size, max_seq_len, + -1, # sliding_window (-1 = disabled) ) # Synchronize GPU: paged_attention_v2_online wrote to out's buffer via a raw From 56fd66580e817d1a93474640bfedc58ea589dea9 Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 16 Mar 2026 23:28:46 -0500 Subject: [PATCH 3/5] fix two nits Signed-off-by: ran --- vllm_metal/metal/kernels_v2/pagedattention.metal | 2 +- vllm_metal/metal/paged_ops.cpp | 4 ++-- vllm_metal/metal_kernel_backend/paged_attention.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_metal/metal/kernels_v2/pagedattention.metal b/vllm_metal/metal/kernels_v2/pagedattention.metal index 5d81d5b2..bfb598cf 100644 --- a/vllm_metal/metal/kernels_v2/pagedattention.metal +++ b/vllm_metal/metal/kernels_v2/pagedattention.metal @@ -1000,7 +1000,7 @@ template ::dot( q_vecs[thread_group_offset], k_vecs); - if (softcapping != 1.0) { + if (softcapping > 0.0f) { qk = tanh(qk / softcapping) * softcapping; } diff --git a/vllm_metal/metal/paged_ops.cpp b/vllm_metal/metal/paged_ops.cpp index 21d330aa..07203f70 100644 --- a/vllm_metal/metal/paged_ops.cpp +++ b/vllm_metal/metal/paged_ops.cpp @@ -340,8 +340,8 @@ void paged_attention_v2_online_impl( int32_t nkv = static_cast(num_kv_heads); enc.set_bytes(nkv, 8); enc.set_bytes(scale, 9); - // softcap=0 means disabled; the kernel uses 1.0 as the disabled sentinel. - float softcapping = (softcap == 0.f) ? 1.0f : softcap; + // softcap: 0.0 = disabled, >0 = enabled. Passed through to kernel as-is. + float softcapping = softcap; enc.set_bytes(softcapping, 10); enc.set_input_array(block_tables, 11); diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index 0c040ffa..fe0f6d09 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -227,7 +227,7 @@ def _metal_kernel_decode_attention( scale = attn_module.scale # Build cu_seqlens_q for varlen dispatch: decode has q_len=1 per sequence. - cu_seqlens_q = mx.array(list(range(B + 1)), dtype=mx.int32) + cu_seqlens_q = mx.arange(B + 1, dtype=mx.int32) mx.eval(cu_seqlens_q) # Zero-copy paged attention (v2, online softmax, varlen-capable) From b780e80f2445ed3976930cefc08d91d8b4c9c2ac Mon Sep 17 00:00:00 2001 From: ran Date: Mon, 16 Mar 2026 23:50:07 -0500 Subject: [PATCH 4/5] blindly fixed partition Signed-off-by: ran --- vllm_metal/metal/__init__.py | 1 + .../metal/kernels_v2/pagedattention.metal | 31 +++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py index 7badfa8e..29f4c18f 100644 --- a/vllm_metal/metal/__init__.py +++ b/vllm_metal/metal/__init__.py @@ -94,6 +94,7 @@ def metal_unified_attention( binary search on cu_seqlens_q to find its sequence and computes causal attention against the paged KV cache. """ + assert causal, "Only causal attention is supported" import mlx.core as mx # Extract dimensions from cache shape diff --git a/vllm_metal/metal/kernels_v2/pagedattention.metal b/vllm_metal/metal/kernels_v2/pagedattention.metal index bfb598cf..9954b31b 100644 --- a/vllm_metal/metal/kernels_v2/pagedattention.metal +++ b/vllm_metal/metal/kernels_v2/pagedattention.metal @@ -1108,13 +1108,14 @@ template (shared_mem); const device float *max_logits_ptr = - max_logits + seq_idx * num_heads * max_num_partitions + + max_logits + q_token_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = thread_position_in_threadgroup.x; i < num_partitions; @@ -1292,7 +1301,7 @@ template ( shared_mem + sizeof(float) * num_partitions); const device float *exp_sums_ptr = exp_sums + - seq_idx * num_heads * max_num_partitions + + q_token_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = thread_position_in_threadgroup.x; i < num_partitions; @@ -1315,10 +1324,10 @@ template Date: Tue, 17 Mar 2026 00:14:15 -0500 Subject: [PATCH 5/5] add oneline comment Signed-off-by: ran --- vllm_metal/metal/kernels_v2/pagedattention.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_metal/metal/kernels_v2/pagedattention.metal b/vllm_metal/metal/kernels_v2/pagedattention.metal index 9954b31b..b0d76206 100644 --- a/vllm_metal/metal/kernels_v2/pagedattention.metal +++ b/vllm_metal/metal/kernels_v2/pagedattention.metal @@ -848,7 +848,7 @@ template