Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions tests/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
OffsetCache,
clear_context,
get_context,
prepare_decode,
prepare_prefill_packed,
prepare_unified,
)


Expand All @@ -34,38 +33,36 @@ class TestPrepare:
def teardown_method(self):
clear_context()

def test_prepare_prefill_single_request(self):
# Single request via prepare_prefill_packed
prepare_prefill_packed([([10, 11], 5)], block_size=4)
def test_prepare_unified_prefill_single(self):
# Single prefill request via prepare_unified (start_pos=0)
prepare_unified([], [([10, 11], 5, 0)], block_size=4)
ctx = get_context()

# block 10: slots 40,41,42,43; block 11: slot 44
assert ctx is not None
assert ctx.is_prefill
assert ctx.slot_mapping == [40, 41, 42, 43, 44]
assert ctx.block_tables == [[10, 11]]
assert ctx.context_lens == [5]
assert ctx.cu_seqlens == [0, 5]
assert ctx.offsets == [0]

def test_prepare_prefill_packed_slot_mapping(self):
# Two requests: 3 tokens in block 10, 2 tokens in block 20
requests = [([10], 3), ([20], 2)]
prepare_prefill_packed(requests, block_size=4)
def test_prepare_unified_prefill_packed(self):
# Two prefill requests packed together
prepare_unified([], [([10], 3, 0), ([20], 2, 0)], block_size=4)
ctx = get_context()

assert ctx is not None
assert ctx.is_prefill
# Request 0: block 10, slots 40,41,42
# Request 1: block 20, slots 80,81
assert ctx.slot_mapping == [40, 41, 42, 80, 81]
assert ctx.cu_seqlens == [0, 3, 5]
assert ctx.block_tables == [[10], [20]]
assert ctx.context_lens == [3, 2]
assert ctx.offsets == [0, 0]

def test_prepare_prefill_packed_single_request(self):
# Single request through packed path should produce valid metadata
requests = [([5, 6], 5)]
prepare_prefill_packed(requests, block_size=4)
def test_prepare_unified_prefill_multiblock(self):
# Single prefill spanning two blocks
prepare_unified([], [([5, 6], 5, 0)], block_size=4)
ctx = get_context()

assert ctx is not None
Expand All @@ -75,20 +72,53 @@ def test_prepare_prefill_packed_single_request(self):
assert ctx.block_tables == [[5, 6]]
assert ctx.context_lens == [5]

def test_prepare_decode(self):
# Arrange
requests = [([5, 6], 7)]
def test_prepare_unified_continuation_chunk(self):
# Continuation chunk: 3 new tokens starting at position 4
# block 10 has slots 40-43 (positions 0-3, already cached),
# block 11 has slots 44-47 (positions 4-6 are the new tokens)
prepare_unified([], [([10, 11], 3, 4)], block_size=4)
ctx = get_context()

# Act
prepare_decode(requests, block_size=4)
assert ctx is not None
# Only 3 tokens in the query (positions 4, 5, 6)
assert ctx.cu_seqlens == [0, 3]
# Slots for positions 4, 5, 6: block 11 slots 44, 45, 46
assert ctx.slot_mapping == [44, 45, 46]
assert ctx.block_tables == [[10, 11]]
# Total context = start_pos + num_tokens = 4 + 3 = 7
assert ctx.context_lens == [7]
# RoPE offset = start_pos
assert ctx.offsets == [4]

def test_prepare_unified_decode_only(self):
# Single decode request via prepare_unified
decode_requests = [([5, 6], 7)]
prepare_unified(decode_requests, [], block_size=4)
ctx = get_context()

# Assert — new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27
# new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27
assert ctx is not None
assert not ctx.is_prefill
assert ctx.slot_mapping == [27]
assert ctx.context_lens == [8]
assert ctx.offsets == [7]
assert ctx.cu_seqlens == [0, 1]

