From 5c46844ef21db94a160bec0793723b49031c3eb8 Mon Sep 17 00:00:00 2001 From: ran Date: Wed, 18 Mar 2026 01:58:30 -0500 Subject: [PATCH 01/12] unified prefilling & decoding prototype Signed-off-by: ran --- tests/test_unified_batching.py | 105 +++++++ .../packed_prefill_compat.py | 18 +- .../metal_kernel_backend/paged_attention.py | 12 +- vllm_metal/paged_attention_common.py | 58 ++++ vllm_metal/v1/model_runner.py | 269 +++++++++++++++++- 5 files changed, 441 insertions(+), 21 deletions(-) create mode 100644 tests/test_unified_batching.py diff --git a/tests/test_unified_batching.py b/tests/test_unified_batching.py new file mode 100644 index 00000000..44610290 --- /dev/null +++ b/tests/test_unified_batching.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Smoke test for unified prefill+decode forward pass (continuous batching). + +Runs vLLM offline inference with max_num_seqs > 1 so the scheduler batches +multiple requests together, triggering the unified forward pass where prefill +and decode happen in a single model call. + +Due to floating-point non-determinism when batching on Metal (MLX GEMM uses +different internal kernels for different batch sizes), exact golden-token +matching is NOT expected. Instead, this test: + 1. Verifies all requests complete without errors. + 2. Prints the generated text for manual inspection (not gibberish). + 3. Optionally checks whether outputs still match the single-request golden. + +Run: + python -m pytest tests/test_unified_batching.py -v -s +""" + +from __future__ import annotations + +import os + +import pytest +from vllm import LLM, SamplingParams + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MAX_TOKENS = 10 +MAX_NUM_SEQS = 4 # key: allow concurrent requests + +PROMPTS = [ + "The capital of France is", + "The weather today is not", + "One plus one equals", + "The largest planet in our solar system is", + "Water boils at a temperature of", + "Machine learning is", +] + +# fmt: off +# Golden from max_num_seqs=1 (single-request, deterministic). +# Used only for informational comparison — NOT asserted. +GOLDEN_SINGLE = { + "The capital of France is": [12095, 13, 576, 6722, 315, 9625, 374, 1083, 279, 6722], + "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], + "One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11], + "The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13], + "Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315], + "Machine learning is": [264, 7988, 5392, 429, 702, 13791, 1506, 279, 2070, 315], +} +# fmt: on + + +@pytest.fixture(autouse=True, scope="module") +def _set_env(): + with pytest.MonkeyPatch.context() as mp: + mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") + mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") + yield + + +@pytest.fixture(scope="module") +def vllm_outputs(): + """Run vLLM offline inference with concurrent batching.""" + llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=MAX_NUM_SEQS) + + # Verify paged KV + attention wrapper are active + runner = llm.llm_engine.model_executor.driver_worker.model_runner + assert runner._paged_kv_cache is not None, "Paged KV cache not initialised" + + from vllm_metal.metal_kernel_backend.paged_attention import ( + MetalKernelPagedAttentionWrapper, + ) + + attn = runner.model.model.layers[0].self_attn + assert isinstance(attn, MetalKernelPagedAttentionWrapper) + + sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) + outputs = llm.generate(PROMPTS, sp) + return {o.prompt: o for o in outputs} + + +class TestUnifiedBatching: + @pytest.mark.slow + @pytest.mark.parametrize("prompt", PROMPTS) + def test_generate_coherent(self, vllm_outputs, prompt): + """Verify output is non-empty and print for manual inspection.""" + output = vllm_outputs[prompt] + token_ids = list(output.outputs[0].token_ids) + text = output.outputs[0].text + + golden = GOLDEN_SINGLE.get(prompt, []) + match = token_ids == golden + + print(f"\n prompt: {prompt!r}") + print(f" output: {text!r}") + print(f" ids: {token_ids}") + print(f" golden: {golden}") + print(f" match: {'YES' if match else 'no (expected with batching)'}") + + # Basic sanity: output should not be empty + assert len(token_ids) == MAX_TOKENS, ( + f"Expected {MAX_TOKENS} tokens, got {len(token_ids)}" + ) + assert len(text.strip()) > 0, "Generated text is empty" diff --git a/vllm_metal/metal_kernel_backend/packed_prefill_compat.py b/vllm_metal/metal_kernel_backend/packed_prefill_compat.py index 72ec5d47..b41ac0c4 100644 --- a/vllm_metal/metal_kernel_backend/packed_prefill_compat.py +++ b/vllm_metal/metal_kernel_backend/packed_prefill_compat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SCAFFOLDING: remove when varlen kernel handles position encoding natively. -# -# Per-request RoPE helper for packed prefill. +# Per-request RoPE helper for packed / unified forward passes. from __future__ import annotations @@ -13,16 +11,22 @@ def apply_packed_rope( queries: mx.array, keys: mx.array, cu_seqlens: list[int], + offsets: list[int] | None = None, ) -> tuple[mx.array, mx.array]: - """Apply per-request RoPE with position reset for packed prefill. + """Apply per-request RoPE for packed sequences. - SCAFFOLDING: remove when varlen kernel is ready. + Each segment delimited by ``cu_seqlens`` gets its own RoPE application + starting at the corresponding offset. When *offsets* is ``None`` every + segment starts at position 0 (pure prefill). For unified prefill+decode + batches, decode segments carry ``offset=seq_len`` while prefill segments + keep ``offset=0``. """ q_parts = [] k_parts = [] for i in range(len(cu_seqlens) - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] - q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0)) - k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0)) + off = offsets[i] if offsets is not None else 0 + q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=off)) + k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=off)) return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2) diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index 79813065..3f8380a1 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -107,9 +107,15 @@ def _metal_kernel_prefill_attention( "attribute. Only RoPE-based models are supported by paged attention." ) - # NOTE: apply_packed_rope always uses offset=0 per request. Chunked - # prefill will need per-request offsets (like decode) for continuation chunks. - queries, keys = apply_packed_rope(attn_module, queries, keys, ctx.cu_seqlens) + # Per-segment RoPE: offset=0 for fresh prefill, offset=seq_len for decode + # tokens in a unified batch (ctx.offsets populated by prepare_unified). + queries, keys = apply_packed_rope( + attn_module, + queries, + keys, + ctx.cu_seqlens, + offsets=ctx.offsets if ctx.offsets else None, + ) # Reshape to 3D: (1, heads, L, hd) → (L, heads, hd) q_3d = mx.contiguous(queries[0].transpose(1, 0, 2).astype(cache.dtype)) diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index be148c61..f80952f3 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -188,6 +188,64 @@ def prepare_prefill_packed( ) +def prepare_unified( + decode_requests: list[tuple[list[int], int]], + prefill_requests: list[tuple[list[int], int]], + block_size: int, +) -> None: + """Compute metadata for a unified prefill + decode forward pass. + + Packs decode tokens (1 per request) followed by prefill tokens into a + single flattened sequence. ``cu_seqlens`` marks request boundaries so + the varlen kernel handles both decode (length-1) and prefill (length-N) + subsequences in one dispatch. + + Args: + decode_requests: list of ``(block_ids, seq_len)`` for decode requests. + ``seq_len`` = tokens already cached before this step. + prefill_requests: list of ``(block_ids, num_tokens)`` for prefill. + block_size: tokens per KV cache block. + """ + slot_mapping: list[int] = [] + cu_seqlens: list[int] = [0] + block_tables: list[list[int]] = [] + context_lens: list[int] = [] + offsets: list[int] = [] + + # Decode requests first (1 token each) + for block_ids, seq_len in decode_requests: + new_pos = seq_len + block_idx = block_ids[new_pos // block_size] + slot = block_idx * block_size + (new_pos % block_size) + slot_mapping.append(slot) + cu_seqlens.append(cu_seqlens[-1] + 1) + block_tables.append(block_ids) + context_lens.append(seq_len + 1) # including new token + offsets.append(seq_len) # RoPE position + + # Prefill requests (variable tokens each) + for block_ids, num_tokens in prefill_requests: + for pos in range(num_tokens): + block_idx = block_ids[pos // block_size] + slot = block_idx * block_size + (pos % block_size) + slot_mapping.append(slot) + cu_seqlens.append(cu_seqlens[-1] + num_tokens) + block_tables.append(block_ids) + context_lens.append(num_tokens) + offsets.append(0) # prefill starts at position 0 + + set_context( + PagedAttentionContext( + is_prefill=True, # use varlen code path + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + cu_seqlens=cu_seqlens, + offsets=offsets, + ) + ) + + def prepare_decode( requests: list[tuple[list[int], int]], block_size: int, diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 350e3abe..ba86b355 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -54,6 +54,7 @@ clear_context, prepare_decode, prepare_prefill_packed, + prepare_unified, ) from vllm_metal.pytorch_backend.tensor_bridge import mlx_to_torch from vllm_metal.stt.config import ( @@ -1745,6 +1746,198 @@ def _run_packed_prefill( if self._rust_state_manager is not None: self._rust_state_manager.add_request(rid, list(tids) + [nt]) + # ------------------------------------------------------------------ + # Unified prefill + decode (single forward pass) + # ------------------------------------------------------------------ + + def _unified_prefill_decode_paged( + self, + prefill_reqs: list[ + tuple[ + str, + list[int], + SamplingParams, + list[int], + torch.Generator | None, + int | None, + ] + ], + decode_reqs: list[tuple[str, RequestState]], + ) -> tuple[list[int], list[int]]: + """Single forward pass for mixed prefill + decode requests. + + Packs decode tokens (1 per request) followed by prefill tokens into + a flat ``(1, total_tokens)`` input. The varlen kernel uses + ``cu_seqlens`` to handle variable-length subsequences. + + Args: + prefill_reqs: list of + ``(req_id, token_ids, sampling_params, block_ids, + generator, prompt_len)`` — complete prefill requests. + decode_reqs: list of ``(req_id, RequestState)`` — decode requests. + + Returns: + ``(prefill_next_tokens, decode_next_tokens)`` + """ + num_decode = len(decode_reqs) + num_prefill = len(prefill_reqs) + + # ---- build unified token sequence: decode first, then prefill ---- + all_token_ids: list[int] = [] + + # Decode: last token per request + if self._rust_state_manager is not None: + last_tokens = self._rust_state_manager.get_last_tokens_batch( + [rid for rid, _ in decode_reqs] + ) + else: + last_tokens = [ + state.token_ids[-1] if state.token_ids else 0 + for _, state in decode_reqs + ] + all_token_ids.extend(last_tokens) + + # Prefill: all tokens per request + for _, token_ids, _, _, _, _ in prefill_reqs: + all_token_ids.extend(token_ids) + + # ---- build metadata for prepare_unified ---- + decode_info: list[tuple[list[int], int]] = [] + for req_id, state in decode_reqs: + seq_len = self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 1) + decode_info.append((state.block_ids, seq_len)) + + prefill_info: list[tuple[list[int], int]] = [] + for _, token_ids, _, block_ids, _, _ in prefill_reqs: + prefill_info.append((block_ids, len(token_ids))) + + prepare_unified(decode_info, prefill_info, self._paged_block_size) + + # ---- forward ---- + offset_caches = [OffsetCache(0) for _ in range(self.num_layers)] + input_ids = mx.array([all_token_ids], dtype=mx.int32) + try: + model_output = self.model(input_ids, cache=offset_caches) + logits = self._extract_logits(model_output) + finally: + clear_context() + + # ---- build cu_seqlens for logit extraction ---- + cu_seqlens: list[int] = [0] + for _ in decode_reqs: + cu_seqlens.append(cu_seqlens[-1] + 1) + for _, token_ids, _, _, _, _ in prefill_reqs: + cu_seqlens.append(cu_seqlens[-1] + len(token_ids)) + + # ---- sample decode tokens ---- + decode_next_tokens: list[int] = [] + if decode_reqs: + # All decode logits are at positions 0..num_decode-1 + decode_logits = logits[0, :num_decode, :] # (num_decode, vocab) + + sampling_params_list = [state.sampling_params for _, state in decode_reqs] + all_greedy = all(sp.temperature < 1e-5 for sp in sampling_params_list) + any_advanced = any( + sp.top_k > 0 + or sp.top_p < 1.0 + or sp.frequency_penalty != 0 + or sp.presence_penalty != 0 + or sp.repetition_penalty != 1.0 + for sp in sampling_params_list + ) + + if all_greedy and not any_advanced: + next_tokens_mlx = _mlx_greedy_sample(decode_logits) + mx.eval(next_tokens_mlx) + decode_next_tokens = next_tokens_mlx.tolist() + else: + mx.eval(decode_logits) + prompt_token_ids_list = [ + state.token_ids[: state.prompt_len] for _, state in decode_reqs + ] + output_tokens_list = [ + state.token_ids[state.prompt_len :] for _, state in decode_reqs + ] + generators = { + i: state.generator + for i, (_, state) in enumerate(decode_reqs) + if state.generator is not None + } + logits_torch = mlx_to_torch( + decode_logits.astype(mx.float32), device=self.device + ) + metadata = self._make_sampling_metadata( + sampling_params_list, + prompt_token_ids_list, + output_tokens_list, + generators=generators, + ) + output = self._sampler.forward(logits_torch, metadata) + decode_next_tokens = [ + int(output.sampled_token_ids[i, 0].item()) + for i in range(num_decode) + ] + + # Update decode state + for i, (req_id, state) in enumerate(decode_reqs): + state.token_ids.append(decode_next_tokens[i]) + state.generated_tokens += 1 + self._paged_request_seq_lens[req_id] = ( + self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 2) + + 1 + ) + if self._rust_state_manager is not None: + self._rust_state_manager.append_token(req_id, decode_next_tokens[i]) + + # ---- sample prefill tokens ---- + prefill_next_tokens: list[int] = [] + for j, ( + req_id, + token_ids, + sampling_params, + _block_ids, + generator, + prompt_len, + ) in enumerate(prefill_reqs): + last_idx = cu_seqlens[num_decode + j + 1] - 1 + last_logits = logits[:, last_idx : last_idx + 1, :] + + if prompt_len is None: + prompt_len = len(token_ids) + + is_greedy = sampling_params.temperature < 1e-5 + needs_advanced = ( + sampling_params.top_k > 0 + or sampling_params.top_p < 1.0 + or sampling_params.frequency_penalty != 0 + or sampling_params.presence_penalty != 0 + or sampling_params.repetition_penalty != 1.0 + ) + + if is_greedy and not needs_advanced: + next_token_mlx = _mlx_greedy_sample(last_logits[0]) + mx.eval(next_token_mlx) + next_token = int(next_token_mlx.item()) + else: + mx.eval(last_logits) + logits_torch = mlx_to_torch( + last_logits[0].astype(mx.float32), device=self.device + ) + generators = {} if generator is None else {0: generator} + metadata = self._make_sampling_metadata( + [sampling_params], + [token_ids[:prompt_len]], + [token_ids[prompt_len:]], + generators=generators, + ) + output = self._sampler.forward(logits_torch, metadata) + next_token = int(output.sampled_token_ids[0, 0].item()) + + self._paged_request_seq_lens[req_id] = len(token_ids) + prefill_next_tokens.append(next_token) + + return prefill_next_tokens, decode_next_tokens + def _batched_decode_paged( self, decode_reqs: list[tuple[str, RequestState]] ) -> list[int]: @@ -1980,10 +2173,8 @@ def execute_model( req_id, list(token_ids) + [next_token] ) - # Process collected complete paged prefill requests via unified - # packed path (handles 1 or more requests). - if paged_complete: - self._run_packed_prefill(paged_complete, sampled_tokens) + # Defer paged_complete to unified forward pass during PHASE 2. + # If no cached requests exist, it is consumed after PHASE 2. # === PHASE 2: Process cached requests (TRUE batched decode) === cached_reqs = scheduler_output.scheduled_cached_reqs @@ -2094,13 +2285,64 @@ def execute_model( # Decode phase: collect for batched decode paged_decode_reqs.append((req_id, state)) - # Batch decode all generation-phase requests - if paged_decode_reqs: - decode_tokens = self._batched_decode_paged(paged_decode_reqs) - for i, (req_id, _) in enumerate(paged_decode_reqs): - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([decode_tokens[i]]) + # === Unified forward: complete prefills + decode === + if paged_decode_reqs or paged_complete: + total_prefill_tokens = sum( + len(entry[2]) for entry in paged_complete + ) + total_tokens = total_prefill_tokens + len(paged_decode_reqs) + + if total_tokens <= MAX_PACKED_PREFILL_TOKENS: + # Build pack input for prefill requests + prefill_pack = [ + (rid, tids, sp, bids, gen, None) + for _, rid, tids, sp, bids, gen in paged_complete + ] + prefill_tokens, decode_tokens = ( + self._unified_prefill_decode_paged( + prefill_pack, paged_decode_reqs + ) + ) + + # Record prefill results + create state + for i, (idx, rid, tids, sp, bids, gen) in enumerate( + paged_complete + ): + nt = prefill_tokens[i] + sampled_tokens[idx] = [nt] + self._request_states[rid] = RequestState( + token_ids=list(tids) + [nt], + prompt_len=len(tids), + cache=[], + sampling_params=sp, + generator=gen, + generated_tokens=1, + block_ids=bids, + ) + if self._rust_state_manager is not None: + self._rust_state_manager.add_request( + rid, list(tids) + [nt] + ) + paged_complete = [] # consumed + + # Record decode results + for i, (req_id, _) in enumerate(paged_decode_reqs): + req_ids.append(req_id) + req_id_to_index[req_id] = len(req_ids) - 1 + sampled_tokens.append([decode_tokens[i]]) + else: + # Fallback: separate passes for large batches + if paged_complete: + self._run_packed_prefill(paged_complete, sampled_tokens) + paged_complete = [] # consumed + if paged_decode_reqs: + decode_tokens = self._batched_decode_paged( + paged_decode_reqs + ) + for i, (req_id, _) in enumerate(paged_decode_reqs): + req_ids.append(req_id) + req_id_to_index[req_id] = len(req_ids) - 1 + sampled_tokens.append([decode_tokens[i]]) else: # Collect all valid decode requests valid_decode_reqs = [] @@ -2128,6 +2370,11 @@ def execute_model( req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([0]) + # Handle any paged_complete not consumed by the unified pass + # (happens when no cached requests exist in this step). + if paged_complete: + self._run_packed_prefill(paged_complete, sampled_tokens) + # Consistency check: every scheduled request must be represented in # req_ids, and decode-phase scheduled requests should not emit empty # token lists. Missing/empty outputs here can leave placeholders stale. From 245c7dc934e4105d5ac3b036836bf3a9e6f67813 Mon Sep 17 00:00:00 2001 From: ran Date: Wed, 18 Mar 2026 02:12:31 -0500 Subject: [PATCH 02/12] remove easy dead code Signed-off-by: ran --- tests/test_paged_attention.py | 34 +++- .../metal_kernel_backend/paged_attention.py | 118 +---------- vllm_metal/paged_attention_common.py | 58 ++---- vllm_metal/v1/model_runner.py | 189 +++--------------- 4 files changed, 69 insertions(+), 330 deletions(-) diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index bcd70f95..2dee33c9 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -11,8 +11,8 @@ OffsetCache, clear_context, get_context, - prepare_decode, prepare_prefill_packed, + prepare_unified, ) @@ -75,20 +75,36 @@ def test_prepare_prefill_packed_single_request(self): assert ctx.block_tables == [[5, 6]] assert ctx.context_lens == [5] - def test_prepare_decode(self): - # Arrange - requests = [([5, 6], 7)] - - # Act - prepare_decode(requests, block_size=4) + def test_prepare_unified_decode_only(self): + # Single decode request via prepare_unified + decode_requests = [([5, 6], 7)] + prepare_unified(decode_requests, [], block_size=4) ctx = get_context() - # Assert — new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27 + # new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27 assert ctx is not None - assert not ctx.is_prefill + assert ctx.is_prefill # unified always sets True assert ctx.slot_mapping == [27] assert ctx.context_lens == [8] assert ctx.offsets == [7] + assert ctx.cu_seqlens == [0, 1] + + def test_prepare_unified_mixed(self): + # 1 decode + 1 prefill + decode_requests = [([5, 6], 7)] # seq_len=7 + prefill_requests = [([10, 11], 5)] # 5 tokens + + prepare_unified(decode_requests, prefill_requests, block_size=4) + ctx = get_context() + + assert ctx is not None + # Decode slot: pos=7, block 6, slot=6*4+3=27 + # Prefill slots: block 10 slots 40,41,42,43; block 11 slot 44 + assert ctx.slot_mapping == [27, 40, 41, 42, 43, 44] + assert ctx.cu_seqlens == [0, 1, 6] + assert ctx.offsets == [7, 0] + assert ctx.context_lens == [8, 5] + assert ctx.block_tables == [[5, 6], [10, 11]] class TestPackedRoPE: diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index 3f8380a1..c59ce563 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -10,8 +10,8 @@ All operations use MLX arrays end-to-end — no PyTorch MPS bridge. -Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_prefill_packed``, -``prepare_decode``, ``clear_context`` from ``paged_attention_common``. +Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_unified``, +``prepare_prefill_packed``, ``clear_context`` from ``paged_attention_common``. Backend replacement guide ------------------------- @@ -174,109 +174,6 @@ def _metal_kernel_prefill_attention( return attn_module.o_proj(out) -# --------------------------------------------------------------------------- -# Decode attention (reshape_and_cache + paged_attention_v1) -# --------------------------------------------------------------------------- - - -def _metal_kernel_decode_attention( - attn_module: Any, - queries: mx.array, - keys: mx.array, - values: mx.array, - cache: MetalPagedKVCache, - layer_idx: int, - ctx: PagedAttentionContext, -) -> mx.array: - """Batched decode: B=batch_size, L=1. - - Per-request RoPE, write new token via ``reshape_and_cache``, - then zero-copy attention via ``paged_attention_v1``. - """ - B = queries.shape[0] # noqa: N806 - n_heads = queries.shape[1] - head_dim = queries.shape[3] - - # Per-request RoPE - if not hasattr(attn_module, "rope"): - raise NotImplementedError( - f"Attention module {type(attn_module).__name__} does not have a 'rope' " - "attribute. Only RoPE-based models are supported by paged attention." - ) - q_parts = [] - k_parts = [] - for i in range(B): - q_parts.append(attn_module.rope(queries[i : i + 1], offset=ctx.offsets[i])) - k_parts.append(attn_module.rope(keys[i : i + 1], offset=ctx.offsets[i])) - queries = mx.concatenate(q_parts, axis=0) # (B, heads, 1, head_dim) - keys_new = mx.concatenate(k_parts, axis=0) # (B, kv_heads, 1, head_dim) - - # Squeeze seq dim: (B, heads, 1, hd) → (B, heads, hd) - q_3d = mx.contiguous(queries[:, :, 0, :].astype(cache.dtype)) - k_3d = mx.contiguous(keys_new[:, :, 0, :].astype(cache.dtype)) - v_3d = mx.contiguous(values[:, :, 0, :].astype(cache.dtype)) - - slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64) - - # Build block_tables and seq_lens - max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables) - block_tables_list = [ - bt + [0] * (max_blocks_per_seq - len(bt)) for bt in ctx.block_tables - ] - block_tables = mx.array(block_tables_list, dtype=mx.int32) - seq_lens = mx.array(ctx.context_lens, dtype=mx.int32) - - # Eval all inputs before kernel dispatch - mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens) - - ops = get_ops() - - # Write new K/V tokens into paged cache - ops.reshape_and_cache( - k_3d, - v_3d, - cache.key_caches[layer_idx], - cache.value_caches[layer_idx], - slot_mapping, - ) - - # Allocate output - out = mx.zeros((B, n_heads, head_dim), dtype=cache.dtype) - mx.eval(out) - - max_seq_len = max(ctx.context_lens) - scale = attn_module.scale - - # Build cu_seqlens_q for varlen dispatch: decode has q_len=1 per sequence. - cu_seqlens_q = mx.arange(B + 1, dtype=mx.int32) - mx.eval(cu_seqlens_q) - - # Zero-copy paged attention (v2, online softmax, varlen-capable) - ops.paged_attention_v2_online( - out, - q_3d, - cache.key_caches[layer_idx], - cache.value_caches[layer_idx], - cache.num_kv_heads, - scale, - 0.0, # softcap (0 = disabled) - block_tables, - seq_lens, - cu_seqlens_q, - cache.block_size, - max_seq_len, - -1, # sliding_window (-1 = disabled) - ) - - # Synchronize GPU: paged_attention_v2_online wrote to out's buffer via a raw - # Metal dispatch that MLX's lazy graph doesn't track. mx.eval(out) would - # be a no-op here (out was already evaluated as zeros), so we must use - # mx.synchronize() to flush the command encoder and wait for the kernel. - mx.synchronize() - out = out.reshape(B, 1, n_heads * head_dim) - return attn_module.o_proj(out) - - # --------------------------------------------------------------------------- # Wrapper nn.Module # --------------------------------------------------------------------------- @@ -335,14 +232,9 @@ def __call__(self, x: mx.array, mask: Any = None, cache: Any = None) -> mx.array keys = keys.transpose(0, 2, 1, 3) values = values.transpose(0, 2, 1, 3) - if ctx.is_prefill: - return _metal_kernel_prefill_attention( - inner, queries, keys, values, kv_cache, layer_idx, ctx - ) - else: - return _metal_kernel_decode_attention( - inner, queries, keys, values, kv_cache, layer_idx, ctx - ) + return _metal_kernel_prefill_attention( + inner, queries, keys, values, kv_cache, layer_idx, ctx + ) # --------------------------------------------------------------------------- diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index f80952f3..994f4224 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -5,9 +5,10 @@ both the Metal kernel paged attention backend and the model runner. Usage: - 1. Before each forward pass call ``prepare_prefill_packed()`` or ``prepare_decode()`` + 1. Before each forward pass call ``prepare_unified()`` (or + ``prepare_prefill_packed()`` for prefill-only intermediate chunks) 2. Run ``model(input_ids, cache=offset_caches)`` as normal - 3. The attention wrapper reads ``get_context()`` to decide prefill vs decode + 3. The attention wrapper reads ``get_context()`` for paged metadata 4. Call ``clear_context()`` after the forward pass """ @@ -34,17 +35,21 @@ @dataclass class PagedAttentionContext: - """Context set before each forward pass, read by patched attention.""" + """Context set before each forward pass, read by patched attention. - is_prefill: bool + All forward passes use the varlen kernel with ``cu_seqlens`` to handle + variable-length subsequences (both prefill and decode tokens packed + into a single flat sequence). + """ + + is_prefill: bool # kept for compatibility; always True slot_mapping: list[int] - # decode-only fields block_tables: list[list[int]] = field(default_factory=list) context_lens: list[int] = field(default_factory=list) + # Per-segment RoPE offsets: 0 for fresh prefill, seq_len for decode. offsets: list[int] = field(default_factory=list) - # packed prefill fields — set when multiple requests are packed into - # a single forward pass. cu_seqlens is a cumulative sequence length - # array: [0, len0, len0+len1, ...] (length = num_requests + 1). + # Cumulative sequence length array: [0, len0, len0+len1, ...] + # (length = num_requests + 1). cu_seqlens: list[int] | None = None @@ -244,40 +249,3 @@ def prepare_unified( offsets=offsets, ) ) - - -def prepare_decode( - requests: list[tuple[list[int], int]], - block_size: int, -) -> None: - """Compute slot_mapping, block_tables, context_lens, offsets for decode. - - Args: - requests: list of (block_ids, seq_len) per request. - seq_len = number of tokens already stored (before this step). - block_size: tokens per block - """ - slot_mapping: list[int] = [] - block_tables: list[list[int]] = [] - context_lens: list[int] = [] - offsets: list[int] = [] - - for block_ids, seq_len in requests: - # Slot for the new token at position seq_len - new_pos = seq_len - block_idx = block_ids[new_pos // block_size] - slot = block_idx * block_size + (new_pos % block_size) - slot_mapping.append(slot) - block_tables.append(block_ids) - context_lens.append(seq_len + 1) # including new token - offsets.append(seq_len) # RoPE position = seq_len - - set_context( - PagedAttentionContext( - is_prefill=False, - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - offsets=offsets, - ) - ) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index ba86b355..d365d458 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -52,7 +52,6 @@ from vllm_metal.paged_attention_common import ( OffsetCache, clear_context, - prepare_decode, prepare_prefill_packed, prepare_unified, ) @@ -1938,116 +1937,6 @@ def _unified_prefill_decode_paged( return prefill_next_tokens, decode_next_tokens - def _batched_decode_paged( - self, decode_reqs: list[tuple[str, RequestState]] - ) -> list[int]: - """Paged-attention batched decode. - - Uses MLX for projections + per-request RoPE, then the HF kernel for - reshape_and_cache + paged_attention_v1 (zero-copy from block tables). - """ - - batch_size = len(decode_reqs) - - # Build request info for prepare_decode - requests_info: list[tuple[list[int], int]] = [] - for req_id, state in decode_reqs: - seq_len = self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 1) - requests_info.append((state.block_ids, seq_len)) - - # Stash per-request metadata (slot_mapping, block_tables, context_lens, - # offsets) in thread-local for the attention wrappers. - prepare_decode(requests_info, self._paged_block_size) - - # OffsetCache is a fake cache — no KV stored. The offset value - # only matters for make_mask(); for single-token decode make_mask(1) - # returns None regardless, so a shared max_offset is fine. Actual - # per-request RoPE offsets come from ctx.offsets in the wrapper. - max_offset = max(info[1] for info in requests_info) - offset_caches = [OffsetCache(max_offset) for _ in range(self.num_layers)] - - # Build batched input - if self._rust_state_manager is not None: - last_tokens = self._rust_state_manager.get_last_tokens_batch( - [req_id for req_id, _ in decode_reqs] - ) - else: - last_tokens = [ - state.token_ids[-1] if state.token_ids else 0 - for _, state in decode_reqs - ] - - batched_input = mx.array(last_tokens, dtype=mx.int32)[:, None] - - # The model forward calls each layer's self_attn, which has been - # replaced by MetalKernelPagedAttentionWrapper. The wrapper: - # - ignores cache= (OffsetCache) for KV storage - # - reads get_context() for block_tables, slot_mapping, offsets - # - applies per-request RoPE using ctx.offsets - # - writes new K/V to MPS paged cache via reshape_and_cache - # - reads all cached K/V via paged_attention_v1 (zero-copy) - try: - model_output = self.model(batched_input, cache=offset_caches) - logits = self._extract_logits(model_output) - next_token_logits = logits[:, -1, :] - finally: - clear_context() - - # Sample - sampling_params_list = [state.sampling_params for _, state in decode_reqs] - all_greedy = all(sp.temperature < 1e-5 for sp in sampling_params_list) - any_advanced = any( - sp.top_k > 0 - or sp.top_p < 1.0 - or sp.frequency_penalty != 0 - or sp.presence_penalty != 0 - or sp.repetition_penalty != 1.0 - for sp in sampling_params_list - ) - - if all_greedy and not any_advanced: - next_tokens_mlx = _mlx_greedy_sample(next_token_logits) - mx.eval(next_tokens_mlx) - next_tokens: list[int] = next_tokens_mlx.tolist() - else: - mx.eval(next_token_logits) - prompt_token_ids_list = [ - state.token_ids[: state.prompt_len] for _, state in decode_reqs - ] - output_tokens_list = [ - state.token_ids[state.prompt_len :] for _, state in decode_reqs - ] - generators = { - i: state.generator - for i, (_, state) in enumerate(decode_reqs) - if state.generator is not None - } - logits_torch = mlx_to_torch( - next_token_logits.astype(mx.float32), device=self.device - ) - metadata = self._make_sampling_metadata( - sampling_params_list, - prompt_token_ids_list, - output_tokens_list, - generators=generators, - ) - output = self._sampler.forward(logits_torch, metadata) - next_tokens = [ - int(output.sampled_token_ids[i, 0].item()) for i in range(batch_size) - ] - - # Update state - for i, (req_id, state) in enumerate(decode_reqs): - state.token_ids.append(next_tokens[i]) - state.generated_tokens += 1 - self._paged_request_seq_lens[req_id] = ( - self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 2) + 1 - ) - if self._rust_state_manager is not None: - self._rust_state_manager.append_token(req_id, next_tokens[i]) - - return next_tokens - def execute_model( self, scheduler_output: SchedulerOutput ) -> ModelRunnerOutput | None: @@ -2287,62 +2176,36 @@ def execute_model( # === Unified forward: complete prefills + decode === if paged_decode_reqs or paged_complete: - total_prefill_tokens = sum( - len(entry[2]) for entry in paged_complete + prefill_pack = [ + (rid, tids, sp, bids, gen, None) + for _, rid, tids, sp, bids, gen in paged_complete + ] + prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( + prefill_pack, paged_decode_reqs ) - total_tokens = total_prefill_tokens + len(paged_decode_reqs) - if total_tokens <= MAX_PACKED_PREFILL_TOKENS: - # Build pack input for prefill requests - prefill_pack = [ - (rid, tids, sp, bids, gen, None) - for _, rid, tids, sp, bids, gen in paged_complete - ] - prefill_tokens, decode_tokens = ( - self._unified_prefill_decode_paged( - prefill_pack, paged_decode_reqs - ) + # Record prefill results + create state + for i, (idx, rid, tids, sp, bids, gen) in enumerate(paged_complete): + nt = prefill_tokens[i] + sampled_tokens[idx] = [nt] + self._request_states[rid] = RequestState( + token_ids=list(tids) + [nt], + prompt_len=len(tids), + cache=[], + sampling_params=sp, + generator=gen, + generated_tokens=1, + block_ids=bids, ) + if self._rust_state_manager is not None: + self._rust_state_manager.add_request(rid, list(tids) + [nt]) + paged_complete = [] # consumed - # Record prefill results + create state - for i, (idx, rid, tids, sp, bids, gen) in enumerate( - paged_complete - ): - nt = prefill_tokens[i] - sampled_tokens[idx] = [nt] - self._request_states[rid] = RequestState( - token_ids=list(tids) + [nt], - prompt_len=len(tids), - cache=[], - sampling_params=sp, - generator=gen, - generated_tokens=1, - block_ids=bids, - ) - if self._rust_state_manager is not None: - self._rust_state_manager.add_request( - rid, list(tids) + [nt] - ) - paged_complete = [] # consumed - - # Record decode results - for i, (req_id, _) in enumerate(paged_decode_reqs): - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([decode_tokens[i]]) - else: - # Fallback: separate passes for large batches - if paged_complete: - self._run_packed_prefill(paged_complete, sampled_tokens) - paged_complete = [] # consumed - if paged_decode_reqs: - decode_tokens = self._batched_decode_paged( - paged_decode_reqs - ) - for i, (req_id, _) in enumerate(paged_decode_reqs): - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([decode_tokens[i]]) + # Record decode results + for i, (req_id, _) in enumerate(paged_decode_reqs): + req_ids.append(req_id) + req_id_to_index[req_id] = len(req_ids) - 1 + sampled_tokens.append([decode_tokens[i]]) else: # Collect all valid decode requests valid_decode_reqs = [] From 909e16f43124d2fa8f705a3386bfa5bed6b6b7cd Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 19 Mar 2026 16:27:58 -0500 Subject: [PATCH 03/12] unified everything, multi-arm post-processing Signed-off-by: ran --- tests/test_paged_attention.py | 23 +- .../metal_kernel_backend/paged_attention.py | 2 +- vllm_metal/paged_attention_common.py | 45 +- vllm_metal/v1/model_runner.py | 418 ++++++------------ 4 files changed, 137 insertions(+), 351 deletions(-) diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index 2dee33c9..b1d3a275 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -11,7 +11,6 @@ OffsetCache, clear_context, get_context, - prepare_prefill_packed, prepare_unified, ) @@ -34,9 +33,9 @@ class TestPrepare: def teardown_method(self): clear_context() - def test_prepare_prefill_single_request(self): - # Single request via prepare_prefill_packed - prepare_prefill_packed([([10, 11], 5)], block_size=4) + def test_prepare_unified_prefill_single(self): + # Single prefill request via prepare_unified + prepare_unified([], [([10, 11], 5)], block_size=4) ctx = get_context() # block 10: slots 40,41,42,43; block 11: slot 44 @@ -46,11 +45,11 @@ def test_prepare_prefill_single_request(self): assert ctx.block_tables == [[10, 11]] assert ctx.context_lens == [5] assert ctx.cu_seqlens == [0, 5] + assert ctx.offsets == [0] - def test_prepare_prefill_packed_slot_mapping(self): - # Two requests: 3 tokens in block 10, 2 tokens in block 20 - requests = [([10], 3), ([20], 2)] - prepare_prefill_packed(requests, block_size=4) + def test_prepare_unified_prefill_packed(self): + # Two prefill requests packed together + prepare_unified([], [([10], 3), ([20], 2)], block_size=4) ctx = get_context() assert ctx is not None @@ -61,11 +60,11 @@ def test_prepare_prefill_packed_slot_mapping(self): assert ctx.cu_seqlens == [0, 3, 5] assert ctx.block_tables == [[10], [20]] assert ctx.context_lens == [3, 2] + assert ctx.offsets == [0, 0] - def test_prepare_prefill_packed_single_request(self): - # Single request through packed path should produce valid metadata - requests = [([5, 6], 5)] - prepare_prefill_packed(requests, block_size=4) + def test_prepare_unified_prefill_multiblock(self): + # Single prefill spanning two blocks + prepare_unified([], [([5, 6], 5)], block_size=4) ctx = get_context() assert ctx is not None diff --git a/vllm_metal/metal_kernel_backend/paged_attention.py b/vllm_metal/metal_kernel_backend/paged_attention.py index c59ce563..936ddfa7 100644 --- a/vllm_metal/metal_kernel_backend/paged_attention.py +++ b/vllm_metal/metal_kernel_backend/paged_attention.py @@ -11,7 +11,7 @@ All operations use MLX arrays end-to-end — no PyTorch MPS bridge. Reuses ``PagedAttentionContext``, ``OffsetCache``, ``prepare_unified``, -``prepare_prefill_packed``, ``clear_context`` from ``paged_attention_common``. +``clear_context`` from ``paged_attention_common``. Backend replacement guide ------------------------- diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index 994f4224..4ce7622f 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -5,8 +5,7 @@ both the Metal kernel paged attention backend and the model runner. Usage: - 1. Before each forward pass call ``prepare_unified()`` (or - ``prepare_prefill_packed()`` for prefill-only intermediate chunks) + 1. Before each forward pass call ``prepare_unified()`` 2. Run ``model(input_ids, cache=offset_caches)`` as normal 3. The attention wrapper reads ``get_context()`` for paged metadata 4. Call ``clear_context()`` after the forward pass @@ -27,7 +26,7 @@ # Thread-local storage used to pass per-request metadata (slot_mapping, # block_tables, etc.) to attention wrappers buried inside the model. # We cannot add extra arguments to the mlx_lm forward signature, so -# instead: prepare_prefill_packed/decode() stashes context here before the +# instead: prepare_unified() stashes context here before the # forward pass, each attention wrapper reads it via get_context(), and # clear_context() cleans up afterwards. _thread_local = threading.local() @@ -153,46 +152,6 @@ def find_layers_and_attr(model: Any) -> tuple[list[Any], str]: # --------------------------------------------------------------------------- -def prepare_prefill_packed( - requests: list[tuple[list[int], int]], - block_size: int, -) -> None: - """Compute slot_mapping, cu_seqlens, block_tables, context_lens for prefill. - - Packs one or more prefill requests into a single forward pass. The - varlen Metal kernel uses ``cu_seqlens`` to locate each sequence's - query tokens and ``block_tables`` / ``context_lens`` to read K/V - from the paged cache. - - Args: - requests: list of (block_ids, num_tokens) per request. - block_size: tokens per block. - """ - slot_mapping: list[int] = [] - cu_seqlens: list[int] = [0] - block_tables: list[list[int]] = [] - context_lens: list[int] = [] - - for block_ids, num_tokens in requests: - for pos in range(num_tokens): - block_idx = block_ids[pos // block_size] - slot = block_idx * block_size + (pos % block_size) - slot_mapping.append(slot) - cu_seqlens.append(cu_seqlens[-1] + num_tokens) - block_tables.append(block_ids) - context_lens.append(num_tokens) - - set_context( - PagedAttentionContext( - is_prefill=True, - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - cu_seqlens=cu_seqlens, - ) - ) - - def prepare_unified( decode_requests: list[tuple[list[int], int]], prefill_requests: list[tuple[list[int], int]], diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index d365d458..632b2f2f 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -52,7 +52,6 @@ from vllm_metal.paged_attention_common import ( OffsetCache, clear_context, - prepare_prefill_packed, prepare_unified, ) from vllm_metal.pytorch_backend.tensor_bridge import mlx_to_torch @@ -1592,159 +1591,6 @@ def _sequential_decode( return next_tokens # ------------------------------------------------------------------ - # Paged attention paths - # ------------------------------------------------------------------ - - def _prefill_packed_paged( - self, - pack_reqs: list[ - tuple[ - str, - list[int], - SamplingParams, - list[int], - torch.Generator | None, - int | None, - ] - ], - ) -> list[int]: - """Packed paged-attention prefill for multiple requests. - - Concatenates token_ids from all requests into a single forward - pass using ``cu_seqlens`` to build a block-diagonal causal mask. - This avoids the overhead of N separate forward passes. - - Args: - pack_reqs: list of - (req_id, token_ids, sampling_params, block_ids, - generator, prompt_len) tuples. - - Returns: - List of sampled next tokens, one per request. - """ - # Build packed input - all_token_ids: list[int] = [] - block_requests: list[tuple[list[int], int]] = [] - for _, token_ids, _, block_ids, _, _ in pack_reqs: - all_token_ids.extend(token_ids) - block_requests.append((block_ids, len(token_ids))) - - # Stash packed context (slot_mapping + cu_seqlens) - prepare_prefill_packed(block_requests, self._paged_block_size) - - offset_caches = [OffsetCache(0) for _ in range(self.num_layers)] - input_ids = mx.array([all_token_ids], dtype=mx.int32) - try: - model_output = self.model(input_ids, cache=offset_caches) - logits = self._extract_logits(model_output) - finally: - clear_context() - - # Extract per-request last-token logits and sample - cu_seqlens = [0] - for _, token_ids, _, _, _, _ in pack_reqs: - cu_seqlens.append(cu_seqlens[-1] + len(token_ids)) - - next_tokens: list[int] = [] - for i, ( - req_id, - token_ids, - sampling_params, - _, - generator, - prompt_len, - ) in enumerate(pack_reqs): - last_idx = cu_seqlens[i + 1] - 1 - last_logits = logits[:, last_idx : last_idx + 1, :] - - if prompt_len is None: - prompt_len = len(token_ids) - - is_greedy = sampling_params.temperature < 1e-5 - needs_advanced = ( - sampling_params.top_k > 0 - or sampling_params.top_p < 1.0 - or sampling_params.frequency_penalty != 0 - or sampling_params.presence_penalty != 0 - or sampling_params.repetition_penalty != 1.0 - ) - - if is_greedy and not needs_advanced: - next_token_mlx = _mlx_greedy_sample(last_logits[0]) - mx.eval(next_token_mlx) - next_token = int(next_token_mlx.item()) - else: - mx.eval(last_logits) - logits_torch = mlx_to_torch( - last_logits[0].astype(mx.float32), device=self.device - ) - generators = {} if generator is None else {0: generator} - metadata = self._make_sampling_metadata( - [sampling_params], - [token_ids[:prompt_len]], - [token_ids[prompt_len:]], - generators=generators, - ) - output = self._sampler.forward(logits_torch, metadata) - next_token = int(output.sampled_token_ids[0, 0].item()) - - self._paged_request_seq_lens[req_id] = len(token_ids) - next_tokens.append(next_token) - - return next_tokens - - def _run_packed_prefill( - self, - paged_complete: list[ - tuple[ - int, - str, - list[int], - SamplingParams, - list[int], - torch.Generator | None, - ] - ], - sampled_tokens: list[list[int]], - ) -> None: - """Batch, dispatch, and write back state for packed paged prefill. - - Splits *paged_complete* into batches that fit within - ``MAX_PACKED_PREFILL_TOKENS``, runs each batch through - ``_prefill_packed_paged``, and fills *sampled_tokens* in-place. - """ - # Split into batches that fit within the packed-length cap. - batches: list[list[tuple]] = [[]] - batch_tokens = 0 - for entry in paged_complete: - entry_tokens = len(entry[2]) # token_ids - if batch_tokens + entry_tokens > MAX_PACKED_PREFILL_TOKENS and batches[-1]: - batches.append([]) - batch_tokens = 0 - batches[-1].append(entry) - batch_tokens += entry_tokens - - for batch in batches: - pack_input = [ - (rid, tids, sp, bids, gen, None) - for _, rid, tids, sp, bids, gen in batch - ] - next_tokens = self._prefill_packed_paged(pack_input) - for i, (idx, rid, tids, sp, bids, gen) in enumerate(batch): - nt = next_tokens[i] - sampled_tokens[idx] = [nt] - self._request_states[rid] = RequestState( - token_ids=list(tids) + [nt], - prompt_len=len(tids), - cache=[], - sampling_params=sp, - generator=gen, - generated_tokens=1, - block_ids=bids, - ) - if self._rust_state_manager is not None: - self._rust_state_manager.add_request(rid, list(tids) + [nt]) - # ------------------------------------------------------------------ # Unified prefill + decode (single forward pass) # ------------------------------------------------------------------ @@ -1962,17 +1808,31 @@ def execute_model( req_id_to_index: dict[str, int] = {} sampled_tokens: list[list[int]] = [] - # === PHASE 1: Process new requests (prefill phase) === + # === Collect all requests into one unified batch === new_reqs = scheduler_output.scheduled_new_reqs + cached_reqs = scheduler_output.scheduled_cached_reqs - # First pass: handle intermediate chunks immediately, collect - # complete paged prefill requests for potential packing. - paged_complete: list[ + # Paged-attention entries collected for the single unified forward. + # Each prefill entry: (output_idx, req_id, token_ids, sampling_params, + # block_ids, generator, entry_type, prompt_len) + # entry_type is one of: "new_intermediate", "new_complete", + # "cached_intermediate", "cached_last_chunk" + _ENTRY_TYPE = str # noqa: N806 — alias for readability + paged_prefill_entries: list[ tuple[ - int, str, list[int], SamplingParams, list[int], torch.Generator | None + int, + str, + list[int], + SamplingParams, + list[int], + torch.Generator | None, + _ENTRY_TYPE, + int, ] ] = [] + paged_decode_reqs: list[tuple[str, RequestState]] = [] + # --- New requests --- for new_req in new_reqs: req_id = new_req.req_id token_ids = new_req.prompt_token_ids or [] @@ -1983,40 +1843,40 @@ def execute_model( req_id_to_index[req_id] = output_idx if not token_ids: - sampled_tokens.append([0]) # Fallback + sampled_tokens.append([0]) continue generator = _create_request_generator(self.device, sampling_params) if self._paged_kv_cache is not None: - # Paged attention path (Metal kernel) sched_block_ids = list(new_req.block_ids[0]) scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) computed_tokens = new_req.num_computed_tokens prompt_len = len(token_ids) + cur_len = computed_tokens + scheduled_tokens + is_intermediate = cur_len < prompt_len - if computed_tokens + scheduled_tokens < prompt_len: - # Intermediate chunk: sample then drop (async scheduler - # allocates no placeholder for intermediate chunks). - cur_len = computed_tokens + scheduled_tokens - _discarded = self._prefill_packed_paged( - [ - ( - req_id, - token_ids[:cur_len], - sampling_params, - sched_block_ids, - generator, - None, - ), - ] - )[0] - cache: list = [] - sampled_tokens.append([]) + sampled_tokens.append([]) # placeholder + paged_prefill_entries.append( + ( + output_idx, + req_id, + token_ids[:cur_len] if is_intermediate else token_ids, + sampling_params, + sched_block_ids, + generator, + "new_intermediate" if is_intermediate else "new_complete", + prompt_len, + ) + ) + + # Create state immediately for intermediate chunks + # (needed if the request appears as cached next step). + if is_intermediate: self._request_states[req_id] = RequestState( token_ids=list(token_ids), prompt_len=prompt_len, - cache=cache, + cache=[], sampling_params=sampling_params, generator=generator, generated_tokens=0, @@ -2026,20 +1886,6 @@ def execute_model( self._rust_state_manager.add_request( req_id, list(token_ids[:cur_len]) ) - continue - - # Complete prefill — collect for packed processing - sampled_tokens.append([]) # placeholder, filled below - paged_complete.append( - ( - output_idx, - req_id, - token_ids, - sampling_params, - sched_block_ids, - generator, - ) - ) else: next_token, cache = self._prefill_single( req_id, @@ -2062,21 +1908,14 @@ def execute_model( req_id, list(token_ids) + [next_token] ) - # Defer paged_complete to unified forward pass during PHASE 2. - # If no cached requests exist, it is consumed after PHASE 2. - - # === PHASE 2: Process cached requests (TRUE batched decode) === - cached_reqs = scheduler_output.scheduled_cached_reqs + # --- Cached requests --- decode_req_ids = list(cached_reqs.req_ids) if decode_req_ids: if self._paged_kv_cache is not None: - # Paged attention path: unified flow using model-runner-local - # state (state.generated_tokens) instead of is_context_phase(). req_id_to_cached_idx = { rid: i for i, rid in enumerate(cached_reqs.req_ids) } - paged_decode_reqs: list[tuple[str, RequestState]] = [] # Update block_ids from scheduler (append or replace on resume) for i, req_id in enumerate(cached_reqs.req_ids): @@ -2090,13 +1929,6 @@ def execute_model( if new_block_ids is not None: state.block_ids.extend(new_block_ids[0]) else: - # Preempted → full recompute with fresh blocks. - # Keep prompt_len at the original prompt boundary - # (used for sampling penalty split). The prefill - # loop uses len(state.token_ids) — which already - # includes previously generated output tokens — - # to determine the recompute scope, matching - # upstream vLLM's use of request.num_tokens. assert new_block_ids is not None state.block_ids = list(new_block_ids[0]) state.generated_tokens = 0 @@ -2107,17 +1939,17 @@ def execute_model( req_id, list(state.token_ids) ) + # Categorise each cached request for req_id in decode_req_ids: state = self._request_states.get(req_id) if state is None: - # Edge case: no state — emit dummy token req_ids.append(req_id) req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([0]) continue if state.generated_tokens == 0: - # Still prefilling prompt (or re-prefilling after preemption) + # Still prefilling (or re-prefilling after preemption) idx = req_id_to_cached_idx.get(req_id) if idx is not None and idx < len( cached_reqs.num_computed_tokens @@ -2126,88 +1958,38 @@ def execute_model( else: computed = self._paged_request_seq_lens.get(req_id, 0) scheduled = scheduler_output.num_scheduled_tokens.get(req_id, 0) - target_len = computed + scheduled # FIX: was just `computed` - - if target_len < len(state.token_ids): - # Intermediate chunk: sample then drop - _discarded = self._prefill_packed_paged( - [ - ( - req_id, - state.token_ids[:target_len], - state.sampling_params, - state.block_ids, - state.generator, - None, - ), - ] - )[0] - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([]) - else: - # Last chunk: sample and keep (drains async placeholder) - next_token = self._prefill_packed_paged( - [ - ( - req_id, - state.token_ids, - state.sampling_params, - state.block_ids, - state.generator, - state.prompt_len, - ), - ] - )[0] - state.token_ids.append(next_token) - state.generated_tokens = ( - len(state.token_ids) - state.prompt_len - ) - if self._rust_state_manager is not None: - self._rust_state_manager.append_token( - req_id, next_token - ) - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([next_token]) - else: - # Decode phase: collect for batched decode - paged_decode_reqs.append((req_id, state)) + target_len = computed + scheduled + is_intermediate = target_len < len(state.token_ids) - # === Unified forward: complete prefills + decode === - if paged_decode_reqs or paged_complete: - prefill_pack = [ - (rid, tids, sp, bids, gen, None) - for _, rid, tids, sp, bids, gen in paged_complete - ] - prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( - prefill_pack, paged_decode_reqs - ) + req_ids.append(req_id) + output_idx = len(req_ids) - 1 + req_id_to_index[req_id] = output_idx + sampled_tokens.append([]) # placeholder - # Record prefill results + create state - for i, (idx, rid, tids, sp, bids, gen) in enumerate(paged_complete): - nt = prefill_tokens[i] - sampled_tokens[idx] = [nt] - self._request_states[rid] = RequestState( - token_ids=list(tids) + [nt], - prompt_len=len(tids), - cache=[], - sampling_params=sp, - generator=gen, - generated_tokens=1, - block_ids=bids, + paged_prefill_entries.append( + ( + output_idx, + req_id, + ( + state.token_ids[:target_len] + if is_intermediate + else state.token_ids + ), + state.sampling_params, + state.block_ids, + state.generator, + ( + "cached_intermediate" + if is_intermediate + else "cached_last_chunk" + ), + state.prompt_len, + ) ) - if self._rust_state_manager is not None: - self._rust_state_manager.add_request(rid, list(tids) + [nt]) - paged_complete = [] # consumed - - # Record decode results - for i, (req_id, _) in enumerate(paged_decode_reqs): - req_ids.append(req_id) - req_id_to_index[req_id] = len(req_ids) - 1 - sampled_tokens.append([decode_tokens[i]]) + else: + paged_decode_reqs.append((req_id, state)) else: - # Collect all valid decode requests + # Non-paged (MLX) path — unchanged valid_decode_reqs = [] for req_id in decode_req_ids: state = self._request_states.get(req_id) @@ -2220,23 +2002,69 @@ def execute_model( else: decode_tokens = self._sequential_decode(valid_decode_reqs) - # Add decode results to output for i, (req_id, _) in enumerate(valid_decode_reqs): req_ids.append(req_id) req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([decode_tokens[i]]) - # Handle requests with no cached state (edge case) for req_id in decode_req_ids: if req_id not in req_id_to_index: req_ids.append(req_id) req_id_to_index[req_id] = len(req_ids) - 1 sampled_tokens.append([0]) - # Handle any paged_complete not consumed by the unified pass - # (happens when no cached requests exist in this step). - if paged_complete: - self._run_packed_prefill(paged_complete, sampled_tokens) + # === Single unified forward pass (paged path) === + if paged_prefill_entries or paged_decode_reqs: + prefill_pack = [ + (rid, tids, sp, bids, gen, None) + for _, rid, tids, sp, bids, gen, _, _ in paged_prefill_entries + ] + prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( + prefill_pack, paged_decode_reqs + ) + + # Post-process prefill results + for i, ( + idx, + rid, + tids, + sp, + bids, + gen, + entry_type, + prompt_len, + ) in enumerate(paged_prefill_entries): + nt = prefill_tokens[i] + + if entry_type.endswith("_intermediate"): + # KV cache populated; discard sampled token + sampled_tokens[idx] = [] + elif entry_type == "new_complete": + sampled_tokens[idx] = [nt] + self._request_states[rid] = RequestState( + token_ids=list(tids) + [nt], + prompt_len=len(tids), + cache=[], + sampling_params=sp, + generator=gen, + generated_tokens=1, + block_ids=bids, + ) + if self._rust_state_manager is not None: + self._rust_state_manager.add_request(rid, list(tids) + [nt]) + elif entry_type == "cached_last_chunk": + sampled_tokens[idx] = [nt] + state = self._request_states[rid] + state.token_ids.append(nt) + state.generated_tokens = len(state.token_ids) - state.prompt_len + if self._rust_state_manager is not None: + self._rust_state_manager.append_token(rid, nt) + + # Post-process decode results + for i, (req_id, _) in enumerate(paged_decode_reqs): + req_ids.append(req_id) + req_id_to_index[req_id] = len(req_ids) - 1 + sampled_tokens.append([decode_tokens[i]]) # Consistency check: every scheduled request must be represented in # req_ids, and decode-phase scheduled requests should not emit empty From 7358872c66a6044a4bd8742c65f17b12753e7e45 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 19 Mar 2026 16:30:18 -0500 Subject: [PATCH 04/12] fix linter Signed-off-by: ran --- tests/test_unified_batching.py | 2 -- vllm_metal/v1/model_runner.py | 14 +++----------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/test_unified_batching.py b/tests/test_unified_batching.py index 44610290..f7f1035f 100644 --- a/tests/test_unified_batching.py +++ b/tests/test_unified_batching.py @@ -18,8 +18,6 @@ from __future__ import annotations -import os - import pytest from vllm import LLM, SamplingParams diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 632b2f2f..367dc705 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1625,7 +1625,6 @@ def _unified_prefill_decode_paged( ``(prefill_next_tokens, decode_next_tokens)`` """ num_decode = len(decode_reqs) - num_prefill = len(prefill_reqs) # ---- build unified token sequence: decode first, then prefill ---- all_token_ids: list[int] = [] @@ -1817,17 +1816,10 @@ def execute_model( # block_ids, generator, entry_type, prompt_len) # entry_type is one of: "new_intermediate", "new_complete", # "cached_intermediate", "cached_last_chunk" - _ENTRY_TYPE = str # noqa: N806 — alias for readability paged_prefill_entries: list[ tuple[ - int, - str, - list[int], - SamplingParams, - list[int], - torch.Generator | None, - _ENTRY_TYPE, - int, + int, str, list[int], SamplingParams, list[int], + torch.Generator | None, str, int, ] ] = [] paged_decode_reqs: list[tuple[str, RequestState]] = [] @@ -2032,7 +2024,7 @@ def execute_model( bids, gen, entry_type, - prompt_len, + _prompt_len, ) in enumerate(paged_prefill_entries): nt = prefill_tokens[i] From 7cd4cb67b639bc93f4af96bce8c3c4d56fe84e81 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 19 Mar 2026 16:38:21 -0500 Subject: [PATCH 05/12] fix everything Signed-off-by: ran --- vllm_metal/v1/model_runner.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 367dc705..264a738d 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1818,8 +1818,14 @@ def execute_model( # "cached_intermediate", "cached_last_chunk" paged_prefill_entries: list[ tuple[ - int, str, list[int], SamplingParams, list[int], - torch.Generator | None, str, int, + int, + str, + list[int], + SamplingParams, + list[int], + torch.Generator | None, + str, + int, ] ] = [] paged_decode_reqs: list[tuple[str, RequestState]] = [] From e22c3672d22c4a298f4787ae63594d6900fa99b5 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 19 Mar 2026 17:04:55 -0500 Subject: [PATCH 06/12] fix repeative chunked prefilling from O(n2) to O(n) Signed-off-by: ran --- tests/test_paged_attention.py | 28 +++++++++++++++++---- vllm_metal/paged_attention_common.py | 16 ++++++------ vllm_metal/v1/model_runner.py | 37 ++++++++++++++++------------ 3 files changed, 53 insertions(+), 28 deletions(-) diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index b1d3a275..454ca8c1 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -34,8 +34,8 @@ def teardown_method(self): clear_context() def test_prepare_unified_prefill_single(self): - # Single prefill request via prepare_unified - prepare_unified([], [([10, 11], 5)], block_size=4) + # Single prefill request via prepare_unified (start_pos=0) + prepare_unified([], [([10, 11], 5, 0)], block_size=4) ctx = get_context() # block 10: slots 40,41,42,43; block 11: slot 44 @@ -49,7 +49,7 @@ def test_prepare_unified_prefill_single(self): def test_prepare_unified_prefill_packed(self): # Two prefill requests packed together - prepare_unified([], [([10], 3), ([20], 2)], block_size=4) + prepare_unified([], [([10], 3, 0), ([20], 2, 0)], block_size=4) ctx = get_context() assert ctx is not None @@ -64,7 +64,7 @@ def test_prepare_unified_prefill_packed(self): def test_prepare_unified_prefill_multiblock(self): # Single prefill spanning two blocks - prepare_unified([], [([5, 6], 5)], block_size=4) + prepare_unified([], [([5, 6], 5, 0)], block_size=4) ctx = get_context() assert ctx is not None @@ -74,6 +74,24 @@ def test_prepare_unified_prefill_multiblock(self): assert ctx.block_tables == [[5, 6]] assert ctx.context_lens == [5] + def test_prepare_unified_continuation_chunk(self): + # Continuation chunk: 3 new tokens starting at position 4 + # block 10 has slots 40-43 (positions 0-3, already cached), + # block 11 has slots 44-47 (positions 4-6 are the new tokens) + prepare_unified([], [([10, 11], 3, 4)], block_size=4) + ctx = get_context() + + assert ctx is not None + # Only 3 tokens in the query (positions 4, 5, 6) + assert ctx.cu_seqlens == [0, 3] + # Slots for positions 4, 5, 6: block 11 slots 44, 45, 46 + assert ctx.slot_mapping == [44, 45, 46] + assert ctx.block_tables == [[10, 11]] + # Total context = start_pos + num_tokens = 4 + 3 = 7 + assert ctx.context_lens == [7] + # RoPE offset = start_pos + assert ctx.offsets == [4] + def test_prepare_unified_decode_only(self): # Single decode request via prepare_unified decode_requests = [([5, 6], 7)] @@ -91,7 +109,7 @@ def test_prepare_unified_decode_only(self): def test_prepare_unified_mixed(self): # 1 decode + 1 prefill decode_requests = [([5, 6], 7)] # seq_len=7 - prefill_requests = [([10, 11], 5)] # 5 tokens + prefill_requests = [([10, 11], 5, 0)] # 5 tokens from position 0 prepare_unified(decode_requests, prefill_requests, block_size=4) ctx = get_context() diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index 4ce7622f..ef6ac345 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -154,7 +154,7 @@ def find_layers_and_attr(model: Any) -> tuple[list[Any], str]: def prepare_unified( decode_requests: list[tuple[list[int], int]], - prefill_requests: list[tuple[list[int], int]], + prefill_requests: list[tuple[list[int], int, int]], block_size: int, ) -> None: """Compute metadata for a unified prefill + decode forward pass. @@ -167,7 +167,9 @@ def prepare_unified( Args: decode_requests: list of ``(block_ids, seq_len)`` for decode requests. ``seq_len`` = tokens already cached before this step. - prefill_requests: list of ``(block_ids, num_tokens)`` for prefill. + prefill_requests: list of ``(block_ids, num_tokens, start_pos)`` for + prefill. ``start_pos`` is the position of the first token in this + chunk (0 for a fresh prefill, >0 for continuation chunks). block_size: tokens per KV cache block. """ slot_mapping: list[int] = [] @@ -187,16 +189,16 @@ def prepare_unified( context_lens.append(seq_len + 1) # including new token offsets.append(seq_len) # RoPE position - # Prefill requests (variable tokens each) - for block_ids, num_tokens in prefill_requests: - for pos in range(num_tokens): + # Prefill requests (variable tokens each, starting at start_pos) + for block_ids, num_tokens, start_pos in prefill_requests: + for pos in range(start_pos, start_pos + num_tokens): block_idx = block_ids[pos // block_size] slot = block_idx * block_size + (pos % block_size) slot_mapping.append(slot) cu_seqlens.append(cu_seqlens[-1] + num_tokens) block_tables.append(block_ids) - context_lens.append(num_tokens) - offsets.append(0) # prefill starts at position 0 + context_lens.append(start_pos + num_tokens) + offsets.append(start_pos) set_context( PagedAttentionContext( diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 264a738d..dbf68a3d 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1605,6 +1605,7 @@ def _unified_prefill_decode_paged( list[int], torch.Generator | None, int | None, + int, ] ], decode_reqs: list[tuple[str, RequestState]], @@ -1618,7 +1619,9 @@ def _unified_prefill_decode_paged( Args: prefill_reqs: list of ``(req_id, token_ids, sampling_params, block_ids, - generator, prompt_len)`` — complete prefill requests. + generator, prompt_len, start_pos)`` — prefill requests. + ``start_pos`` is the RoPE offset / KV slot start (0 for + fresh prefill, >0 for continuation chunks). decode_reqs: list of ``(req_id, RequestState)`` — decode requests. Returns: @@ -1641,8 +1644,8 @@ def _unified_prefill_decode_paged( ] all_token_ids.extend(last_tokens) - # Prefill: all tokens per request - for _, token_ids, _, _, _, _ in prefill_reqs: + # Prefill: tokens per request + for _, token_ids, _, _, _, _, _ in prefill_reqs: all_token_ids.extend(token_ids) # ---- build metadata for prepare_unified ---- @@ -1651,9 +1654,9 @@ def _unified_prefill_decode_paged( seq_len = self._paged_request_seq_lens.get(req_id, len(state.token_ids) - 1) decode_info.append((state.block_ids, seq_len)) - prefill_info: list[tuple[list[int], int]] = [] - for _, token_ids, _, block_ids, _, _ in prefill_reqs: - prefill_info.append((block_ids, len(token_ids))) + prefill_info: list[tuple[list[int], int, int]] = [] + for _, token_ids, _, block_ids, _, _, start_pos in prefill_reqs: + prefill_info.append((block_ids, len(token_ids), start_pos)) prepare_unified(decode_info, prefill_info, self._paged_block_size) @@ -1670,7 +1673,7 @@ def _unified_prefill_decode_paged( cu_seqlens: list[int] = [0] for _ in decode_reqs: cu_seqlens.append(cu_seqlens[-1] + 1) - for _, token_ids, _, _, _, _ in prefill_reqs: + for _, token_ids, _, _, _, _, _ in prefill_reqs: cu_seqlens.append(cu_seqlens[-1] + len(token_ids)) # ---- sample decode tokens ---- @@ -1742,6 +1745,7 @@ def _unified_prefill_decode_paged( _block_ids, generator, prompt_len, + _start_pos, ) in enumerate(prefill_reqs): last_idx = cu_seqlens[num_decode + j + 1] - 1 last_logits = logits[:, last_idx : last_idx + 1, :] @@ -1813,7 +1817,8 @@ def execute_model( # Paged-attention entries collected for the single unified forward. # Each prefill entry: (output_idx, req_id, token_ids, sampling_params, - # block_ids, generator, entry_type, prompt_len) + # block_ids, generator, entry_type, prompt_len, + # start_pos) # entry_type is one of: "new_intermediate", "new_complete", # "cached_intermediate", "cached_last_chunk" paged_prefill_entries: list[ @@ -1826,6 +1831,7 @@ def execute_model( torch.Generator | None, str, int, + int, ] ] = [] paged_decode_reqs: list[tuple[str, RequestState]] = [] @@ -1859,12 +1865,13 @@ def execute_model( ( output_idx, req_id, - token_ids[:cur_len] if is_intermediate else token_ids, + token_ids[computed_tokens:cur_len], sampling_params, sched_block_ids, generator, "new_intermediate" if is_intermediate else "new_complete", prompt_len, + computed_tokens, # start_pos / RoPE offset ) ) @@ -1968,11 +1975,7 @@ def execute_model( ( output_idx, req_id, - ( - state.token_ids[:target_len] - if is_intermediate - else state.token_ids - ), + state.token_ids[computed:target_len], state.sampling_params, state.block_ids, state.generator, @@ -1982,6 +1985,7 @@ def execute_model( else "cached_last_chunk" ), state.prompt_len, + computed, # start_pos / RoPE offset ) ) else: @@ -2014,8 +2018,8 @@ def execute_model( # === Single unified forward pass (paged path) === if paged_prefill_entries or paged_decode_reqs: prefill_pack = [ - (rid, tids, sp, bids, gen, None) - for _, rid, tids, sp, bids, gen, _, _ in paged_prefill_entries + (rid, tids, sp, bids, gen, None, start_pos) + for _, rid, tids, sp, bids, gen, _, _, start_pos in paged_prefill_entries ] prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( prefill_pack, paged_decode_reqs @@ -2031,6 +2035,7 @@ def execute_model( gen, entry_type, _prompt_len, + _start_pos, ) in enumerate(paged_prefill_entries): nt = prefill_tokens[i] From a31d8547b9fb5b974ff3d69b1e56b05698c928fe Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 19 Mar 2026 22:50:32 -0500 Subject: [PATCH 07/12] delete scaffolding test for batch infer Signed-off-by: ran --- tests/test_unified_batching.py | 103 --------------------------------- 1 file changed, 103 deletions(-) delete mode 100644 tests/test_unified_batching.py diff --git a/tests/test_unified_batching.py b/tests/test_unified_batching.py deleted file mode 100644 index f7f1035f..00000000 --- a/tests/test_unified_batching.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Smoke test for unified prefill+decode forward pass (continuous batching). - -Runs vLLM offline inference with max_num_seqs > 1 so the scheduler batches -multiple requests together, triggering the unified forward pass where prefill -and decode happen in a single model call. - -Due to floating-point non-determinism when batching on Metal (MLX GEMM uses -different internal kernels for different batch sizes), exact golden-token -matching is NOT expected. Instead, this test: - 1. Verifies all requests complete without errors. - 2. Prints the generated text for manual inspection (not gibberish). - 3. Optionally checks whether outputs still match the single-request golden. - -Run: - python -m pytest tests/test_unified_batching.py -v -s -""" - -from __future__ import annotations - -import pytest -from vllm import LLM, SamplingParams - -MODEL_NAME = "Qwen/Qwen3-0.6B" -MAX_TOKENS = 10 -MAX_NUM_SEQS = 4 # key: allow concurrent requests - -PROMPTS = [ - "The capital of France is", - "The weather today is not", - "One plus one equals", - "The largest planet in our solar system is", - "Water boils at a temperature of", - "Machine learning is", -] - -# fmt: off -# Golden from max_num_seqs=1 (single-request, deterministic). -# Used only for informational comparison — NOT asserted. -GOLDEN_SINGLE = { - "The capital of France is": [12095, 13, 576, 6722, 315, 9625, 374, 1083, 279, 6722], - "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], - "One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11], - "The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13], - "Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315], - "Machine learning is": [264, 7988, 5392, 429, 702, 13791, 1506, 279, 2070, 315], -} -# fmt: on - - -@pytest.fixture(autouse=True, scope="module") -def _set_env(): - with pytest.MonkeyPatch.context() as mp: - mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") - mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") - yield - - -@pytest.fixture(scope="module") -def vllm_outputs(): - """Run vLLM offline inference with concurrent batching.""" - llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=MAX_NUM_SEQS) - - # Verify paged KV + attention wrapper are active - runner = llm.llm_engine.model_executor.driver_worker.model_runner - assert runner._paged_kv_cache is not None, "Paged KV cache not initialised" - - from vllm_metal.metal_kernel_backend.paged_attention import ( - MetalKernelPagedAttentionWrapper, - ) - - attn = runner.model.model.layers[0].self_attn - assert isinstance(attn, MetalKernelPagedAttentionWrapper) - - sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) - outputs = llm.generate(PROMPTS, sp) - return {o.prompt: o for o in outputs} - - -class TestUnifiedBatching: - @pytest.mark.slow - @pytest.mark.parametrize("prompt", PROMPTS) - def test_generate_coherent(self, vllm_outputs, prompt): - """Verify output is non-empty and print for manual inspection.""" - output = vllm_outputs[prompt] - token_ids = list(output.outputs[0].token_ids) - text = output.outputs[0].text - - golden = GOLDEN_SINGLE.get(prompt, []) - match = token_ids == golden - - print(f"\n prompt: {prompt!r}") - print(f" output: {text!r}") - print(f" ids: {token_ids}") - print(f" golden: {golden}") - print(f" match: {'YES' if match else 'no (expected with batching)'}") - - # Basic sanity: output should not be empty - assert len(token_ids) == MAX_TOKENS, ( - f"Expected {MAX_TOKENS} tokens, got {len(token_ids)}" - ) - assert len(text.strip()) > 0, "Generated text is empty" From 23c4c20a108d9e5c5a099ac356de83dbcc7b3686 Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 20 Mar 2026 01:58:20 -0500 Subject: [PATCH 08/12] fix sampling bug, remove is_prefilling dead code Signed-off-by: ran --- tests/test_paged_attention.py | 3 --- vllm_metal/paged_attention_common.py | 2 -- vllm_metal/v1/model_runner.py | 16 ++++++++++++++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index 454ca8c1..30b45bfb 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -40,7 +40,6 @@ def test_prepare_unified_prefill_single(self): # block 10: slots 40,41,42,43; block 11: slot 44 assert ctx is not None - assert ctx.is_prefill assert ctx.slot_mapping == [40, 41, 42, 43, 44] assert ctx.block_tables == [[10, 11]] assert ctx.context_lens == [5] @@ -53,7 +52,6 @@ def test_prepare_unified_prefill_packed(self): ctx = get_context() assert ctx is not None - assert ctx.is_prefill # Request 0: block 10, slots 40,41,42 # Request 1: block 20, slots 80,81 assert ctx.slot_mapping == [40, 41, 42, 80, 81] @@ -100,7 +98,6 @@ def test_prepare_unified_decode_only(self): # new_pos=7, block_ids[7//4]=block_ids[1]=6, slot=6*4+(7%4)=27 assert ctx is not None - assert ctx.is_prefill # unified always sets True assert ctx.slot_mapping == [27] assert ctx.context_lens == [8] assert ctx.offsets == [7] diff --git a/vllm_metal/paged_attention_common.py b/vllm_metal/paged_attention_common.py index ef6ac345..05661d82 100644 --- a/vllm_metal/paged_attention_common.py +++ b/vllm_metal/paged_attention_common.py @@ -41,7 +41,6 @@ class PagedAttentionContext: into a single flat sequence). """ - is_prefill: bool # kept for compatibility; always True slot_mapping: list[int] block_tables: list[list[int]] = field(default_factory=list) context_lens: list[int] = field(default_factory=list) @@ -202,7 +201,6 @@ def prepare_unified( set_context( PagedAttentionContext( - is_prefill=True, # use varlen code path slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index dbf68a3d..eb838ce6 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -2018,8 +2018,16 @@ def execute_model( # === Single unified forward pass (paged path) === if paged_prefill_entries or paged_decode_reqs: prefill_pack = [ - (rid, tids, sp, bids, gen, None, start_pos) - for _, rid, tids, sp, bids, gen, _, _, start_pos in paged_prefill_entries + ( + rid, + tids, + sp, + bids, + gen, + prompt_len if not entry_type.endswith("_intermediate") else None, + start_pos, + ) + for _, rid, tids, sp, bids, gen, entry_type, prompt_len, start_pos in paged_prefill_entries ] prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( prefill_pack, paged_decode_reqs @@ -2043,6 +2051,10 @@ def execute_model( # KV cache populated; discard sampled token sampled_tokens[idx] = [] elif entry_type == "new_complete": + assert _start_pos == 0, ( + "new_complete with start_pos > 0 not supported " + "(prefix caching not yet implemented in unified path)" + ) sampled_tokens[idx] = [nt] self._request_states[rid] = RequestState( token_ids=list(tids) + [nt], From 7d3a570e32553719fc3dc5fcd6883920be7b2ab0 Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 20 Mar 2026 02:43:51 -0500 Subject: [PATCH 09/12] delete bad comment lines Signed-off-by: ran --- vllm_metal/v1/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index eb838ce6..6ab028de 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1590,7 +1590,6 @@ def _sequential_decode( return next_tokens - # ------------------------------------------------------------------ # ------------------------------------------------------------------ # Unified prefill + decode (single forward pass) # ------------------------------------------------------------------ From c66a39fc103d74e8adc78585a1aaafbbcb076eb3 Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 20 Mar 2026 02:49:12 -0500 Subject: [PATCH 10/12] remove deadcode prefilling cap Signed-off-by: ran --- vllm_metal/v1/model_runner.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 6ab028de..4f33c57c 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -742,12 +742,6 @@ def _extract_asr_text_tokens(self, tokens: list[int]) -> list[int]: return tokens[start:end] -# Cap total packed-prefill tokens per forward pass to bound activation -# memory (QKV projections + FFN intermediates scale linearly with total -# tokens) and avoid Metal GPU command-buffer timeouts on large dispatches. -MAX_PACKED_PREFILL_TOKENS = 4096 - - class MetalModelRunner: """Model runner for MLX-based inference on Metal. From 55d5e62f8826d41415d06b104197e2535fd7496c Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 20 Mar 2026 02:59:05 -0500 Subject: [PATCH 11/12] refactor the stage string Signed-off-by: ran --- vllm_metal/v1/model_runner.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 4f33c57c..1f057759 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1810,10 +1810,8 @@ def execute_model( # Paged-attention entries collected for the single unified forward. # Each prefill entry: (output_idx, req_id, token_ids, sampling_params, - # block_ids, generator, entry_type, prompt_len, - # start_pos) - # entry_type is one of: "new_intermediate", "new_complete", - # "cached_intermediate", "cached_last_chunk" + # block_ids, generator, is_new, is_intermediate, + # prompt_len, start_pos) paged_prefill_entries: list[ tuple[ int, @@ -1822,7 +1820,8 @@ def execute_model( SamplingParams, list[int], torch.Generator | None, - str, + bool, + bool, int, int, ] @@ -1862,7 +1861,8 @@ def execute_model( sampling_params, sched_block_ids, generator, - "new_intermediate" if is_intermediate else "new_complete", + True, # is_new + is_intermediate, prompt_len, computed_tokens, # start_pos / RoPE offset ) @@ -1972,11 +1972,8 @@ def execute_model( state.sampling_params, state.block_ids, state.generator, - ( - "cached_intermediate" - if is_intermediate - else "cached_last_chunk" - ), + False, # is_new + is_intermediate, state.prompt_len, computed, # start_pos / RoPE offset ) @@ -2017,10 +2014,10 @@ def execute_model( sp, bids, gen, - prompt_len if not entry_type.endswith("_intermediate") else None, + prompt_len if not is_intermediate else None, start_pos, ) - for _, rid, tids, sp, bids, gen, entry_type, prompt_len, start_pos in paged_prefill_entries + for _, rid, tids, sp, bids, gen, _is_new, is_intermediate, prompt_len, start_pos in paged_prefill_entries ] prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( prefill_pack, paged_decode_reqs @@ -2034,18 +2031,19 @@ def execute_model( sp, bids, gen, - entry_type, + is_new, + is_intermediate, _prompt_len, _start_pos, ) in enumerate(paged_prefill_entries): nt = prefill_tokens[i] - if entry_type.endswith("_intermediate"): + if is_intermediate: # KV cache populated; discard sampled token sampled_tokens[idx] = [] - elif entry_type == "new_complete": + elif is_new: assert _start_pos == 0, ( - "new_complete with start_pos > 0 not supported " + "new complete prefill with start_pos > 0 not supported " "(prefix caching not yet implemented in unified path)" ) sampled_tokens[idx] = [nt] @@ -2060,7 +2058,8 @@ def execute_model( ) if self._rust_state_manager is not None: self._rust_state_manager.add_request(rid, list(tids) + [nt]) - elif entry_type == "cached_last_chunk": + else: + # Cached last chunk — append token to existing state sampled_tokens[idx] = [nt] state = self._request_states[rid] state.token_ids.append(nt) From c3beae836e2e63a8c1e5fea5e81aa62321101d7e Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 20 Mar 2026 03:33:38 -0500 Subject: [PATCH 12/12] compensate the merge mistake Signed-off-by: ran --- vllm_metal/v1/model_runner.py | 163 ---------------------------------- 1 file changed, 163 deletions(-) diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index abb4ad78..17ee86dc 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -579,169 +579,6 @@ def _extract_kv_cache( return extracted -# ------------------------------------------------------------------ -# STTExecutor — owns audio feature extraction and decode delegation -# ------------------------------------------------------------------ - - -class STTExecutor: - """Encapsulates STT-specific audio extraction and decoding. - - Holds a lazily-created :class:`WhisperTranscriber` and provides - :meth:`extract_audio_features` and :meth:`decode` so that - :class:`MetalModelRunner` can delegate without STT-specific logic. - """ - - def __init__(self, model: Any, model_path: str) -> None: - self.model = model - self._model_path = model_path - self._transcriber: Any = None - self._model_type: str = getattr(model, "model_type", "whisper") - # Cached Qwen3-ASR special token IDs (resolved once on first use) - self._asr_text_token_id: int | None = None - self._im_end_token_id: int | None = None - - @property - def transcriber(self): - """Lazily-created transcriber (Whisper or Qwen3-ASR).""" - if self._transcriber is None: - if self._model_type == "qwen3_asr": - from vllm_metal.stt.transcribe import Qwen3ASRTranscriber - - self._transcriber = Qwen3ASRTranscriber( - self.model, model_path=self._model_path - ) - else: - from vllm_metal.stt.transcribe import WhisperTranscriber - - self._transcriber = WhisperTranscriber( - self.model, model_path=self._model_path - ) - return self._transcriber - - @property - def eot_token(self) -> int: - """End-of-text token ID resolved from the tokenizer or config.""" - if self._model_type == "qwen3_asr": - return self.model.config.eos_token_id - return self.transcriber.tokenizer.convert_tokens_to_ids("<|endoftext|>") - - def extract_audio_features(self, input_features: Any) -> "mx.array": - """Extract and encode STT input features.""" - # Convert to MLX array — handle numpy, torch, and lists - if isinstance(input_features, np.ndarray): - mel = mx.array(input_features, dtype=mx.float16) - elif isinstance(input_features, torch.Tensor): - # .cpu() for device safety, .float() because bfloat16 has - # no numpy dtype support. - mel = mx.array(input_features.cpu().float().numpy(), dtype=mx.float16) - else: - mel = mx.array(np.array(input_features), dtype=mx.float16) - - if self._model_type == "qwen3_asr": - # Qwen3-ASR encoder expects: (n_mels, time) or (batch, n_mels, time) - # HF WhisperFeatureExtractor output shape is already (n_mels, time) - if mel.ndim == 3: - mel = mel[0] # drop batch dim → (n_mels, time) - elif mel.ndim != 2: - raise ValueError(f"Qwen3-ASR expects 2D or 3D mel, got rank {mel.ndim}") - features = self.model.encode(mel) - mx.eval(features) - return features - else: - # Whisper encoder expects: (batch, time, n_mels) - # HF WhisperFeatureExtractor output shape: (n_mels, time) - if mel.ndim == 2: - mel = mel[None, ...] # add batch dim → (1, n_mels, time) - mel = mel.transpose(0, 2, 1) # → (1, time, n_mels) - elif mel.ndim == 3: - mel = mel.transpose( - 0, 2, 1 - ) # (batch, n_mels, time) → (batch, time, n_mels) - else: - raise ValueError( - f"Unexpected mel spectrogram rank {mel.ndim}; expected 2D or 3D" - ) - - features = self.model.encode(mel) - mx.eval(features) - return features - - def decode( - self, - audio_features: "mx.array", - prompt_token_ids: list[int], - ) -> list[int]: - """Decode audio features into token IDs (ending with EOT). - - Delegates the core decode loop to the transcriber. - - Args: - audio_features: Encoded audio from the encoder. - prompt_token_ids: Prefix tokens (language, task, etc.). - - Returns: - List of decoded token IDs ending with EOT. - """ - eot = self.eot_token - - if self._model_type == "qwen3_asr": - # Qwen3-ASR uses a fixed prompt format — language, task, and - # user prompt controls are not supported by this model. - # Rebuild prompt with the correct number of audio_pad tokens - # matching the audio encoder output length. - n_audio_frames = audio_features.shape[0] - prompt_token_ids = self.transcriber.build_prompt_tokens(n_audio_frames) - elif not prompt_token_ids: - logger.warning("STT: empty prompt_token_ids, returning EOT") - return [eot] - - tokens = self.transcriber.greedy_decode_tokens(audio_features, prompt_token_ids) - - if self._model_type == "qwen3_asr": - # Extract tokens between and <|im_end|> - tokens = self._extract_asr_text_tokens(tokens) - - # Always end with EOT so vLLM marks the request as finished - tokens.append(eot) - return tokens - - def _extract_asr_text_tokens(self, tokens: list[int]) -> list[int]: - """Extract content tokens between and <|im_end|>. - - Qwen3-ASR outputs: ``language {lang}{text}<|im_end|>`` - We extract only the ``{text}`` portion. - """ - if self._asr_text_token_id is None: - tok = self.transcriber.tokenizer - self._asr_text_token_id = tok.encode( - "", add_special_tokens=False - )[0] - self._im_end_token_id = tok.encode("<|im_end|>", add_special_tokens=False)[ - 0 - ] - asr_text_token = self._asr_text_token_id - im_end_token = self._im_end_token_id - - # Find last tag - start = -1 - for i, t in enumerate(tokens): - if t == asr_text_token: - start = i + 1 - - if start < 0 or start >= len(tokens): - return tokens # No found; return as-is - - # Find first <|im_end|> after - end = len(tokens) - for i in range(start, len(tokens)): - if tokens[i] == im_end_token: - end = i - break - - return tokens[start:end] - - class MetalModelRunner: """Model runner for MLX-based inference on Metal.