Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions tests/v1/core/test_async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def test_prefix_caching_for_prefill_dedup():
same_prompt=True,
block_size=BLOCK_SIZE,
)
requests_copy = requests.copy()

# Two requests with the same prompt.
req0 = requests.pop(0)
Expand All @@ -167,26 +166,31 @@ def test_prefix_caching_for_prefill_dedup():
# Make sure prefix caching de-duplicates the prompts in the same step,
# so all the blocks except the last are shared between the two requests.
assert len(sched_output.num_scheduled_tokens) == 2
num_blocks = num_prompt_tokens // BLOCK_SIZE
assert req0.num_cached_tokens == 0
assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE
assert sched_output.num_scheduled_tokens[req0.request_id] == num_prompt_tokens
assert (
sched_output.num_scheduled_tokens[req1.request_id]
== num_prompt_tokens % BLOCK_SIZE
)
Comment thread
markmc marked this conversation as resolved.

sched_outputs.append(scheduler.schedule())
while sched_outputs:
added_req = None
if requests:
scheduler.add_request(requests.pop(0))
added_req = requests.pop(0)
scheduler.add_request(added_req)
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
if added_req:
assert (
sched_output.num_scheduled_tokens[added_req.request_id]
== num_prompt_tokens % BLOCK_SIZE
)

# Other requests scheduled after the two requests should also get
# prefix cache hit.
assert scheduler.get_num_unfinished_requests() == 0
for req in requests_copy[1:]:
assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE


def test_prefix_caching_for_multi_turn():
Expand Down Expand Up @@ -243,12 +247,15 @@ def test_prefix_caching_for_multi_turn():
# Schedule the next-turn requests.
for req in next_turn_requests:
scheduler.add_request(req)
sched_outputs.append(scheduler.schedule())
sched_output = scheduler.schedule()
sched_outputs.append(sched_output)

# Make sure the next-turn requests get prefix cache hit by the previous
# requests.
for req in next_turn_requests:
assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE
assert sched_output.num_scheduled_tokens[req.request_id] == (
req.num_prompt_tokens % BLOCK_SIZE
)


def test_abort_request_when_structured_output_fsm_cannot_advance():
Expand Down
14 changes: 10 additions & 4 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_incremental_detokenization(

engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
prompts_list=dummy_test_vectors.prompt_tokens,
request_ids=[req.request_id for req in requests],
)

Expand Down Expand Up @@ -506,6 +507,7 @@ def test_logprobs_processor(

engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
prompts_list=dummy_test_vectors.prompt_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
Expand Down Expand Up @@ -691,6 +693,7 @@ def test_stop_token(

engine_core = MockEngineCore(
tokens_list=[generation_tokens],
prompts_list=dummy_test_vectors.prompt_tokens,
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=sampling_params.eos_token_id,
Expand Down Expand Up @@ -794,6 +797,7 @@ def test_stop_string(

engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
prompts_list=dummy_test_vectors.prompt_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
Expand Down Expand Up @@ -917,6 +921,7 @@ def test_iteration_stats(dummy_test_vectors):

engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
dummy_test_vectors.prompt_tokens,
request_ids=[req.request_id for req in requests],
)

Expand All @@ -927,7 +932,7 @@ def test_iteration_stats(dummy_test_vectors):
inactive_request = requests[num_active]

# First iteration has 2 prefills.
outputs = engine_core.get_outputs()[:num_active]
outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
total_prompt_tokens = sum(
Expand All @@ -941,7 +946,7 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active

# Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active]
outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)

Expand All @@ -951,7 +956,7 @@ def test_iteration_stats(dummy_test_vectors):
# Add a new request - prefill and 2 decodes in this step.
output_processor.add_request(inactive_request, None)
num_active += 1
outputs = engine_core.get_outputs()[:num_active]
outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
Expand All @@ -960,7 +965,7 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active

# Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active]
outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)

Expand Down Expand Up @@ -1003,6 +1008,7 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):

engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
dummy_test_vectors.prompt_tokens,
request_ids=[req.request_id for req in requests],
)

Expand Down
33 changes: 27 additions & 6 deletions tests/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.metrics.stats import PrefillStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors

GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
Expand Down Expand Up @@ -330,6 +331,7 @@ class MockEngineCore:
def __init__(
self,
tokens_list: list[list[int]],
prompts_list: list[list[int]],
# For each request, for each sampled token offset,
# a tuple of
# (list of topk token ids, list of sample logprob vals, rank)
Expand All @@ -346,12 +348,13 @@ def __init__(
) -> None:
self.num_requests = len(tokens_list)
self.tokens_list = tokens_list
self.current_idx = 0
self.prompts_list = prompts_list
self.generated_logprobs_raw = generated_logprobs_raw
self.do_logprobs = generated_logprobs_raw is not None
self.prompt_logprobs_raw = prompt_logprobs_raw
self.do_prompt_logprobs = prompt_logprobs_raw is not None
self.request_finished = [False for _ in range(self.num_requests)]
self.request_token_idx = [0 for _ in range(self.num_requests)]
self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids
self.request_ids = (
Expand All @@ -360,14 +363,18 @@ def __init__(
else [f"request-{i}" for i in range(self.num_requests)]
)

def get_outputs(self) -> list[EngineCoreOutput]:
def get_outputs(self, num_active: int = -1) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs
do_prompt_logprobs = self.do_prompt_logprobs
token_idx = self.current_idx

outputs = []
for req_idx, token_ids in enumerate(self.tokens_list):
for req_idx, (token_ids, prompt_token_ids) in enumerate(
zip(self.tokens_list, self.prompts_list)
):
if num_active != -1 and req_idx >= num_active:
break
if not self.request_finished[req_idx]:
token_idx = self.request_token_idx[req_idx]
if do_logprobs:
assert self.generated_logprobs_raw is not None
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
Expand All @@ -381,19 +388,32 @@ def get_outputs(self) -> list[EngineCoreOutput]:
else:
logprobs = None
if do_prompt_logprobs:
if self.current_idx == 0:
if token_idx == 0:
assert self.prompt_logprobs_raw is not None
prompt_logprobs = self.prompt_logprobs_raw[req_idx]
else:
prompt_logprobs = None
else:
prompt_logprobs = None

# Add prefill_stats on first output (prefill) for this request
if token_idx == 0:
prefill_stats = PrefillStats()
prefill_stats.set(
num_prompt_tokens=len(prompt_token_ids),
num_local_cached_tokens=0,
num_external_cached_tokens=0,
)
else:
prefill_stats = None

new_token_id = token_ids[token_idx]
output = EngineCoreOutput(
request_id=self.request_ids[req_idx],
new_token_ids=[new_token_id],
new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs,
prefill_stats=prefill_stats,
)
if token_idx == len(token_ids) - 1:
output.finish_reason = FinishReason.LENGTH
Expand All @@ -407,5 +427,6 @@ def get_outputs(self) -> list[EngineCoreOutput]:
self.request_finished[req_idx] = True
outputs.append(output)

self.current_idx += 1
self.request_token_idx[req_idx] += 1

return outputs
Loading
Loading