def test_prepare_unified_mixed(self):
# 1 decode + 1 prefill
decode_requests = [([5, 6], 7)] # seq_len=7
prefill_requests = [([10, 11], 5, 0)] # 5 tokens from position 0

prepare_unified(decode_requests, prefill_requests, block_size=4)
ctx = get_context()

assert ctx is not None
# Decode slot: pos=7, block 6, slot=6*4+3=27
# Prefill slots: block 10 slots 40,41,42,43; block 11 slot 44
assert ctx.slot_mapping == [27, 40, 41, 42, 43, 44]
assert ctx.cu_seqlens == [0, 1, 6]
assert ctx.offsets == [7, 0]
assert ctx.context_lens == [8, 5]
assert ctx.block_tables == [[5, 6], [10, 11]]


class TestPackedRoPE:
Expand Down
18 changes: 11 additions & 7 deletions vllm_metal/metal_kernel_backend/packed_prefill_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SCAFFOLDING: remove when varlen kernel handles position encoding natively.
#
# Per-request RoPE helper for packed prefill.
# Per-request RoPE helper for packed / unified forward passes.

from __future__ import annotations

Expand All @@ -13,16 +11,22 @@ def apply_packed_rope(
queries: mx.array,
keys: mx.array,
cu_seqlens: list[int],
offsets: list[int] | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply per-request RoPE with position reset for packed prefill.
"""Apply per-request RoPE for packed sequences.

SCAFFOLDING: remove when varlen kernel is ready.
Each segment delimited by ``cu_seqlens`` gets its own RoPE application
starting at the corresponding offset. When *offsets* is ``None`` every
segment starts at position 0 (pure prefill). For unified prefill+decode
batches, decode segments carry ``offset=seq_len`` while prefill segments
keep ``offset=0``.
"""
q_parts = []
k_parts = []
for i in range(len(cu_seqlens) - 1):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0))
k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0))
off = offsets[i] if offsets is not None else 0
q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=off))
k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=off))
return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2)
130 changes: 14 additions & 116 deletions vllm_metal/metal_kernel_backend/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

All operations use MLX arrays end-to-end — no PyTorch MPS bridge.

Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_prefill_packed``,
``prepare_decode``, ``clear_context`` from ``paged_attention_common``.
Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_unified``,
``clear_context`` from ``paged_attention_common``.

Backend replacement guide
-------------------------
Expand Down Expand Up @@ -107,9 +107,15 @@ def _metal_kernel_prefill_attention(
"attribute. Only RoPE-based models are supported by paged attention."
)

# NOTE: apply_packed_rope always uses offset=0 per request. Chunked
# prefill will need per-request offsets (like decode) for continuation chunks.
queries, keys = apply_packed_rope(attn_module, queries, keys, ctx.cu_seqlens)
# Per-segment RoPE: offset=0 for fresh prefill, offset=seq_len for decode
# tokens in a unified batch (ctx.offsets populated by prepare_unified).
queries, keys = apply_packed_rope(
attn_module,
queries,
keys,
ctx.cu_seqlens,
offsets=ctx.offsets if ctx.offsets else None,
)

