Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/test_metal_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest

from tools.attention_bench_utils import ref_paged_attn, run_v1_paged_attention
from vllm_metal.metal import metal_unified_attention
from vllm_metal.metal import PARTITION_THRESHOLD, metal_unified_attention

# Original upstream parameters (vLLM Triton/CUDA test_triton_unified_attention.py):
# HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -220,9 +220,14 @@ def test_metal_unified_attn_decode_only(
softcap=0,
)

# v2 must match v1 exactly (same algorithm, same precision)
# Partitioned decode changes the reduction order versus v1, so long-context
# cases need a slightly looser tolerance than the no-partition exact-match
# path.
atol = rtol = 1e-4
if max_kv_len >= PARTITION_THRESHOLD:
atol = rtol = 3e-4
np.testing.assert_allclose(
np.array(v2_output), np.array(v1_output), atol=1e-4, rtol=1e-4
np.array(v2_output), np.array(v1_output), atol=atol, rtol=rtol
)


Expand Down
75 changes: 59 additions & 16 deletions vllm_metal/metal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from pathlib import Path
from types import ModuleType

from vllm_metal.metal.constants import PARTITION_SIZE, PARTITION_THRESHOLD

logger = logging.getLogger(__name__)

_THIS_DIR = Path(__file__).resolve().parent
Expand Down Expand Up @@ -53,6 +55,7 @@ def _build_reshape_cache_source() -> str:
def _build_paged_attention_source() -> str:
"""Concatenate float8 + utils + paged_attention into a single source."""
parts = [
f"#define VLLM_METAL_PARTITION_SIZE {PARTITION_SIZE}",
_read_metal_source(_KERNELS_DIR / "float8.metal"),
_read_metal_source(_KERNELS_DIR / "utils.metal"),
_read_metal_source(_KERNELS_DIR / "pagedattention.metal"),
Expand All @@ -63,6 +66,7 @@ def _build_paged_attention_source() -> str:
def _build_v2_paged_attention_source() -> str:
"""Concatenate float8 + utils + v2 paged_attention (online softmax)."""
parts = [
f"#define VLLM_METAL_PARTITION_SIZE {PARTITION_SIZE}",
_read_metal_source(_KERNELS_V2_DIR / "float8.metal"),
_read_metal_source(_KERNELS_V2_DIR / "utils.metal"),
_read_metal_source(_KERNELS_V2_DIR / "pagedattention.metal"),
Expand Down Expand Up @@ -114,23 +118,62 @@ def metal_unified_attention(

# Ensure all inputs are evaluated before raw Metal dispatch
mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q)

ops.paged_attention_v2_online(
out,
q,
k,
v,
num_kv_heads,
softmax_scale,
softcap,
block_table,
seqused_k,
cu_seqlens_q,
block_size,
max_seqlen_k,
sliding_window,
max_num_partitions = max(1, (max_seqlen_k + PARTITION_SIZE - 1) // PARTITION_SIZE)
use_partitioning = (
PARTITION_SIZE % block_size == 0
and max_seqlen_q == 1
and max_seqlen_k >= PARTITION_THRESHOLD
and max_num_partitions > 1
)
mx.synchronize()

if use_partitioning:
exp_sums = mx.zeros(
(q.shape[0], q.shape[1], max_num_partitions), dtype=mx.float32
)
max_logits = mx.zeros(
(q.shape[0], q.shape[1], max_num_partitions), dtype=mx.float32
)
tmp_out = mx.zeros(
(q.shape[0], q.shape[1], max_num_partitions, q.shape[2]),
dtype=q.dtype,
)
mx.eval(exp_sums, max_logits, tmp_out)
ops.paged_attention_v2_online_partitioned(
out,
q,
k,
v,
num_kv_heads,
softmax_scale,
softcap,
block_table,
seqused_k,
cu_seqlens_q,
block_size,
max_seqlen_k,
sliding_window,
exp_sums,
max_logits,
tmp_out,
)
mx.synchronize()
else:
ops.paged_attention_v2_online(
out,
q,
k,
v,
num_kv_heads,
softmax_scale,
softcap,
block_table,
seqused_k,
cu_seqlens_q,
block_size,
max_seqlen_k,
sliding_window,
)
mx.synchronize()


def get_ops() -> ModuleType:
Expand Down
13 changes: 11 additions & 2 deletions vllm_metal/metal/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
import sysconfig
from pathlib import Path

from vllm_metal.metal.constants import PARTITION_SIZE

logger = logging.getLogger(__name__)

_THIS_DIR = Path(__file__).resolve().parent
_SRC = _THIS_DIR / "paged_ops.cpp"
_BUILD = _THIS_DIR / "build.py"
_CONSTANTS = _THIS_DIR / "constants.py"
_EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX") or ".so"
_CACHE_DIR = Path.home() / ".cache" / "vllm-metal"
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Expand All @@ -40,8 +44,12 @@ def needs_rebuild() -> bool:
"""Return True if the .so is missing or older than the source."""
if not _OUT.exists():
return True
src_mtime = _SRC.stat().st_mtime
return _OUT.stat().st_mtime < src_mtime
latest_input_mtime = max(
_SRC.stat().st_mtime,
_BUILD.stat().st_mtime,
_CONSTANTS.stat().st_mtime,
)
return _OUT.stat().st_mtime < latest_input_mtime


def build() -> Path:
Expand Down Expand Up @@ -94,6 +102,7 @@ def build() -> Path:
f"-Wl,-rpath,{mlx_lib}",
"-D_METAL_",
"-DACCELERATE_NEW_LAPACK",
f"-DVLLM_METAL_PARTITION_SIZE={PARTITION_SIZE}",
"-undefined",
"dynamic_lookup",
str(nb_src),
Expand Down
5 changes: 5 additions & 0 deletions vllm_metal/metal/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
"""Shared constants for Metal paged-attention partitioning."""

PARTITION_SIZE = 512
PARTITION_THRESHOLD = 4096
10 changes: 6 additions & 4 deletions vllm_metal/metal/kernels_v1/pagedattention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1367,15 +1367,17 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
num_simd_lanes, 0);

// TODO: tune num_threads = 256
// NOTE: partition_size = 512
// NOTE: partition_size = VLLM_METAL_PARTITION_SIZE
#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, cache_type, 256, \
num_simd_lanes, 512);
num_simd_lanes, \
VLLM_METAL_PARTITION_SIZE);

// TODO: tune num_threads = 256
// NOTE: partition_size = 512
// NOTE: partition_size = VLLM_METAL_PARTITION_SIZE
#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
instantiate_paged_attention_v2_reduce_heads( \
type, 256, num_simd_lanes, VLLM_METAL_PARTITION_SIZE);

instantiate_paged_attention_v1(float, float, 32);
instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32);
Expand Down
36 changes: 19 additions & 17 deletions vllm_metal/metal/kernels_v2/pagedattention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1107,19 +1107,6 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
// merges sequentially. Simple and barrier-safe (all barriers are
// reached by all threads in the threadgroup).

