diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index 17ee86dc..4ace8975 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1427,10 +1427,11 @@ def _unified_prefill_decode_paged( tuple[ str, list[int], + list[int], SamplingParams, list[int], torch.Generator | None, - int | None, + int, int, ] ], @@ -1444,10 +1445,10 @@ def _unified_prefill_decode_paged( Args: prefill_reqs: list of - ``(req_id, token_ids, sampling_params, block_ids, - generator, prompt_len, start_pos)`` — prefill requests. - ``start_pos`` is the RoPE offset / KV slot start (0 for - fresh prefill, >0 for continuation chunks). + ``(req_id, chunk_token_ids, full_token_ids, sampling_params, + block_ids, generator, prompt_len, start_pos)`` — prefill + requests. ``start_pos`` is the RoPE offset / KV slot start + (0 for fresh prefill, >0 for continuation chunks). decode_reqs: list of ``(req_id, RequestState)`` — decode requests. Returns: @@ -1471,8 +1472,8 @@ def _unified_prefill_decode_paged( all_token_ids.extend(last_tokens) # Prefill: tokens per request - for _, token_ids, _, _, _, _, _ in prefill_reqs: - all_token_ids.extend(token_ids) + for _, chunk_token_ids, _, _, _, _, _, _ in prefill_reqs: + all_token_ids.extend(chunk_token_ids) # ---- build metadata for prepare_unified ---- decode_info: list[tuple[list[int], int]] = [] @@ -1481,8 +1482,8 @@ def _unified_prefill_decode_paged( decode_info.append((state.block_ids, seq_len)) prefill_info: list[tuple[list[int], int, int]] = [] - for _, token_ids, _, block_ids, _, _, start_pos in prefill_reqs: - prefill_info.append((block_ids, len(token_ids), start_pos)) + for _, chunk_token_ids, _, _, block_ids, _, _, start_pos in prefill_reqs: + prefill_info.append((block_ids, len(chunk_token_ids), start_pos)) prepare_unified(decode_info, prefill_info, self._paged_block_size) @@ -1499,8 +1500,8 @@ def _unified_prefill_decode_paged( cu_seqlens: list[int] = [0] for _ in decode_reqs: cu_seqlens.append(cu_seqlens[-1] + 1) - for _, token_ids, _, _, _, _, _ in prefill_reqs: - cu_seqlens.append(cu_seqlens[-1] + len(token_ids)) + for _, chunk_token_ids, _, _, _, _, _, _ in prefill_reqs: + cu_seqlens.append(cu_seqlens[-1] + len(chunk_token_ids)) # ---- sample decode tokens ---- decode_next_tokens: list[int] = [] @@ -1566,18 +1567,19 @@ def _unified_prefill_decode_paged( prefill_next_tokens: list[int] = [] for j, ( req_id, - token_ids, + chunk_token_ids, + full_token_ids, sampling_params, _block_ids, generator, prompt_len, - _start_pos, + start_pos, ) in enumerate(prefill_reqs): last_idx = cu_seqlens[num_decode + j + 1] - 1 last_logits = logits[:, last_idx : last_idx + 1, :] - if prompt_len is None: - prompt_len = len(token_ids) + prompt_token_ids = full_token_ids[:prompt_len] + output_token_ids = full_token_ids[prompt_len:] is_greedy = sampling_params.temperature < 1e-5 needs_advanced = ( @@ -1600,14 +1602,14 @@ def _unified_prefill_decode_paged( generators = {} if generator is None else {0: generator} metadata = self._make_sampling_metadata( [sampling_params], - [token_ids[:prompt_len]], - [token_ids[prompt_len:]], + [prompt_token_ids], + [output_token_ids], generators=generators, ) output = self._sampler.forward(logits_torch, metadata) next_token = int(output.sampled_token_ids[0, 0].item()) - self._paged_request_seq_lens[req_id] = len(token_ids) + self._paged_request_seq_lens[req_id] = start_pos + len(chunk_token_ids) prefill_next_tokens.append(next_token) return prefill_next_tokens, decode_next_tokens @@ -1642,14 +1644,16 @@ def execute_model( cached_reqs = scheduler_output.scheduled_cached_reqs # Paged-attention entries collected for the single unified forward. - # Each prefill entry: (output_idx, req_id, token_ids, sampling_params, - # block_ids, generator, is_new, is_intermediate, - # prompt_len, start_pos) + # Each prefill entry: (output_idx, req_id, chunk_token_ids, + # full_token_ids, sampling_params, block_ids, + # generator, is_new, is_intermediate, prompt_len, + # start_pos) paged_prefill_entries: list[ tuple[ int, str, list[int], + list[int], SamplingParams, list[int], torch.Generator | None, @@ -1691,6 +1695,7 @@ def execute_model( output_idx, req_id, token_ids[computed_tokens:cur_len], + list(token_ids), sampling_params, sched_block_ids, generator, @@ -1802,6 +1807,7 @@ def execute_model( output_idx, req_id, state.token_ids[computed:target_len], + list(state.token_ids), state.sampling_params, state.block_ids, state.generator, @@ -1843,14 +1849,27 @@ def execute_model( prefill_pack = [ ( rid, - tids, + chunk_tids, + full_tids, sp, bids, gen, - prompt_len if not is_intermediate else None, + prompt_len, start_pos, ) - for _, rid, tids, sp, bids, gen, _is_new, is_intermediate, prompt_len, start_pos in paged_prefill_entries + for ( + _, + rid, + chunk_tids, + full_tids, + sp, + bids, + gen, + _is_new, + _is_intermediate, + prompt_len, + start_pos, + ) in paged_prefill_entries ] prefill_tokens, decode_tokens = self._unified_prefill_decode_paged( prefill_pack, paged_decode_reqs @@ -1860,7 +1879,8 @@ def execute_model( for i, ( idx, rid, - tids, + _chunk_tids, + full_tids, sp, bids, gen, @@ -1875,22 +1895,20 @@ def execute_model( # KV cache populated; discard sampled token sampled_tokens[idx] = [] elif is_new: - assert _start_pos == 0, ( - "new complete prefill with start_pos > 0 not supported " - "(prefix caching not yet implemented in unified path)" - ) sampled_tokens[idx] = [nt] self._request_states[rid] = RequestState( - token_ids=list(tids) + [nt], - prompt_len=len(tids), + token_ids=list(full_tids) + [nt], + prompt_len=_prompt_len, cache=[], sampling_params=sp, generator=gen, - generated_tokens=1, + generated_tokens=len(full_tids) + 1 - _prompt_len, block_ids=bids, ) if self._rust_state_manager is not None: - self._rust_state_manager.add_request(rid, list(tids) + [nt]) + self._rust_state_manager.add_request( + rid, list(full_tids) + [nt] + ) else: # Cached last chunk — append token to existing state sampled_tokens[idx] = [nt]