# Reshape to 3D: (1, heads, L, hd) → (L, heads, hd)
q_3d = mx.contiguous(queries[0].transpose(1, 0, 2).astype(cache.dtype))
Expand Down Expand Up @@ -168,109 +174,6 @@ def _metal_kernel_prefill_attention(
return attn_module.o_proj(out)


# ---------------------------------------------------------------------------
# Decode attention (reshape_and_cache + paged_attention_v1)
# ---------------------------------------------------------------------------


def _metal_kernel_decode_attention(
attn_module: Any,
queries: mx.array,
keys: mx.array,
values: mx.array,
cache: MetalPagedKVCache,
layer_idx: int,
ctx: PagedAttentionContext,
) -> mx.array:
"""Batched decode: B=batch_size, L=1.

Per-request RoPE, write new token via ``reshape_and_cache``,
then zero-copy attention via ``paged_attention_v1``.
"""
B = queries.shape[0] # noqa: N806
n_heads = queries.shape[1]
head_dim = queries.shape[3]

# Per-request RoPE
if not hasattr(attn_module, "rope"):
raise NotImplementedError(
f"Attention module {type(attn_module).__name__} does not have a 'rope' "
"attribute. Only RoPE-based models are supported by paged attention."
)
q_parts = []
k_parts = []
for i in range(B):
q_parts.append(attn_module.rope(queries[i : i + 1], offset=ctx.offsets[i]))
k_parts.append(attn_module.rope(keys[i : i + 1], offset=ctx.offsets[i]))
queries = mx.concatenate(q_parts, axis=0) # (B, heads, 1, head_dim)
keys_new = mx.concatenate(k_parts, axis=0) # (B, kv_heads, 1, head_dim)

# Squeeze seq dim: (B, heads, 1, hd) → (B, heads, hd)
q_3d = mx.contiguous(queries[:, :, 0, :].astype(cache.dtype))
k_3d = mx.contiguous(keys_new[:, :, 0, :].astype(cache.dtype))
v_3d = mx.contiguous(values[:, :, 0, :].astype(cache.dtype))

slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64)

# Build block_tables and seq_lens
max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables)
block_tables_list = [
bt + [0] * (max_blocks_per_seq - len(bt)) for bt in ctx.block_tables
]
block_tables = mx.array(block_tables_list, dtype=mx.int32)
seq_lens = mx.array(ctx.context_lens, dtype=mx.int32)

# Eval all inputs before kernel dispatch
mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens)

ops = get_ops()

# Write new K/V tokens into paged cache
ops.reshape_and_cache(
k_3d,
v_3d,
cache.key_caches[layer_idx],
cache.value_caches[layer_idx],
slot_mapping,
)

# Allocate output
out = mx.zeros((B, n_heads, head_dim), dtype=cache.dtype)
mx.eval(out)

max_seq_len = max(ctx.context_lens)
scale = attn_module.scale

# Build cu_seqlens_q for varlen dispatch: decode has q_len=1 per sequence.
cu_seqlens_q = mx.arange(B + 1, dtype=mx.int32)
mx.eval(cu_seqlens_q)

# Zero-copy paged attention (v2, online softmax, varlen-capable)
ops.paged_attention_v2_online(
out,
q_3d,
cache.key_caches[layer_idx],
cache.value_caches[layer_idx],
cache.num_kv_heads,
scale,
0.0, # softcap (0 = disabled)
block_tables,
seq_lens,
cu_seqlens_q,
cache.block_size,
max_seq_len,
-1, # sliding_window (-1 = disabled)
)

# Synchronize GPU: paged_attention_v2_online wrote to out's buffer via a raw
# Metal dispatch that MLX's lazy graph doesn't track. mx.eval(out) would
# be a no-op here (out was already evaluated as zeros), so we must use
# mx.synchronize() to flush the command encoder and wait for the kernel.
mx.synchronize()
out = out.reshape(B, 1, n_heads * head_dim)
return attn_module.o_proj(out)


# ---------------------------------------------------------------------------
# Wrapper nn.Module
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -329,14 +232,9 @@ def __call__(self, x: mx.array, mask: Any = None, cache: Any = None) -> mx.array
keys = keys.transpose(0, 2, 1, 3)
values = values.transpose(0, 2, 1, 3)

if ctx.is_prefill:
return _metal_kernel_prefill_attention(
inner, queries, keys, values, kv_cache, layer_idx, ctx
)
else:
return _metal_kernel_decode_attention(
inner, queries, keys, values, kv_cache, layer_idx, ctx
)
return _metal_kernel_prefill_attention(
inner, queries, keys, values, kv_cache, layer_idx, ctx
)


# ---------------------------------------------------------------------------
Expand Down
Loading
Loading