Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/v1/e2e/test_async_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@

default_params = dict(
temperature=0.0, # greedy
max_tokens=23,
min_tokens=18,
max_tokens=30,
# spec decoding currently doesn't support min_tokens
# min_tokens=28,
)


Expand Down Expand Up @@ -86,7 +87,7 @@ def test_without_spec_decoding(
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)


def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
Expand All @@ -100,9 +101,16 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}

struct_outputs = StructuredOutputsParams(json=sample_json_schema)

test_sampling_params = [
dict(),
dict(logprobs=2),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
),
]

# test_preemption, executor, async_scheduling,
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/core/sched/async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
class AsyncScheduler(Scheduler):
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output)
has_structured_output_requests = False
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
has_structured_output_requests |= request.use_structured_output
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
Expand All @@ -33,6 +35,7 @@ def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens

scheduler_output.has_structured_output_requests = has_structured_output_requests
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
Expand Down
21 changes: 20 additions & 1 deletion vllm/v1/core/sched/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,26 @@ def update_from_output(

@abstractmethod
def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None:
"""Update the draft token ids for the scheduled requests."""
"""Update requests with newly generated draft token ids, applying
structured output grammar validation if needed.

Args:
draft_token_ids: The input draft token ids for each request.
"""
raise NotImplementedError

@abstractmethod
def update_draft_token_ids_in_output(
self, draft_token_ids: "DraftTokenIds", scheduler_output: "SchedulerOutput"
) -> None:
"""Update scheduler output with newly generated draft token ids, applying
structured output grammar validation if needed.

Args:
draft_token_ids: The input draft token ids for each request.
scheduler_output: Update the given scheduler_output
with the corresponding draft token ids.
"""
raise NotImplementedError

@abstractmethod
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,17 @@ class SchedulerOutput:
# Only used for v2 model runner.
preempted_req_ids: set[str] | None = None

# Whether any of the scheduled requests use structured output.
# Set only in async scheduling case.
has_structured_output_requests: bool = False

# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False

# Used for adjusting acceptance rate calculation.
num_invalid_spec_tokens: dict[str, int] | None = None

# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None

Expand Down
61 changes: 54 additions & 7 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,8 @@ def update_from_output(
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted,
num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens,
request_id=req_id,
)

stopped = False
Expand Down Expand Up @@ -1168,7 +1170,13 @@ def update_from_output(
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if not ok:
logger.warning(
"Unexpected: grammar rejected tokens %s for request %s.",
new_token_ids,
req_id,
)

if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
Expand Down Expand Up @@ -1330,11 +1338,46 @@ def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None:
# Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids
)
else:
request.spec_token_ids = spec_token_ids
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr]
request.spec_token_ids = spec_token_ids

def update_draft_token_ids_in_output(
self, draft_token_ids: DraftTokenIds, scheduler_output: SchedulerOutput
) -> None:
num_invalid_spec_tokens: dict[str, int] = {}

sched_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, spec_token_ids in zip(
draft_token_ids.req_ids,
draft_token_ids.draft_token_ids,
):
request = self.requests.get(req_id)
if request is None or request.is_finished():
# The request may have been finished. Skip.
continue

placeholder_spec_tokens = sched_spec_tokens.get(req_id)
if not placeholder_spec_tokens:
continue

orig_num_spec_tokens = len(placeholder_spec_tokens)
# Trim drafts to scheduled number of spec tokens
# (needed for chunked prefill case for example).
del spec_token_ids[orig_num_spec_tokens:]
# Filter out spec tokens which do not adhere to the grammar.
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids)
# Pad to original number of spec tokens.
num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids)
if num_invalid_tokens:
spec_token_ids.extend([-1] * num_invalid_tokens)
num_invalid_spec_tokens[req_id] = num_invalid_tokens

sched_spec_tokens[req_id] = spec_token_ids

scheduler_output.num_invalid_spec_tokens = num_invalid_spec_tokens

def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
Expand Down Expand Up @@ -1513,11 +1556,15 @@ def make_spec_decoding_stats(
spec_decoding_stats: SpecDecodingStats | None,
num_draft_tokens: int,
num_accepted_tokens: int,
num_invalid_spec_tokens: dict[str, int] | None,
request_id: str,
) -> SpecDecodingStats | None:
if not self.log_stats:
if not self.log_stats or not num_draft_tokens:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
if num_invalid_spec_tokens:
num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0)
spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens
)
Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,18 @@ def step_with_batch_queue(
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# If we are doing speculative decoding with structured output,
# we need to get the draft token ids from the prior step before
# we can compute the grammar bitmask for the deferred request.
if self.use_spec_decode:
draft_token_ids = self.model_executor.take_draft_token_ids()
assert draft_token_ids is not None
# Update the draft token ids in the scheduler output to
# filter out the invalid spec tokens, which will be padded
# with -1 and skipped by the grammar bitmask computation.
self.scheduler.update_draft_token_ids_in_output(
draft_token_ids, deferred_scheduler_output
)
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/engine/input_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,11 @@ def _validate_supported_sampling_params(
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
or params.structured_outputs
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters."
"penalties or bad words in sampling parameters."
)

def _validate_params(
Expand Down
Loading