From 3d6d683f05a5b1415298e4413cf4c0f749255522 Mon Sep 17 00:00:00 2001 From: kingwl Date: Thu, 19 Mar 2026 22:08:15 +0800 Subject: [PATCH] Support partitioned Metal attention Signed-off-by: kingwl --- tests/test_metal_unified_attention.py | 11 +- vllm_metal/metal/__init__.py | 75 ++++++-- vllm_metal/metal/build.py | 13 +- vllm_metal/metal/constants.py | 5 + .../metal/kernels_v1/pagedattention.metal | 10 +- .../metal/kernels_v2/pagedattention.metal | 36 ++-- vllm_metal/metal/paged_ops.cpp | 175 +++++++++++++++++- 7 files changed, 276 insertions(+), 49 deletions(-) create mode 100644 vllm_metal/metal/constants.py diff --git a/tests/test_metal_unified_attention.py b/tests/test_metal_unified_attention.py index e34d8115..8878903c 100644 --- a/tests/test_metal_unified_attention.py +++ b/tests/test_metal_unified_attention.py @@ -13,7 +13,7 @@ import pytest from tools.attention_bench_utils import ref_paged_attn, run_v1_paged_attention -from vllm_metal.metal import metal_unified_attention +from vllm_metal.metal import PARTITION_THRESHOLD, metal_unified_attention # Original upstream parameters (vLLM Triton/CUDA test_triton_unified_attention.py): # HEAD_SIZES = [128, 256] @@ -220,9 +220,14 @@ def test_metal_unified_attn_decode_only( softcap=0, ) - # v2 must match v1 exactly (same algorithm, same precision) + # Partitioned decode changes the reduction order versus v1, so long-context + # cases need a slightly looser tolerance than the no-partition exact-match + # path. + atol = rtol = 1e-4 + if max_kv_len >= PARTITION_THRESHOLD: + atol = rtol = 3e-4 np.testing.assert_allclose( - np.array(v2_output), np.array(v1_output), atol=1e-4, rtol=1e-4 + np.array(v2_output), np.array(v1_output), atol=atol, rtol=rtol ) diff --git a/vllm_metal/metal/__init__.py b/vllm_metal/metal/__init__.py index 29f4c18f..9fe29b93 100644 --- a/vllm_metal/metal/__init__.py +++ b/vllm_metal/metal/__init__.py @@ -18,6 +18,8 @@ from pathlib import Path from types import ModuleType +from vllm_metal.metal.constants import PARTITION_SIZE, PARTITION_THRESHOLD + logger = logging.getLogger(__name__) _THIS_DIR = Path(__file__).resolve().parent @@ -53,6 +55,7 @@ def _build_reshape_cache_source() -> str: def _build_paged_attention_source() -> str: """Concatenate float8 + utils + paged_attention into a single source.""" parts = [ + f"#define VLLM_METAL_PARTITION_SIZE {PARTITION_SIZE}", _read_metal_source(_KERNELS_DIR / "float8.metal"), _read_metal_source(_KERNELS_DIR / "utils.metal"), _read_metal_source(_KERNELS_DIR / "pagedattention.metal"), @@ -63,6 +66,7 @@ def _build_paged_attention_source() -> str: def _build_v2_paged_attention_source() -> str: """Concatenate float8 + utils + v2 paged_attention (online softmax).""" parts = [ + f"#define VLLM_METAL_PARTITION_SIZE {PARTITION_SIZE}", _read_metal_source(_KERNELS_V2_DIR / "float8.metal"), _read_metal_source(_KERNELS_V2_DIR / "utils.metal"), _read_metal_source(_KERNELS_V2_DIR / "pagedattention.metal"), @@ -114,23 +118,62 @@ def metal_unified_attention( # Ensure all inputs are evaluated before raw Metal dispatch mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q) - - ops.paged_attention_v2_online( - out, - q, - k, - v, - num_kv_heads, - softmax_scale, - softcap, - block_table, - seqused_k, - cu_seqlens_q, - block_size, - max_seqlen_k, - sliding_window, + max_num_partitions = max(1, (max_seqlen_k + PARTITION_SIZE - 1) // PARTITION_SIZE) + use_partitioning = ( + PARTITION_SIZE % block_size == 0 + and max_seqlen_q == 1 + and max_seqlen_k >= PARTITION_THRESHOLD + and max_num_partitions > 1 ) - mx.synchronize() + + if use_partitioning: + exp_sums = mx.zeros( + (q.shape[0], q.shape[1], max_num_partitions), dtype=mx.float32 + ) + max_logits = mx.zeros( + (q.shape[0], q.shape[1], max_num_partitions), dtype=mx.float32 + ) + tmp_out = mx.zeros( + (q.shape[0], q.shape[1], max_num_partitions, q.shape[2]), + dtype=q.dtype, + ) + mx.eval(exp_sums, max_logits, tmp_out) + ops.paged_attention_v2_online_partitioned( + out, + q, + k, + v, + num_kv_heads, + softmax_scale, + softcap, + block_table, + seqused_k, + cu_seqlens_q, + block_size, + max_seqlen_k, + sliding_window, + exp_sums, + max_logits, + tmp_out, + ) + mx.synchronize() + else: + ops.paged_attention_v2_online( + out, + q, + k, + v, + num_kv_heads, + softmax_scale, + softcap, + block_table, + seqused_k, + cu_seqlens_q, + block_size, + max_seqlen_k, + sliding_window, + ) + mx.synchronize() def get_ops() -> ModuleType: diff --git a/vllm_metal/metal/build.py b/vllm_metal/metal/build.py index 137e5198..17aa2cdb 100644 --- a/vllm_metal/metal/build.py +++ b/vllm_metal/metal/build.py @@ -12,10 +12,14 @@ import sysconfig from pathlib import Path +from vllm_metal.metal.constants import PARTITION_SIZE + logger = logging.getLogger(__name__) _THIS_DIR = Path(__file__).resolve().parent _SRC = _THIS_DIR / "paged_ops.cpp" +_BUILD = _THIS_DIR / "build.py" +_CONSTANTS = _THIS_DIR / "constants.py" _EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX") or ".so" _CACHE_DIR = Path.home() / ".cache" / "vllm-metal" _CACHE_DIR.mkdir(parents=True, exist_ok=True) @@ -40,8 +44,12 @@ def needs_rebuild() -> bool: """Return True if the .so is missing or older than the source.""" if not _OUT.exists(): return True - src_mtime = _SRC.stat().st_mtime - return _OUT.stat().st_mtime < src_mtime + latest_input_mtime = max( + _SRC.stat().st_mtime, + _BUILD.stat().st_mtime, + _CONSTANTS.stat().st_mtime, + ) + return _OUT.stat().st_mtime < latest_input_mtime def build() -> Path: @@ -94,6 +102,7 @@ def build() -> Path: f"-Wl,-rpath,{mlx_lib}", "-D_METAL_", "-DACCELERATE_NEW_LAPACK", + f"-DVLLM_METAL_PARTITION_SIZE={PARTITION_SIZE}", "-undefined", "dynamic_lookup", str(nb_src), diff --git a/vllm_metal/metal/constants.py b/vllm_metal/metal/constants.py new file mode 100644 index 00000000..b8a8aea2 --- /dev/null +++ b/vllm_metal/metal/constants.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared constants for Metal paged-attention partitioning.""" + +PARTITION_SIZE = 512 +PARTITION_THRESHOLD = 4096 diff --git a/vllm_metal/metal/kernels_v1/pagedattention.metal b/vllm_metal/metal/kernels_v1/pagedattention.metal index 8cf3ad84..e2c54d67 100644 --- a/vllm_metal/metal/kernels_v1/pagedattention.metal +++ b/vllm_metal/metal/kernels_v1/pagedattention.metal @@ -1367,15 +1367,17 @@ template (out_h); auto& query = *nb::inst_ptr(query_h); @@ -290,6 +299,11 @@ void paged_attention_v2_online_impl( int head_size = static_cast(query.shape(2)); int max_blocks = static_cast(block_tables.shape(1)); int num_seqs = static_cast(cu_seqlens_q.shape(0)) - 1; + int max_num_partitions = + std::max(1, (max_seq_len + kPartitionSize - 1) / kPartitionSize); + bool use_partitioning = + exp_sums != nullptr && max_logits != nullptr && tmp_out != nullptr && + kPartitionSize % block_size == 0 && max_num_partitions > 1; // Same kernel name format as v1 — the template instantiation is identical. auto dt = dtype_to_metal(query.dtype()); @@ -297,12 +311,12 @@ void paged_attention_v2_online_impl( "paged_attention_" + dt + "_cache_" + dt + "_hs" + std::to_string(head_size) + "_bs" + std::to_string(block_size) + - "_nt256_nsl32_ps0"; + "_nt256_nsl32_ps" + + std::to_string(use_partitioning ? kPartitionSize : 0); - bool use_partitioning = false; bool use_alibi = false; bool use_fp8 = false; - bool use_sinks = false; + bool use_sinks = sinks != nullptr; // Use the v2 library (online softmax kernel). // get_kernel(function_name, library, cache_key, constants) @@ -332,7 +346,17 @@ void paged_attention_v2_online_impl( enc.set_threadgroup_memory_length(shmem, 0); // Buffer bindings - enc.set_output_array(out, 2); + if (use_partitioning) { + if (exp_sums == nullptr || max_logits == nullptr || tmp_out == nullptr) { + throw std::runtime_error( + "Partitioned v2 attention requires scratch buffers"); + } + enc.set_output_array(*exp_sums, 0); + enc.set_output_array(*max_logits, 1); + enc.set_output_array(*tmp_out, 2); + } else { + enc.set_output_array(out, 2); + } enc.set_input_array(query, 3); enc.set_input_array(key_cache, 4); enc.set_input_array(value_cache, 5); @@ -356,6 +380,9 @@ void paged_attention_v2_online_impl( enc.set_bytes(q_stride, 15); enc.set_bytes(kv_block_stride, 16); enc.set_bytes(kv_head_stride, 17); + if (use_sinks) { + enc.set_input_array(*sinks, 18); + } // Varlen buffers (new in v2) enc.set_input_array(cu_seqlens_q, 19); @@ -365,10 +392,44 @@ void paged_attention_v2_online_impl( enc.set_bytes(sliding_window_i, 21); // Grid: one threadgroup per (head, query_token) + const int32_t grid_z = + static_cast(use_partitioning ? max_num_partitions : 1); enc.dispatch_threadgroups( - MTL::Size::Make(num_heads, total_q_tokens, 1), + MTL::Size::Make(num_heads, total_q_tokens, grid_z), MTL::Size::Make(NUM_THREADS, 1, 1)); + if (use_partitioning) { + std::string reduce_kname = + "paged_attention_v2_reduce_" + dt + + "_hs" + std::to_string(head_size) + + "_nt256_nsl32_ps" + std::to_string(kPartitionSize); + auto* reduce_kernel = d.get_kernel( + reduce_kname, + lib, + reduce_kname + "_v2_reduce", + {{&use_sinks, MTL::DataType::DataTypeBool, NS::UInteger(40)}}); + size_t reduce_shmem = + static_cast(2 * max_num_partitions * sizeof(float)); + enc.set_compute_pipeline_state(reduce_kernel); + enc.set_threadgroup_memory_length(reduce_shmem, 0); + + enc.set_output_array(out, 0); + enc.set_input_array(*exp_sums, 1); + enc.set_input_array(*max_logits, 2); + enc.set_input_array(*tmp_out, 3); + enc.set_input_array(seq_lens, 4); + int32_t max_num_partitions_i = static_cast(max_num_partitions); + enc.set_bytes(max_num_partitions_i, 5); + if (use_sinks) { + enc.set_input_array(*sinks, 6); + } + enc.set_input_array(cu_seqlens_q, 7); + enc.set_bytes(num_seqs_i, 8); + enc.dispatch_threadgroups( + MTL::Size::Make(num_heads, total_q_tokens, 1), + MTL::Size::Make(NUM_THREADS, 1, 1)); + } + d.add_temporary(out, s.index); d.add_temporary(query, s.index); d.add_temporary(key_cache, s.index); @@ -376,6 +437,90 @@ void paged_attention_v2_online_impl( d.add_temporary(block_tables, s.index); d.add_temporary(seq_lens, s.index); d.add_temporary(cu_seqlens_q, s.index); + if (use_partitioning) { + d.add_temporary(*exp_sums, s.index); + d.add_temporary(*max_logits, s.index); + d.add_temporary(*tmp_out, s.index); + } + if (use_sinks) { + d.add_temporary(*sinks, s.index); + } +} + +void paged_attention_v2_online_impl( + nb::handle out_h, + nb::handle query_h, + nb::handle key_cache_h, + nb::handle value_cache_h, + int num_kv_heads, + float scale, + float softcap, + nb::handle block_tables_h, + nb::handle seq_lens_h, + nb::handle cu_seqlens_q_h, + int block_size, + int max_seq_len, + int sliding_window +) { + paged_attention_v2_online_impl_common( + out_h, + query_h, + key_cache_h, + value_cache_h, + num_kv_heads, + scale, + softcap, + block_tables_h, + seq_lens_h, + cu_seqlens_q_h, + block_size, + max_seq_len, + sliding_window, + nullptr, + nullptr, + nullptr, + nullptr); +} + +void paged_attention_v2_online_partitioned_impl( + nb::handle out_h, + nb::handle query_h, + nb::handle key_cache_h, + nb::handle value_cache_h, + int num_kv_heads, + float scale, + float softcap, + nb::handle block_tables_h, + nb::handle seq_lens_h, + nb::handle cu_seqlens_q_h, + int block_size, + int max_seq_len, + int sliding_window, + nb::handle exp_sums_h, + nb::handle max_logits_h, + nb::handle tmp_out_h +) { + auto& exp_sums = *nb::inst_ptr(exp_sums_h); + auto& max_logits = *nb::inst_ptr(max_logits_h); + auto& tmp_out = *nb::inst_ptr(tmp_out_h); + paged_attention_v2_online_impl_common( + out_h, + query_h, + key_cache_h, + value_cache_h, + num_kv_heads, + scale, + softcap, + block_tables_h, + seq_lens_h, + cu_seqlens_q_h, + block_size, + max_seq_len, + sliding_window, + &exp_sums, + &max_logits, + &tmp_out, + nullptr); } // --------------------------------------------------------------------------- @@ -383,6 +528,8 @@ void paged_attention_v2_online_impl( // --------------------------------------------------------------------------- NB_MODULE(_paged_ops, m) { + m.attr("PARTITION_SIZE") = nb::int_(kPartitionSize); + m.def("init_libraries", &init_libraries, nb::arg("reshape_src"), nb::arg("paged_attn_src"), "JIT-compile the vendored Metal shaders."); @@ -415,4 +562,18 @@ NB_MODULE(_paged_ops, m) { nb::arg("block_size"), nb::arg("max_seq_len"), nb::arg("sliding_window"), "Online-softmax varlen paged attention (v2, unified prefill+decode)."); + + m.def("paged_attention_v2_online_partitioned", + &paged_attention_v2_online_partitioned_impl, + nb::arg("out"), nb::arg("query"), + nb::arg("key_cache"), nb::arg("value_cache"), + nb::arg("num_kv_heads"), nb::arg("scale"), + nb::arg("softcap"), + nb::arg("block_tables"), nb::arg("seq_lens"), + nb::arg("cu_seqlens_q"), + nb::arg("block_size"), nb::arg("max_seq_len"), + nb::arg("sliding_window"), + nb::arg("exp_sums"), nb::arg("max_logits"), nb::arg("tmp_out"), + "Online-softmax varlen paged attention (v2) with caller-provided " + "partition scratch buffers."); }