// If partitioning is enabled, store the partial result for the reduce kernel.
// Indexed by q_token_idx (not seq_idx) for varlen compatibility.
if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
device float *max_logits_ptr =
max_logits + q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = warp_m;
device float *exp_sums_ptr = exp_sums +
q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = warp_l;
}

// For non-partitioned mode, include the sink in each warp's state.
if (!USE_PARTITIONING && use_sinks) {
float sink_val = sinks[head_idx];
Expand Down Expand Up @@ -1190,6 +1177,19 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
warp_l = warp_l * my_corr + other_l * other_corr;
}

// For partitioned mode, persist the merged partition statistics for the
// reduce kernel. These must match the normalized tmp_out written below.
if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
device float *max_logits_ptr =
max_logits + q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = warp_m;
device float *exp_sums_ptr = exp_sums +
q_token_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = warp_l;
}

// Final normalization: O = O / l
const float inv_l = 1.f / (warp_l + 1e-6f);

Expand Down Expand Up @@ -1463,15 +1463,17 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
num_simd_lanes, 0);

// TODO: tune num_threads = 256
// NOTE: partition_size = 512
// NOTE: partition_size = VLLM_METAL_PARTITION_SIZE
#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, cache_type, 256, \
num_simd_lanes, 512);
num_simd_lanes, \
VLLM_METAL_PARTITION_SIZE);

// TODO: tune num_threads = 256
// NOTE: partition_size = 512
// NOTE: partition_size = VLLM_METAL_PARTITION_SIZE
#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
instantiate_paged_attention_v2_reduce_heads( \
type, 256, num_simd_lanes, VLLM_METAL_PARTITION_SIZE);

instantiate_paged_attention_v1(float, float, 32);
instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32);
Expand Down
Loading
Loading