Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 51 additions & 33 deletions vllm_metal/v1/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,10 +1427,11 @@ def _unified_prefill_decode_paged(
tuple[
str,
list[int],
list[int],
SamplingParams,
list[int],
torch.Generator | None,
int | None,
int,
int,
]
],
Expand All @@ -1444,10 +1445,10 @@ def _unified_prefill_decode_paged(

Args:
prefill_reqs: list of
``(req_id, token_ids, sampling_params, block_ids,
generator, prompt_len, start_pos)`` — prefill requests.
``start_pos`` is the RoPE offset / KV slot start (0 for
fresh prefill, >0 for continuation chunks).
``(req_id, chunk_token_ids, full_token_ids, sampling_params,
block_ids, generator, prompt_len, start_pos)`` — prefill
requests. ``start_pos`` is the RoPE offset / KV slot start
(0 for fresh prefill, >0 for continuation chunks).
decode_reqs: list of ``(req_id, RequestState)`` — decode requests.

Returns:
Expand All @@ -1471,8 +1472,8 @@ def _unified_prefill_decode_paged(
all_token_ids.extend(last_tokens)

# Prefill: tokens per request
for _, token_ids, _, _, _, _, _ in prefill_reqs:
all_token_ids.extend(token_ids)
for _, chunk_token_ids, _, _, _, _, _, _ in prefill_reqs:
all_token_ids.extend(chunk_token_ids)

# ---- build metadata for prepare_unified ----
decode_info: list[tuple[list[int], int]] = []
Expand All @@ -1481,8 +1482,8 @@ def _unified_prefill_decode_paged(
decode_info.append((state.block_ids, seq_len))

prefill_info: list[tuple[list[int], int, int]] = []
for _, token_ids, _, block_ids, _, _, start_pos in prefill_reqs:
prefill_info.append((block_ids, len(token_ids), start_pos))
for _, chunk_token_ids, _, _, block_ids, _, _, start_pos in prefill_reqs:
prefill_info.append((block_ids, len(chunk_token_ids), start_pos))

prepare_unified(decode_info, prefill_info, self._paged_block_size)

Expand All @@ -1499,8 +1500,8 @@ def _unified_prefill_decode_paged(
cu_seqlens: list[int] = [0]
for _ in decode_reqs:
cu_seqlens.append(cu_seqlens[-1] + 1)
for _, token_ids, _, _, _, _, _ in prefill_reqs:
cu_seqlens.append(cu_seqlens[-1] + len(token_ids))
for _, chunk_token_ids, _, _, _, _, _, _ in prefill_reqs:
cu_seqlens.append(cu_seqlens[-1] + len(chunk_token_ids))

# ---- sample decode tokens ----
decode_next_tokens: list[int] = []
Expand Down Expand Up @@ -1566,18 +1567,19 @@ def _unified_prefill_decode_paged(
prefill_next_tokens: list[int] = []
for j, (
req_id,
token_ids,
chunk_token_ids,
full_token_ids,
sampling_params,
_block_ids,
generator,
prompt_len,
_start_pos,
start_pos,
) in enumerate(prefill_reqs):
last_idx = cu_seqlens[num_decode + j + 1] - 1
last_logits = logits[:, last_idx : last_idx + 1, :]

if prompt_len is None:
prompt_len = len(token_ids)
prompt_token_ids = full_token_ids[:prompt_len]
output_token_ids = full_token_ids[prompt_len:]

is_greedy = sampling_params.temperature < 1e-5
needs_advanced = (
Expand All @@ -1600,14 +1602,14 @@ def _unified_prefill_decode_paged(
generators = {} if generator is None else {0: generator}
metadata = self._make_sampling_metadata(
[sampling_params],
[token_ids[:prompt_len]],
[token_ids[prompt_len:]],
[prompt_token_ids],
[output_token_ids],
generators=generators,
)
output = self._sampler.forward(logits_torch, metadata)
next_token = int(output.sampled_token_ids[0, 0].item())

self._paged_request_seq_lens[req_id] = len(token_ids)
self._paged_request_seq_lens[req_id] = start_pos + len(chunk_token_ids)
prefill_next_tokens.append(next_token)

return prefill_next_tokens, decode_next_tokens
Expand Down Expand Up @@ -1642,14 +1644,16 @@ def execute_model(
cached_reqs = scheduler_output.scheduled_cached_reqs

# Paged-attention entries collected for the single unified forward.
# Each prefill entry: (output_idx, req_id, token_ids, sampling_params,
# block_ids, generator, is_new, is_intermediate,
# prompt_len, start_pos)
# Each prefill entry: (output_idx, req_id, chunk_token_ids,
# full_token_ids, sampling_params, block_ids,
# generator, is_new, is_intermediate, prompt_len,
# start_pos)
paged_prefill_entries: list[
tuple[
int,
str,
list[int],
list[int],
SamplingParams,
list[int],
torch.Generator | None,
Expand Down Expand Up @@ -1691,6 +1695,7 @@ def execute_model(
output_idx,
req_id,
token_ids[computed_tokens:cur_len],
list(token_ids),
sampling_params,
sched_block_ids,
generator,
Expand Down Expand Up @@ -1802,6 +1807,7 @@ def execute_model(
output_idx,
req_id,
state.token_ids[computed:target_len],
list(state.token_ids),
state.sampling_params,
state.block_ids,
state.generator,
Expand Down Expand Up @@ -1843,14 +1849,27 @@ def execute_model(
prefill_pack = [
(
rid,
tids,
chunk_tids,
full_tids,
sp,
bids,
gen,
prompt_len if not is_intermediate else None,
prompt_len,
start_pos,
)
for _, rid, tids, sp, bids, gen, _is_new, is_intermediate, prompt_len, start_pos in paged_prefill_entries
for (
_,
rid,
chunk_tids,
full_tids,
sp,
bids,
gen,
_is_new,
_is_intermediate,
prompt_len,
start_pos,
) in paged_prefill_entries
]
prefill_tokens, decode_tokens = self._unified_prefill_decode_paged(
prefill_pack, paged_decode_reqs
Expand All @@ -1860,7 +1879,8 @@ def execute_model(
for i, (
idx,
rid,
tids,
_chunk_tids,
full_tids,
sp,
bids,
gen,
Expand All @@ -1875,22 +1895,20 @@ def execute_model(
# KV cache populated; discard sampled token
sampled_tokens[idx] = []
elif is_new:
assert _start_pos == 0, (
"new complete prefill with start_pos > 0 not supported "
"(prefix caching not yet implemented in unified path)"
)
sampled_tokens[idx] = [nt]
self._request_states[rid] = RequestState(
token_ids=list(tids) + [nt],
prompt_len=len(tids),
token_ids=list(full_tids) + [nt],
prompt_len=_prompt_len,
cache=[],
sampling_params=sp,
generator=gen,
generated_tokens=1,
generated_tokens=len(full_tids) + 1 - _prompt_len,
block_ids=bids,
)
if self._rust_state_manager is not None:
self._rust_state_manager.add_request(rid, list(tids) + [nt])
self._rust_state_manager.add_request(
rid, list(full_tids) + [nt]
)
else:
# Cached last chunk — append token to existing state
sampled_tokens[idx] = [nt]
Expand Down
Loading