diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index bcd70f95..30b45bfb 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -11,8 +11,7 @@ OffsetCache, clear_context, get_context, - prepare_decode, - prepare_prefill_packed, + prepare_unified, ) @@ -34,38 +33,36 @@ class TestPrepare: def teardown_method(self): clear_context() - def test_prepare_prefill_single_request(self): - # Single request via prepare_prefill_packed - prepare_prefill_packed([([10, 11], 5)], block_size=4) + def test_prepare_unified_prefill_single(self): + # Single prefill request via prepare_unified (start_pos=0) + prepare_unified([], [([10, 11], 5, 0)], block_size=4) ctx = get_context() # block 10: slots 40,41,42,43; block 11: slot 44 assert ctx is not None - assert ctx.is_prefill assert ctx.slot_mapping == [40, 41, 42, 43, 44] assert ctx.block_tables == [[10, 11]] assert ctx.context_lens == [5] assert ctx.cu_seqlens == [0, 5] + assert ctx.offsets == [0] - def test_prepare_prefill_packed_slot_mapping(self): - # Two requests: 3 tokens in block 10, 2 tokens in block 20 - requests = [([10], 3), ([20], 2)] - prepare_prefill_packed(requests, block_size=4) + def test_prepare_unified_prefill_packed(self): + # Two prefill requests packed together + prepare_unified([], [([10], 3, 0), ([20], 2, 0)], block_size=4) ctx = get_context() assert ctx is not None - assert ctx.is_prefill # Request 0: block 10, slots 40,41,42 # Request 1: block 20, slots 80,81 assert ctx.slot_mapping == [40, 41, 42, 80, 81] assert ctx.cu_seqlens == [0, 3, 5] assert ctx.block_tables == [[10], [20]] assert ctx.context_lens == [3, 2] + assert ctx.offsets == [0, 0] - def test_prepare_prefill_packed_single_request(self): - # Single request through packed path should produce valid metadata - requests = [([5, 6], 5)] - prepare_prefill_packed(requests, block_size=4) + def test_prepare_unified_prefill_multiblock(self): + # Single prefill spanning two blocks + prepare_unified([], [([5, 6], 5, 0)], block_size=4) ctx = get_context() assert ctx is not None @@ -75,20 +72,53 @@ def test_prepare_prefill_packed_single_request(self): assert ctx.block_tables == [[5, 6]] assert ctx.context_lens == [5] - def test_prepare_decode(self): - # Arrange - requests = [([5, 6], 7)] + def test_prepare_unified_continuation_chunk(self): + # Continuation chunk: 3 new tokens starting at position 4 + # block 10 has slots 40-43 (positions 0-3, already cached), + # block 11 has slots 44-47 (positions 4-6 are the new tokens) + prepare_unified([], [([10, 11], 3, 4)], block_size=4) + ctx = get_context() - # Act - prepare_decode(requests, block_size=4) + assert ctx is not None + # Only 3 tokens in the query (positions 4, 5, 6) + assert ctx.cu_seqlens == [0, 3] + # Slots for positions 4, 5, 6: block 11 slots 44, 45, 46 + assert ctx.slot_mapping == [44, 45, 46] + assert ctx.block_tables == [[10, 11]] + # Total context = start_pos + num_tokens = 4 + 3 = 7 + assert ctx.context_lens == [7] + # RoPE offset = start_pos + assert ctx.offsets == [4] + + def test_prepare_unified_decode_only(self): + # Single decode request via prepare_unified + decode_requests = [([5, 6], 7)] + prepare_unified(decode_requests, [], block_size=4) ctx = get_context() - # Assert — new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27 + # new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27 assert ctx is not None - assert not ctx.is_prefill assert ctx.slot_mapping == [27] assert ctx.context_lens == [8] assert ctx.offsets == [7] + assert ctx.cu_seqlens == [0, 1] + + def test_prepare_unified_mixed(self): + # 1 decode + 1 prefill + decode_requests = [([5, 6], 7)] # seq_len=7 + prefill_requests = [([10, 11], 5, 0)] # 5 tokens from position 0 + + prepare_unified(decode_requests, prefill_requests, block_size=4) + ctx = get_context() + + assert ctx is not None + # Decode slot: pos=7, block 6, slot=6*4+3=27 + # Prefill slots: block 10 slots 40,41,42,43; block 11 slot 44 + assert ctx.slot_mapping == [27, 40, 41, 42, 43, 44] + assert ctx.cu_seqlens == [0, 1, 6] + assert ctx.offsets == [7, 0] + assert ctx.context_lens == [8, 5] + assert ctx.block_tables == [[5, 6], [10, 11]] class TestPackedRoPE: diff --git a/vllm_metal/metal_kernel_backend/packed_prefill_compat.py b/vllm_metal/metal_kernel_backend/packed_prefill_compat.py index 72ec5d47..b41ac0c4 100644 --- a/vllm_metal/metal_kernel_backend/packed_prefill_compat.py +++ b/vllm_metal/metal_kernel_backend/packed_prefill_compat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SCAFFOLDING: remove when varlen kernel handles position encoding natively. -# -# Per-request RoPE helper for packed prefill. +# Per-request RoPE helper for packed / unified forward passes. from __future__ import annotations @@ -13,16 +11,22 @@ def apply_packed_rope( queries: mx.array, keys: mx.array, cu_seqlens: list[int], + offsets: list[int] | None = None, ) -> tuple[mx.array, mx.array]: - """Apply per-request RoPE with position reset for packed prefill. + """Apply per-request RoPE for packed sequences. - SCAFFOLDING: remove when varlen kernel is ready. + Each segment delimited by ``cu_seqlens`` gets its own RoPE application + starting at the corresponding offset. When *offsets* is ``None`` every + segment starts at position 0 (pure prefill). For unified prefill+decode + batches, decode segments carry ``offset=seq_len`` while prefill segments + keep ``offset=0``. """ q_parts = [] k_parts = [] for i in range(len(cu_seqlens) - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] - q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0)) - k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0)) + off = offsets[i] if offsets is not None else 0 + q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=off)) + k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=off)) return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2) diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index 79813065..936ddfa7 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -10,8 +10,8 @@ All operations use MLX arrays end-to-end — no PyTorch MPS bridge. -Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_prefill_packed``, -``prepare_decode``, ``clear_context`` from ``paged_attention_common``. +Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_unified``, +``clear_context`` from ``paged_attention_common``. Backend replacement guide ------------------------- @@ -107,9 +107,15 @@ def _metal_kernel_prefill_attention( "attribute. Only RoPE-based models are supported by paged attention." ) - # NOTE: apply_packed_rope always uses offset=0 per request. Chunked - # prefill will need per-request offsets (like decode) for continuation chunks. - queries, keys = apply_packed_rope(attn_module, queries, keys, ctx.cu_seqlens) + # Per-segment RoPE: offset=0 for fresh prefill, offset=seq_len for decode + # tokens in a unified batch (ctx.offsets populated by prepare_unified). + queries, keys = apply_packed_rope( + attn_module, + queries, + keys, + ctx.cu_seqlens, + offsets=ctx.offsets if ctx.offsets else None, + ) # Reshape to 3D: (1, heads, L, hd) → (L, heads, hd) q_3d = mx.contiguous(queries[0].transpose(1, 0, 2).astype(cache.dtype)) @@ -168,109 +174,6 @@ def _metal_kernel_prefill_attention( return attn_module.o_proj(out) -# --------------------------------------------------------------------------- -# Decode attention (reshape_and_cache + paged_attention_v1) -# --------------------------------------------------------------------------- - - -def _metal_kernel_decode_attention( - attn_module: Any, - queries: mx.array, - keys: mx.array, - values: mx.array, - cache: MetalPagedKVCache, - layer_idx: int, - ctx: PagedAttentionContext, -) -> mx.array: - """Batched decode: B=batch_size, L=1. - - Per-request RoPE, write new token via ``reshape_and_cache``, - then zero-copy attention via ``paged_attention_v1``. - """ - B = queries.shape[0] # noqa: N806 - n_heads = queries.shape[1] - head_dim = queries.shape[3] - - # Per-request RoPE - if not hasattr(attn_module, "rope"): - raise NotImplementedError( - f"Attention module {type(attn_module).__name__} does not have a 'rope' " - "attribute. Only RoPE-based models are supported by paged attention." - ) - q_parts = [] - k_parts = [] - for i in range(B): - q_parts.append(attn_module.rope(queries[i : i + 1], offset=ctx.offsets[i])) - k_parts.append(attn_module.rope(keys[i : i + 1], offset=ctx.offsets[i])) - queries = mx.concatenate(q_parts, axis=0) # (B, heads, 1, head_dim) - keys_new = mx.concatenate(k_parts, axis=0) # (B, kv_heads, 1, head_dim) - - # Squeeze seq dim: (B, heads, 1, hd) → (B, heads, hd) - q_3d = mx.contiguous(queries[:, :, 0, :].astype(cache.dtype)) - k_3d = mx.contiguous(keys_new[:, :, 0, :].astype(cache.dtype)) - v_3d = mx.contiguous(values[:, :, 0, :].astype(cache.dtype)) - - slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64) - - # Build block_tables and seq_lens - max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables) - block_tables_list = [ - bt + [0] * (max_blocks_per_seq - len(bt)) for bt in ctx.block_tables - ] - block_tables = mx.array(block_tables_list, dtype=mx.int32) - seq_lens = mx.array(ctx.context_lens, dtype=mx.int32) - - # Eval all inputs before kernel dispatch - mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens) - - ops = get_ops() - - # Write new K/V tokens into paged cache - ops.reshape_and_cache( - k_3d, - v_3d, - cache.key_caches[layer_idx], - cache.value_caches[layer_idx], - slot_mapping, - ) - - # Allocate output - out = mx.zeros((B, n_heads, head_dim), dtype=cache.dtype) - mx.eval(out) - - max_seq_len = max(ctx.context_lens) - scale = attn_module.scale - - # Build cu_seqlens_q for varlen dispatch: decode has q_len=1 per sequence. - cu_seqlens_q = mx.arange(B + 1, dtype=mx.int32) - mx.eval(cu_seqlens_q) - - # Zero-copy paged attention (v2, online softmax, varlen-capable) - ops.paged_attention_v2_online( - out, - q_3d, - cache.key_caches[layer_idx], - cache.value_caches[layer_idx], - cache.num_kv_heads, - scale, - 0.0, # softcap (0 = disabled) - block_tables, - seq_lens, - cu_seqlens_q, - cache.block_size, - max_seq_len, - -1, # sliding_window (-1 = disabled) - ) - - # Synchronize GPU: paged_attention_v2_online wrote to out's buffer via a raw - # Metal dispatch that MLX's lazy graph doesn't track. mx.eval(out) would - # be a no-op here (out was already evaluated as zeros), so we must use - # mx.synchronize() to flush the command encoder and wait for the kernel. - mx.synchronize() - out = out.reshape(B, 1, n_heads * head_dim) - return attn_module.o_proj(out) - - # --------------------------------------------------------------------------- # Wrapper nn.Module # --------------------------------------------------------------------------- @@ -329,14 +232,9 @@ def __call__(self, x: mx.array, mask: Any = None, cache: Any = None) -> mx.array keys = keys.transpose(0, 2, 1, 3) values = values.transpose(0, 2, 1, 3) - if ctx.is_prefill: - return _metal_kernel_prefill_attention( - inner, queries, keys, values, kv_cache, layer_idx, ctx - ) - else: - return _metal_kernel_decode_attention( - inner, queries, keys, values, kv_cache, layer_idx, ctx - ) + return _metal_kernel_prefill_attention( + inner, queries, keys, values, kv_cache, layer_idx, ctx + ) # --------------------------------------------------------------------------- diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index be148c61..05661d82 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -5,9 +5,9 @@ both the Metal kernel paged attention backend and the model runner. Usage: - 1. Before each forward pass call ``prepare_prefill_packed()`` or ``prepare_decode()`` + 1. Before each forward pass call ``prepare_unified()`` 2. Run ``model(input_ids, cache=offset_caches)`` as normal - 3. The attention wrapper reads ``get_context()`` to decide prefill vs decode + 3. The attention wrapper reads ``get_context()`` for paged metadata 4. Call ``clear_context()`` after the forward pass """ @@ -26,7 +26,7 @@ # Thread-local storage used to pass per-request metadata (slot_mapping, # block_tables, etc.) to attention wrappers buried inside the model. # We cannot add extra arguments to the mlx_lm forward signature, so -# instead: prepare_prefill_packed/decode() stashes context here before the +# instead: prepare_unified() stashes context here before the # forward pass, each attention wrapper reads it via get_context(), and # clear_context() cleans up afterwards. _thread_local = threading.local() @@ -34,17 +34,20 @@ @dataclass class PagedAttentionContext: - """Context set before each forward pass, read by patched attention.""" + """Context set before each forward pass, read by patched attention. + + All forward passes use the varlen kernel with ``cu_seqlens`` to handle + variable-length subsequences (both prefill and decode tokens packed + into a single flat sequence). + """ - is_prefill: bool slot_mapping: list[int] - # decode-only fields block_tables: list[list[int]] = field(default_factory=list) context_lens: list[int] = field(default_factory=list) + # Per-segment RoPE offsets: 0 for fresh prefill, seq_len for decode. offsets: list[int] = field(default_factory=list) - # packed prefill fields — set when multiple requests are packed into - # a single forward pass. cu_seqlens is a cumulative sequence length - # array: [0, len0, len0+len1, ...] (length = num_requests + 1). + # Cumulative sequence length array: [0, len0, len0+len1, ...] + # (length = num_requests + 1). cu_seqlens: list[int] | None = None @@ -148,78 +151,60 @@ def find_layers_and_attr(model: Any) -> tuple[list[Any], str]: # --------------------------------------------------------------------------- -def prepare_prefill_packed( - requests: list[tuple[list[int], int]], +def prepare_unified( + decode_requests: list[tuple[list[int], int]], + prefill_requests: list[tuple[list[int], int, int]], block_size: int, ) -> None: - """Compute slot_mapping, cu_seqlens, block_tables, context_lens for prefill. + """Compute metadata for a unified prefill + decode forward pass. - Packs one or more prefill requests into a single forward pass. The - varlen Metal kernel uses ``cu_seqlens`` to locate each sequence's - query tokens and ``block_tables`` / ``context_lens`` to read K/V - from the paged cache. + Packs decode tokens (1 per request) followed by prefill tokens into a + single flattened sequence. ``cu_seqlens`` marks request boundaries so + the varlen kernel handles both decode (length-1) and prefill (length-N) + subsequences in one dispatch. Args: - requests: list of (block_ids, num_tokens) per request. - block_size: tokens per block. + decode_requests: list of ``(block_ids, seq_len)`` for decode requests. + ``seq_len`` = tokens already cached before this step. + prefill_requests: list of ``(block_ids, num_tokens, start_pos)`` for + prefill. ``start_pos`` is the position of the first token in this + chunk (0 for a fresh prefill, >0 for continuation chunks). + block_size: tokens per KV cache block. """ slot_mapping: list[int] = [] cu_seqlens: list[int] = [0] block_tables: list[list[int]] = [] context_lens: list[int] = [] - - for block_ids, num_tokens in requests: - for pos in range(num_tokens): - block_idx = block_ids[pos // block_size] - slot = block_idx * block_size + (pos % block_size) - slot_mapping.append(slot) - cu_seqlens.append(cu_seqlens[-1] + num_tokens) - block_tables.append(block_ids) - context_lens.append(num_tokens) - - set_context( - PagedAttentionContext( - is_prefill=True, - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - cu_seqlens=cu_seqlens, - ) - ) - - -def prepare_decode( - requests: list[tuple[list[int], int]], - block_size: int, -) -> None: - """Compute slot_mapping, block_tables, context_lens, offsets for decode. - - Args: - requests: list of (block_ids, seq_len) per request. - seq_len = number of tokens already stored (before this step). - block_size: tokens per block - """ - slot_mapping: list[int] = [] - block_tables: list[list[int]] = [] - context_lens: list[int] = [] offsets: list[int] = [] - for block_ids, seq_len in requests: - # Slot for the new token at position seq_len + # Decode requests first (1 token each) + for block_ids, seq_len in decode_requests: new_pos = seq_len block_idx = block_ids[new_pos // block_size] slot = block_idx * block_size + (new_pos % block_size) slot_mapping.append(slot) + cu_seqlens.append(cu_seqlens[-1] + 1) block_tables.append(block_ids) context_lens.append(seq_len + 1) # including new token - offsets.append(seq_len) # RoPE position = seq_len + offsets.append(seq_len) # RoPE position + + # Prefill requests (variable tokens each, starting at start_pos) + for block_ids, num_tokens, start_pos in prefill_requests: + for pos in range(start_pos, start_pos + num_tokens): + block_idx = block_ids[pos // block_size] + slot = block_idx * block_size + (pos % block_size) + slot_mapping.append(slot) + cu_seqlens.append(cu_seqlens[-1] + num_tokens) + block_tables.append(block_ids) + context_lens.append(start_pos + num_tokens) + offsets.append(start_pos) set_context( PagedAttentionContext( - is_prefill=False, slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, + cu_seqlens=cu_seqlens, offsets=offsets, ) ) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 4be9eff0..17ee86dc 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -51,8 +51,7 @@ from vllm_metal.paged_attention_common import ( OffsetCache, clear_context, - prepare_decode, - prepare_prefill_packed, + prepare_unified, ) from vllm_metal.pytorch_backend.tensor_bridge import mlx_to_torch from vllm_metal.stt.config import ( @@ -580,12 +579,6 @@ def _extract_kv_cache( return extracted -# SCAFFOLDING: remove when varlen kernel is ready. -# Cap total packed-prefill tokens to bound the O(N²) dense causal mask. -# Batches exceeding this limit are split into multiple forward passes. -MAX_PACKED_PREFILL_TOKENS = 4096 - - class MetalModelRunner: """Model runner for MLX-based inference on Metal. @@ -1425,12 +1418,12 @@ def _sequential_decode( return next_tokens # ------------------------------------------------------------------ - # Paged attention paths + # Unified prefill + decode (single forward pass) # ------------------------------------------------------------------ - def _prefill_packed_paged( + def _unified_prefill_decode_paged( self, - pack_reqs: list[ + prefill_reqs: list[ tuple[ str, list[int], @@ -1438,33 +1431,62 @@ def _prefill_packed_paged( list[int], torch.Generator | None, int | None, + int, ] ], - ) -> list[int]: - """Packed paged-attention prefill for multiple requests. + decode_reqs: list[tuple[str, RequestState]], + ) -> tuple[list[int], list[int]]: + """Single forward pass for mixed prefill + decode requests. - Concatenates token_ids from all requests into a single forward - pass using ``cu_seqlens`` to build a block-diagonal causal mask. - This avoids the overhead of N separate forward passes. + Packs decode tokens (1 per request) followed by prefill tokens into + a flat ``(1, total_tokens)`` input. The varlen kernel uses + ``cu_seqlens`` to handle variable-length subsequences. Args: - pack_reqs: list of - (req_id, token_ids, sampling_params, block_ids, - generator, prompt_len) tuples. + prefill_reqs: list of + ``(req_id, token_ids, sampling_params, block_ids, + generator, prompt_len, start_pos)`` — prefill requests. + ``start_pos`` is the RoPE offset / KV slot start (0 for + fresh prefill, >0 for continuation chunks). + decode_reqs: list of ``(req_id, RequestState)`` — decode requests. Returns: - List of sampled next tokens, one per request. + ``(prefill_next_tokens, decode_next_tokens)`` """ - # Build packed input + num_decode = len(decode_reqs) + + # ---- build unified token sequence: decode first, then prefill ---- all_token_ids: list[int] = [] - block_requests: list[tuple[list[int], int]] = [] - for _, token_ids, _, block_ids, _, _ in pack_reqs: + + # Decode: last token per request + if self._rust_state_manager is not None: + last_tokens = self._rust_state_manager.get_last_tokens_batch( + [rid for rid, _ in decode_reqs] + ) + else: + last_tokens = [ + state.token_ids[-1] if state.token_ids else 0 + for _, state in decode_reqs + ] + all_token_ids.extend(last_tokens) + + # Prefill: tokens per request + for _, token_ids, _, _, _, _, _ in prefill_reqs: all_token_ids.extend(token_ids) - block_requests.append((block_ids, len(token_ids))) - # Stash packed context (slot_mapping + cu_seqlens) - prepare_prefill_packed(block_requests, self._paged_block_size) + # ---- build metadata for prepare_unified ---- + decode_info: list[tuple[list[int], int]] = [] + for req_id, state in decode_reqs: + seq_len = self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 1) + decode_info.append((state.block_ids, seq_len)) + + prefill_info: list[tuple[list[int], int, int]] = [] + for _, token_ids, _, block_ids, _, _, start_pos in prefill_reqs: + prefill_info.append((block_ids, len(token_ids), start_pos)) + prepare_unified(decode_info, prefill_info, self._paged_block_size) + + # ---- forward ---- offset_caches = [OffsetCache(0) for _ in range(self.num_layers)] input_ids = mx.array([all_token_ids], dtype=mx.int32) try: @@ -1473,21 +1495,85 @@ def _prefill_packed_paged( finally: clear_context() - # Extract per-request last-token logits and sample - cu_seqlens = [0] - for _, token_ids, _, _, _, _ in pack_reqs: + # ---- build cu_seqlens for logit extraction ---- + cu_seqlens: list[int] = [0] + for _ in decode_reqs: + cu_seqlens.append(cu_seqlens[-1] + 1) + for _, token_ids, _, _, _, _, _ in prefill_reqs: cu_seqlens.append(cu_seqlens[-1] + len(token_ids)) - next_tokens: list[int] = [] - for i, ( + # ---- sample decode tokens ---- + decode_next_tokens: list[int] = [] + if decode_reqs: + # All decode logits are at positions 0..num_decode-1 + decode_logits = logits[0, :num_decode, :] # (num_decode, vocab) + + sampling_params_list = [state.sampling_params for _, state in decode_reqs] + all_greedy = all(sp.temperature < 1e-5 for sp in sampling_params_list) + any_advanced = any( + sp.top_k > 0 + or sp.top_p < 1.0 + or sp.frequency_penalty != 0 + or sp.presence_penalty != 0 + or sp.repetition_penalty != 1.0 + for sp in sampling_params_list + ) + + if all_greedy and not any_advanced: + next_tokens_mlx = _mlx_greedy_sample(decode_logits) + mx.eval(next_tokens_mlx) + decode_next_tokens = next_tokens_mlx.tolist() + else: + mx.eval(decode_logits) + prompt_token_ids_list = [ + state.token_ids[: state.prompt_len] for _, state in decode_reqs + ] + output_tokens_list = [ + state.token_ids[state.prompt_len :] for _, state in decode_reqs + ] + generators = { + i: state.generator + for i, (_, state) in enumerate(decode_reqs) + if state.generator is not None + } + logits_torch = mlx_to_torch( + decode_logits.astype(mx.float32), device=self.device + ) + metadata = self._make_sampling_metadata( + sampling_params_list, + prompt_token_ids_list, + output_tokens_list, + generators=generators, + ) + output = self._sampler.forward(logits_torch, metadata) + decode_next_tokens = [ + int(output.sampled_token_ids[i, 0].item()) + for i in range(num_decode) + ] + + # Update decode state + for i, (req_id, state) in enumerate(decode_reqs): + state.token_ids.append(decode_next_tokens[i]) + state.generated_tokens += 1 + self._paged_request_seq_lens[req_id] = ( + self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 2) + + 1 + ) + if self._rust_state_manager is not None: + self._rust_state_manager.append_token(req_id, decode_next_tokens[i]) + + # ---- sample prefill tokens ---- + prefill_next_tokens: list[int] = [] + for j, ( req_id, token_ids, sampling_params, - _, + _block_ids, generator, prompt_len, - ) in enumerate(pack_reqs): - last_idx = cu_seqlens[i + 1] - 1 + _start_pos, + ) in enumerate(prefill_reqs): + last_idx = cu_seqlens[num_decode + j + 1] - 1 last_logits = logits[:, last_idx : last_idx + 1, :] if prompt_len is None: @@ -1522,171 +1608,9 @@ def _prefill_packed_paged( next_token = int(output.sampled_token_ids[0, 0].item()) self._paged_request_seq_lens[req_id] = len(token_ids) - next_tokens.append(next_token) - - return next_tokens - - def _run_packed_prefill( - self, - paged_complete: list[ - tuple[ - int, - str, - list[int], - SamplingParams, - list[int], - torch.Generator | None, - ] - ], - sampled_tokens: list[list[int]], - ) -> None: - """Batch, dispatch, and write back state for packed paged prefill. - - Splits *paged_complete* into batches that fit within - ``MAX_PACKED_PREFILL_TOKENS``, runs each batch through - ``_prefill_packed_paged``, and fills *sampled_tokens* in-place. - """ - # Split into batches that fit within the packed-length cap. - batches: list[list[tuple]] = [[]] - batch_tokens = 0 - for entry in paged_complete: - entry_tokens = len(entry[2]) # token_ids - if batch_tokens + entry_tokens > MAX_PACKED_PREFILL_TOKENS and batches[-1]: - batches.append([]) - batch_tokens = 0 - batches[-1].append(entry) - batch_tokens += entry_tokens - - for batch in batches: - pack_input = [ - (rid, tids, sp, bids, gen, None) - for _, rid, tids, sp, bids, gen in batch - ] - next_tokens = self._prefill_packed_paged(pack_input) - for i, (idx, rid, tids, sp, bids, gen) in enumerate(batch): - nt = next_tokens[i] - sampled_tokens[idx] = [nt] - self._request_states[rid] = RequestState( - token_ids=list(tids) + [nt], - prompt_len=len(tids), - cache=[], - sampling_params=sp, - generator=gen, - generated_tokens=1, - block_ids=bids, - ) - if self._rust_state_manager is not None: - self._rust_state_manager.add_request(rid, list(tids) + [nt]) - - def _batched_decode_paged( - self, decode_reqs: list[tuple[str, RequestState]] - ) -> list[int]: - """Paged-attention batched decode. - - Uses MLX for projections + per-request RoPE, then the HF kernel for - reshape_and_cache + paged_attention_v1 (zero-copy from block tables). - """ - - batch_size = len(decode_reqs) - - # Build request info for prepare_decode - requests_info: list[tuple[list[int], int]] = [] - for req_id, state in decode_reqs: - seq_len = self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 1) - requests_info.append((state.block_ids, seq_len)) - - # Stash per-request metadata (slot_mapping, block_tables, context_lens, - # offsets) in thread-local for the attention wrappers. - prepare_decode(requests_info, self._paged_block_size) - - # OffsetCache is a fake cache — no KV stored. The offset value - # only matters for make_mask(); for single-token decode make_mask(1) - # returns None regardless, so a shared max_offset is fine. Actual - # per-request RoPE offsets come from ctx.offsets in the wrapper. - max_offset = max(info[1] for info in requests_info) - offset_caches = [OffsetCache(max_offset) for _ in range(self.num_layers)] - - # Build batched input - if self._rust_state_manager is not None: - last_tokens = self._rust_state_manager.get_last_tokens_batch( - [req_id for req_id, _ in decode_reqs] - ) - else: - last_tokens = [ - state.token_ids[-1] if state.token_ids else 0 - for _, state in decode_reqs - ] - - batched_input = mx.array(last_tokens, dtype=mx.int32)[:, None] + prefill_next_tokens.append(next_token) - # The model forward calls each layer's self_attn, which has been - # replaced by MetalKernelPagedAttentionWrapper. The wrapper: - # - ignores cache= (OffsetCache) for KV storage - # - reads get_context() for block_tables, slot_mapping, offsets - # - applies per-request RoPE using ctx.offsets - # - writes new K/V to MPS paged cache via reshape_and_cache - # - reads all cached K/V via paged_attention_v1 (zero-copy) - try: - model_output = self.model(batched_input, cache=offset_caches) - logits = self._extract_logits(model_output) - next_token_logits = logits[:, -1, :] - finally: - clear_context() - - # Sample - sampling_params_list = [state.sampling_params for _, state in decode_reqs] - all_greedy = all(sp.temperature < 1e-5 for sp in sampling_params_list) - any_advanced = any( - sp.top_k > 0 - or sp.top_p < 1.0 - or sp.frequency_penalty != 0 - or sp.presence_penalty != 0 - or sp.repetition_penalty != 1.0 - for sp in sampling_params_list - ) - - if all_greedy and not any_advanced: - next_tokens_mlx = _mlx_greedy_sample(next_token_logits) - mx.eval(next_tokens_mlx) - next_tokens: list[int] = next_tokens_mlx.tolist() - else: - mx.eval(next_token_logits) - prompt_token_ids_list = [ - state.token_ids[: state.prompt_len] for _, state in decode_reqs - ] - output_tokens_list = [ - state.token_ids[state.prompt_len :] for _, state in decode_reqs - ] - generators = { - i: state.generator - for i, (_, state) in enumerate(decode_reqs) - if state.generator is not None - } - logits_torch = mlx_to_torch( - next_token_logits.astype(mx.float32), device=self.device - ) - metadata = self._make_sampling_metadata( - sampling_params_list, - prompt_token_ids_list, - output_tokens_list, - generators=generators, - ) - output = self._sampler.forward(logits_torch, metadata) - next_tokens = [ - int(output.sampled_token_ids[i, 0].item()) for i in range(batch_size) - ] - - # Update state - for i, (req_id, state) in enumerate(decode_reqs): - state.token_ids.append(next_tokens[i]) - state.generated_tokens += 1 - self._paged_request_seq_lens[req_id] = ( - self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 2) + 1 - ) - if self._rust_state_manager is not None: - self._rust_state_manager.append_token(req_id, next_tokens[i]) - - return next_tokens + return prefill_next_tokens, decode_next_tokens def execute_model( self, scheduler_output: SchedulerOutput @@ -1713,17 +1637,31 @@ def execute_model( req_id_to_index: dict[str, int] = {} sampled_tokens: list[list[int]] = [] - # === PHASE 1: Process new requests (prefill phase) === + # === Collect all requests into one unified batch === new_reqs = scheduler_output.scheduled_new_reqs + cached_reqs = scheduler_output.scheduled_cached_reqs - # First pass: handle intermediate chunks immediately, collect - # complete paged prefill requests for potential packing. - paged_complete: list[ + # Paged-attention entries collected for the single unified forward. + # Each prefill entry: (output_idx, req_id, token_ids, sampling_params, + # block_ids, generator, is_new, is_intermediate, + # prompt_len, start_pos) + paged_prefill_entries: list[ tuple[ - int, str, list[int], SamplingParams, list[int], torch.Generator | None + int, + str, + list[int], + SamplingParams, + list[int], + torch.Generator | None, + bool, + bool, + int, + int, ] ] = [] + paged_decode_reqs: list[tuple[str, RequestState]] = [] + # --- New requests --- for new_req in new_reqs: req_id = new_req.req_id token_ids = new_req.prompt_token_ids or [] @@ -1734,40 +1672,42 @@ def execute_model( req_id_to_index[req_id] = output_idx if not token_ids: - sampled_tokens.append([0]) # Fallback + sampled_tokens.append([0]) continue generator = _create_request_generator(self.device, sampling_params) if self._paged_kv_cache is not None: - # Paged attention path (Metal kernel) sched_block_ids = list(new_req.block_ids[0]) scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) computed_tokens = new_req.num_computed_tokens prompt_len = len(token_ids) + cur_len = computed_tokens + scheduled_tokens + is_intermediate = cur_len < prompt_len - if computed_tokens + scheduled_tokens < prompt_len: - # Intermediate chunk: sample then drop (async scheduler - # allocates no placeholder for intermediate chunks). - cur_len = computed_tokens + scheduled_tokens - _discarded = self._prefill_packed_paged( - [ - ( - req_id, - token_ids[:cur_len], - sampling_params, - sched_block_ids, - generator, - None, - ), - ] - )[0] - cache: list = [] - sampled_tokens.append([]) + sampled_tokens.append([]) # placeholder + paged_prefill_entries.append( + ( + output_idx, + req_id, + token_ids[computed_tokens:cur_len], + sampling_params, + sched_block_ids, + generator, + True, # is_new + is_intermediate, + prompt_len, + computed_tokens, # start_pos / RoPE offset + ) + ) + + # Create state immediately for intermediate chunks + # (needed if the request appears as cached next step). + if is_intermediate: self._request_states[req_id] = RequestState( token_ids=list(token_ids), prompt_len=prompt_len, - cache=cache, + cache=[], sampling_params=sampling_params, generator=generator, generated_tokens=0, @@ -1777,20 +1717,6 @@ def execute_model( self._rust_state_manager.add_request( req_id, list(token_ids[:cur_len]) ) - continue - - # Complete prefill — collect for packed processing - sampled_tokens.append([]) # placeholder, filled below - paged_complete.append( - ( - output_idx, - req_id, - token_ids, - sampling_params, - sched_block_ids, - generator, - ) - ) else: next_token, cache = self._prefill_single( req_id, @@ -1813,23 +1739,14 @@ def execute_model( req_id, list(token_ids) + [next_token] ) - # Process collected complete paged prefill requests via unified - # packed path (handles 1 or more requests). - if paged_complete: - self._run_packed_prefill(paged_complete, sampled_tokens) - - # === PHASE 2: Process cached requests (TRUE batched decode) === - cached_reqs = scheduler_output.scheduled_cached_reqs + # --- Cached requests --- decode_req_ids = list(cached_reqs.req_ids) if decode_req_ids: if self._paged_kv_cache is not None: - # Paged attention path: unified flow using model-runner-local - # state (state.generated_tokens) instead of is_context_phase(). req_id_to_cached_idx = { rid: i for i, rid in enumerate(cached_reqs.req_ids) } - paged_decode_reqs: list[tuple[str, RequestState]] = [] # Update block_ids from scheduler (append or replace on resume) for i, req_id in enumerate(cached_reqs.req_ids): @@ -1843,13 +1760,6 @@ def execute_model( if new_block_ids is not None: state.block_ids.extend(new_block_ids[0]) else: - # Preempted → full recompute with fresh blocks. - # Keep prompt_len at the original prompt boundary - # (used for sampling penalty split). The prefill - # loop uses len(state.token_ids) — which already - # includes previously generated output tokens — - # to determine the recompute scope, matching - # upstream vLLM's use of request.num_tokens. assert new_block_ids is not None state.block_ids = list(new_block_ids[0]) state.generated_tokens = 0 @@ -1860,17 +1770,17 @@ def execute_model( req_id, list(state.token_ids) ) + # Categorise each cached request for req_id in decode_req_ids: state = self._request_states.get(req_id) if state is None: - # Edge case: no state — emit dummy token req_ids.append(req_id) req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([0]) continue if state.generated_tokens == 0: - # Still prefilling prompt (or re-prefilling after preemption) + # Still prefilling (or re-prefilling after preemption) idx = req_id_to_cached_idx.get(req_id) if idx is not None and idx < len( cached_reqs.num_computed_tokens @@ -1879,63 +1789,32 @@ def execute_model( else: computed = self._paged_request_seq_lens.get(req_id, 0) scheduled = scheduler_output.num_scheduled_tokens.get(req_id, 0) - target_len = computed + scheduled # FIX: was just `computed` - - if target_len < len(state.token_ids): - # Intermediate chunk: sample then drop - _discarded = self._prefill_packed_paged( - [ - ( - req_id, - state.token_ids[:target_len], - state.sampling_params, - state.block_ids, - state.generator, - None, - ), - ] - )[0] - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([]) - else: - # Last chunk: sample and keep (drains async placeholder) - next_token = self._prefill_packed_paged( - [ - ( - req_id, - state.token_ids, - state.sampling_params, - state.block_ids, - state.generator, - state.prompt_len, - ), - ] - )[0] - state.token_ids.append(next_token) - state.generated_tokens = ( - len(state.token_ids) - state.prompt_len + target_len = computed + scheduled + is_intermediate = target_len < len(state.token_ids) + + req_ids.append(req_id) + output_idx = len(req_ids) - 1 + req_id_to_index[req_id] = output_idx + sampled_tokens.append([]) # placeholder + + paged_prefill_entries.append( + ( + output_idx, + req_id, + state.token_ids[computed:target_len], + state.sampling_params, + state.block_ids, + state.generator, + False, # is_new + is_intermediate, + state.prompt_len, + computed, # start_pos / RoPE offset ) - if self._rust_state_manager is not None: - self._rust_state_manager.append_token( - req_id, next_token - ) - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([next_token]) + ) else: - # Decode phase: collect for batched decode paged_decode_reqs.append((req_id, state)) - - # Batch decode all generation-phase requests - if paged_decode_reqs: - decode_tokens = self._batched_decode_paged(paged_decode_reqs) - for i, (req_id, _) in enumerate(paged_decode_reqs): - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([decode_tokens[i]]) else: - # Collect all valid decode requests + # Non-paged (MLX) path — unchanged valid_decode_reqs = [] for req_id in decode_req_ids: state = self._request_states.get(req_id) @@ -1948,19 +1827,85 @@ def execute_model( else: decode_tokens = self._sequential_decode(valid_decode_reqs) - # Add decode results to output for i, (req_id, _) in enumerate(valid_decode_reqs): req_ids.append(req_id) req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([decode_tokens[i]]) - # Handle requests with no cached state (edge case) for req_id in decode_req_ids: if req_id not in req_id_to_index: req_ids.append(req_id) req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([0]) + # === Single unified forward pass (paged path) === + if paged_prefill_entries or paged_decode_reqs: + prefill_pack = [ + ( + rid, + tids, + sp, + bids, + gen, + prompt_len if not is_intermediate else None, + start_pos, + ) + for _, rid, tids, sp, bids, gen, _is_new, is_intermediate, prompt_len, start_pos in paged_prefill_entries + ] + prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( + prefill_pack, paged_decode_reqs + ) + + # Post-process prefill results + for i, ( + idx, + rid, + tids, + sp, + bids, + gen, + is_new, + is_intermediate, + _prompt_len, + _start_pos, + ) in enumerate(paged_prefill_entries): + nt = prefill_tokens[i] + + if is_intermediate: + # KV cache populated; discard sampled token + sampled_tokens[idx] = [] + elif is_new: + assert _start_pos == 0, ( + "new complete prefill with start_pos > 0 not supported " + "(prefix caching not yet implemented in unified path)" + ) + sampled_tokens[idx] = [nt] + self._request_states[rid] = RequestState( + token_ids=list(tids) + [nt], + prompt_len=len(tids), + cache=[], + sampling_params=sp, + generator=gen, + generated_tokens=1, + block_ids=bids, + ) + if self._rust_state_manager is not None: + self._rust_state_manager.add_request(rid, list(tids) + [nt]) + else: + # Cached last chunk — append token to existing state + sampled_tokens[idx] = [nt] + state = self._request_states[rid] + state.token_ids.append(nt) + state.generated_tokens = len(state.token_ids) - state.prompt_len + if self._rust_state_manager is not None: + self._rust_state_manager.append_token(rid, nt) + + # Post-process decode results + for i, (req_id, _) in enumerate(paged_decode_reqs): + req_ids.append(req_id) + req_id_to_index[req_id] = len(req_ids) - 1 + sampled_tokens.append([decode_tokens[i]]) + # Consistency check: every scheduled request must be represented in # req_ids, and decode-phase scheduled requests should not emit empty # token lists. Missing/empty outputs here can leave placeholders stale.