Skip to content
157 changes: 157 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,34 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
)


def _schedule_cached_requests(
req_ids: list[str],
num_scheduled_tokens: dict[str, int],
new_token_ids: list[list[int]],
num_computed_tokens: list[int],
num_output_tokens: list[int],
) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData(
req_ids=req_ids,
resumed_req_ids=set(),
new_token_ids=new_token_ids,
all_token_ids={},
new_block_ids=[None] * len(req_ids),
num_computed_tokens=num_computed_tokens,
num_output_tokens=num_output_tokens,
),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=sum(num_scheduled_tokens.values()),
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)


def _is_req_scheduled(model_runner, req_id: str) -> bool:
return req_id in model_runner.input_batch.req_id_to_index

Expand Down Expand Up @@ -510,6 +538,135 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
assert not _is_req_scheduled(model_runner, req_ids[1])


def test_update_states_pp_non_async_multi_request_keeps_token_buffers_consistent(
model_runner, model_runner_2, dist_init, monkeypatch
):
req_ids = ["req_0", "req_1"]
non_last_runner = model_runner
last_runner = model_runner_2
non_last_runner.use_async_scheduling = False
last_runner.use_async_scheduling = False

# Both ranks start from the same request set.
monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.get_pp_group",
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
)
non_last_runner._update_states(_schedule_new_request(*req_ids))
last_runner._update_states(_schedule_new_request(*req_ids))

sampled_by_last_rank = {req_ids[0]: 101, req_ids[1]: 201}
# Emulate last-rank bookkeeping result from previous step:
# sampled tokens already cached in CPU token buffers.
for req_id, token_id in sampled_by_last_rank.items():
req_index = last_runner.input_batch.req_id_to_index[req_id]
start_idx = int(last_runner.input_batch.num_tokens_no_spec[req_index])
end_idx = start_idx + 1
last_runner.input_batch.token_ids_cpu[req_index, start_idx:end_idx] = [token_id]
last_runner.input_batch.is_token_ids[req_index, start_idx:end_idx] = True
last_runner.input_batch.num_tokens_no_spec[req_index] = end_idx
last_runner.requests[req_id].output_token_ids.append(token_id)

scheduler_output = _schedule_cached_requests(
req_ids=req_ids,
num_scheduled_tokens={req_ids[0]: 1, req_ids[1]: 1},
new_token_ids=[[101], [201]],
num_computed_tokens=[3, 3], # prompt tokens only
num_output_tokens=[1, 1],
)
# non-last rank appends new_token_ids in _update_states.
monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.get_pp_group",
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
)
non_last_runner._update_states(scheduler_output)
# last rank should keep its already-bookkept CPU buffers unchanged.
monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.get_pp_group",
lambda: SimpleNamespace(is_last_rank=True, world_size=2),
)
last_runner._update_states(scheduler_output)

# Verify consistency between PP ranks after _update_states.
for req_id in req_ids:
non_last_idx = non_last_runner.input_batch.req_id_to_index[req_id]
last_idx = last_runner.input_batch.req_id_to_index[req_id]
non_last_len = int(non_last_runner.input_batch.num_tokens_no_spec[non_last_idx])
last_len = int(last_runner.input_batch.num_tokens_no_spec[last_idx])
assert non_last_len == last_len
assert (
non_last_runner.input_batch.token_ids_cpu[
non_last_idx, :non_last_len
].tolist()
== last_runner.input_batch.token_ids_cpu[last_idx, :last_len].tolist()
)


def test_update_states_pp_async_multi_request_keeps_rank_state_consistent(
model_runner, model_runner_2, dist_init, monkeypatch
):
req_ids = ["req_0", "req_1"]
non_last_runner = model_runner
last_runner = model_runner_2
non_last_runner.use_async_scheduling = True
last_runner.use_async_scheduling = True

# Both ranks start from the same request set.
monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.get_pp_group",
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
)
non_last_runner._update_states(_schedule_new_request(*req_ids))
last_runner._update_states(_schedule_new_request(*req_ids))

# Simulate async previous-step sampled tokens known on both ranks.
# non-last rank may receive them via PP communication; last rank has
# them from local sampling/bookkeeping.
sampled_by_last_rank = {req_ids[0]: 111, req_ids[1]: 222}
for runner in (non_last_runner, last_runner):
for req_id, token_id in sampled_by_last_rank.items():
req_index = runner.input_batch.req_id_to_index[req_id]
start_idx = int(runner.input_batch.num_tokens_no_spec[req_index])
end_idx = start_idx + 1
runner.input_batch.token_ids_cpu[req_index, start_idx:end_idx] = [token_id]
runner.input_batch.is_token_ids[req_index, start_idx:end_idx] = True
runner.input_batch.num_tokens_no_spec[req_index] = end_idx
runner.requests[req_id].output_token_ids.append(token_id)

scheduler_output = _schedule_cached_requests(
req_ids=req_ids,
num_scheduled_tokens={req_ids[0]: 1, req_ids[1]: 1},
new_token_ids=[],
num_computed_tokens=[4, 4],
num_output_tokens=[1, 1],
)
# non-last rank: async PP branch (new_token_ids empty).
monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.get_pp_group",
lambda: SimpleNamespace(is_last_rank=False, world_size=2),
)
non_last_runner._update_states(scheduler_output)
# last rank: keep already-bookkept state aligned with scheduler view.
monkeypatch.setattr(
"vllm.v1.worker.gpu_model_runner.get_pp_group",
lambda: SimpleNamespace(is_last_rank=True, world_size=2),
)
last_runner._update_states(scheduler_output)

for req_id in req_ids:
non_last_idx = non_last_runner.input_batch.req_id_to_index[req_id]
last_idx = last_runner.input_batch.req_id_to_index[req_id]
non_last_len = int(non_last_runner.input_batch.num_tokens_no_spec[non_last_idx])
last_len = int(last_runner.input_batch.num_tokens_no_spec[last_idx])
assert non_last_len == last_len
assert (
non_last_runner.input_batch.token_ids_cpu[
non_last_idx, :non_last_len
].tolist()
== last_runner.input_batch.token_ids_cpu[last_idx, :last_len].tolist()
)


def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order.
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def update_req_spec_token_ids(
start_index = self.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
self.is_token_ids[req_index, start_index:end_token_index] = True
cur_spec_token_ids.extend(spec_token_ids)

def remove_request(self, req_id: str) -> int | None:
Expand Down
28 changes: 21 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,13 +1347,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index
] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
start_token_index = self.input_batch.num_tokens_no_spec[req_index]
# For chunked prefill, num_computed_tokens may less
# than num_tokens_no_spec.
# Async scheduled PP: no new_token_ids, advance num_tokens_no_spec
# according to num_computed_tokens.
end_token_index = max(
start_token_index,
num_computed_tokens + len(new_token_ids),
)
if end_token_index > start_token_index:
if new_token_ids:
# Add new_token_ids to token_ids_cpu.
num_new_tokens = end_token_index - start_token_index
tokens_to_append = new_token_ids[-num_new_tokens:]
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index
] = tokens_to_append
self.input_batch.is_token_ids[
req_index, start_token_index:end_token_index
] = True
self.input_batch.num_tokens_no_spec[req_index] = end_token_index

# Add spec_token_ids to token_ids_cpu.
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
Expand Down
Loading