Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/v1/core/test_async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@ def test_stop_by_max_tokens(max_tokens: int):
requests = create_requests(num_requests=2, max_tokens=max_tokens)
req0, req1 = requests

expected_total_num_scheduled_tokens = 0
sched_outputs: deque[SchedulerOutput] = deque()
scheduler.add_request(req0)
sched_outputs.append(scheduler.schedule())
expected_total_num_scheduled_tokens += req0.num_prompt_tokens + max_tokens - 1

scheduler.add_request(req1)
sched_outputs.append(scheduler.schedule())
expected_total_num_scheduled_tokens += req1.num_prompt_tokens + max_tokens - 1

total_num_scheduled_tokens = 0
while sched_outputs:
sched_output = sched_outputs.popleft()
total_num_scheduled_tokens += sched_output.total_num_scheduled_tokens
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)

Expand All @@ -53,6 +58,8 @@ def test_stop_by_max_tokens(max_tokens: int):
assert scheduler.get_num_unfinished_requests() == 0
assert req0.num_output_tokens == max_tokens
assert req1.num_output_tokens == max_tokens
# Ensure we aren't scheduling more tokens than necessary.
assert total_num_scheduled_tokens == expected_total_num_scheduled_tokens


def test_abort():
Expand Down
1 change: 0 additions & 1 deletion tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def test_suffix_decoding_acceptance(
)

# Run several times and check that the accepted tokens increase.
spec_llm.chat(test_prompts, sampling_config)
num_draft = []
num_accept = []
for i in range(10): # Run multiple times to warm up the cache.
Expand Down
10 changes: 7 additions & 3 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,14 @@ def schedule(self) -> SchedulerOutput:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
# Make sure the input position does not exceed the max model len or
# request's max_tokens.
# This is necessary when using spec decoding and/or async scheduling.
max_total_tokens = min(
request.num_prompt_tokens + request.max_tokens, self.max_model_len
)
num_new_tokens = min(
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
)
Comment thread
njhill marked this conversation as resolved.

# Schedule encoder inputs.
Expand Down