Skip to content
14 changes: 11 additions & 3 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,9 @@ def update_from_output(
):
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))

if new_token_ids and self.structured_output_manager.should_advance(request):
if new_token_ids and self.structured_output_manager.should_advance(
request, new_token_ids
):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
Expand Down Expand Up @@ -1611,7 +1613,10 @@ def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None:
continue

# Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request):
if self.structured_output_manager.should_advance(
request,
spec_token_ids,
):
metadata = request.structured_output_request
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr]
request.spec_token_ids = spec_token_ids
Expand Down Expand Up @@ -1640,7 +1645,10 @@ def update_draft_token_ids_in_output(
# (needed for chunked prefill case for example).
del spec_token_ids[orig_num_spec_tokens:]
# Filter out spec tokens which do not adhere to the grammar.
if self.structured_output_manager.should_advance(request):
if self.structured_output_manager.should_advance(
request,
spec_token_ids,
):
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids)
Expand Down
50 changes: 30 additions & 20 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,16 @@ def grammar_bitmask(
apply_bitmask = False
if apply_bitmask and not grammar.is_terminated():
accepted = grammar.accept_tokens(req_id, [token])
assert accepted, (token, req_id, scheduled_spec_decode_tokens)
state_advancements += 1
if not accepted:
# Speculative token violates grammar constraint.
# Dont crash: stop validating for remaining slots.
apply_bitmask = False
else:
# Only increment state if token was accepted by grammar
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
grammar.rollback(state_advancements)

bitmask_tensor = self._grammar_bitmask
if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index]
Expand Down Expand Up @@ -299,37 +303,43 @@ def should_fill_bitmask(self, request: "Request") -> bool:
return request.structured_output_request.reasoning_ended
return True

def should_advance(self, request: "Request") -> bool:
def should_advance(
self,
request: "Request",
new_token_ids: list[int] | None = None,
) -> bool:
if not request.use_structured_output:
return False

# To determine whether we can advance the FSM.
# Supports thinking usage where we skip the reasoning components.
if TYPE_CHECKING:
assert request.structured_output_request is not None
assert request.structured_output_request.grammar is not None
# by default, we should always advance
# for cases that don't use thinking mode.
if self.reasoner is None:
return True

# if the model needs structured in reasoning, we should advance
if self.enable_in_reasoning:
return True

structured_req = request.structured_output_request
if structured_req is None or structured_req.grammar is None:
return True

# If reasoning already ended, advance.
if structured_req.reasoning_ended:
return True

# Check if reasoning ends in *this* step
delta_from = request.num_computed_tokens - request.num_output_placeholders
all_token_ids = request.all_token_ids
if self.reasoner.is_reasoning_end_streaming(
all_token_ids, all_token_ids[delta_from:]
):
# Reasoning just ended, so we shouldn't advance til
# next pass
# Detect reasoning end from committed tokens (request.all_token_ids).
all_ids = request.all_token_ids or []
scan_from = structured_req.reasoning_scan_idx
if scan_from > len(all_ids):
scan_from = len(all_ids)

delta = all_ids[scan_from:]
structured_req.reasoning_scan_idx = len(all_ids)

if not delta:
return False

if self.reasoner.is_reasoning_end_streaming(all_ids, delta):
structured_req.reasoning_ended = True
return False # Stop here; advance next step.

return False

Expand Down
8 changes: 5 additions & 3 deletions vllm/v1/structured_output/backend_xgrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,11 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
return False
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.",
# Under speculative decoding, the draft model may propose tokens
# that violate the grammar (e.g., "</think>" after reasoning ends).
# This is not fatal: treat as a rejected token sequence.
logger.debug(
"Grammar rejected token req=%s token=%s",
request_id,
token,
)
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/structured_output/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class StructuredOutputRequest:
params: StructuredOutputsParams
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
reasoning_ended: bool | None = None
reasoning_scan_idx: int = 0

@staticmethod
def from_sampling_params(
Expand Down