Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,132 @@ def test_stop_via_update_from_output():
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]


def test_reasoning_spec_decode_grammar():
"""
Test for speculative decoding with reasoning parser and grammar constraints.

This test validates the fix for the bug where combining reasoning parser,
structured output (JSON schema), and speculative decoding resulted in
invalid JSON output.

Key behaviors tested:
- When reasoning_end is detected in spec tokens, ALL spec tokens are cleared.
- Grammar FSM does NOT accept reasoning_end token.
- Spec tokens validated even when reasoning_ended flag is stale.
"""
REASONING_END_TOKEN_ID = 151660
Comment thread
sfbemerk marked this conversation as resolved.

scheduler = create_scheduler(num_speculative_tokens=3)
requests = create_requests(num_requests=1, max_tokens=20)
request = requests[0]

structured_params = StructuredOutputsParams(
json='{"type": "object", "properties": {"answer": {"type": "string"}}}'
Comment thread
sfbemerk marked this conversation as resolved.
)

from vllm.v1.structured_output.request import StructuredOutputRequest

structured_req = StructuredOutputRequest(params=structured_params)
structured_req.reasoning_ended = False

mock_grammar = Mock()
mock_grammar.is_terminated = Mock(return_value=False)
mock_grammar.validate_tokens = Mock(side_effect=lambda tokens: tokens)
mock_grammar.accept_tokens = Mock(return_value=True)
structured_req.grammar = mock_grammar

request.structured_output_request = structured_req

mock_reasoner = Mock()
mock_reasoner.is_reasoning_end = (
lambda token_ids: REASONING_END_TOKEN_ID in token_ids
)
mock_reasoner.is_reasoning_end_streaming = (
lambda full_ids, delta_ids: REASONING_END_TOKEN_ID in delta_ids
)

def extract_content_ids(delta_ids):
if REASONING_END_TOKEN_ID not in delta_ids:
return []
idx = delta_ids.index(REASONING_END_TOKEN_ID)
return delta_ids[idx + 1 :]

mock_reasoner.extract_content_ids = extract_content_ids
scheduler.structured_output_manager.reasoner = mock_reasoner
scheduler.structured_output_manager.enable_in_reasoning = False

request.num_computed_tokens = request.num_tokens + 3
request.append_output_token_ids([100, 101, 102])

scheduler.requests[request.request_id] = request
scheduler.running.append(request)
request.status = RequestStatus.RUNNING

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.request_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)

# Test case 1: spec tokens cleared when reasoning_end detected in draft tokens
model_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[104]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)

draft_tokens = DraftTokenIds(
[request.request_id], [[105, REASONING_END_TOKEN_ID, 200, 201]]
)
scheduler.update_draft_token_ids(draft_tokens)

assert request.spec_token_ids == []

# Test case 2: drafted tokens after reasoning end are not accepted
scheduler_output2 = scheduler.schedule()
model_output2 = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[REASONING_END_TOKEN_ID, 200]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)

mock_grammar.accept_tokens.reset_mock()
scheduler.update_from_output(scheduler_output2, model_output2)

assert request.all_token_ids[-1] == REASONING_END_TOKEN_ID

# Test case 3: grammar accepts content tokens after reasoning ended
scheduler_output3 = scheduler.schedule()
model_output3 = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[300]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)

mock_grammar.accept_tokens.reset_mock()
scheduler.update_from_output(scheduler_output3, model_output3)

assert mock_grammar.accept_tokens.call_count == 1
accepted_tokens = mock_grammar.accept_tokens.call_args[0][1]
assert accepted_tokens == [300]


def test_check_stop_min_tokens():
"""Test that requests don't stop when min_tokens requirement isn't met."""
from vllm.v1.core.sched.utils import check_stop
Expand Down
47 changes: 46 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,34 @@ def update_from_output(
kv_transfer_params = None
status_before_stop = request.status

# If reasoning ends in this step and content tokens were emitted,
# drop those content tokens so constrained decoding can start
# in the next step. This prevents unconstrained prefixes like
# markdown fences when structured output is requested.
reasoner = self.structured_output_manager.reasoner
in_unconstraint_reasoning = (
request.structured_output_request is not None
and not request.structured_output_request.reasoning_ended
and not self.structured_output_manager.enable_in_reasoning
)
if (
new_token_ids
and reasoner is not None
and in_unconstraint_reasoning
and reasoner.is_reasoning_end_streaming(
request.all_token_ids, new_token_ids
)
):
Comment thread
sfbemerk marked this conversation as resolved.
content_ids = reasoner.extract_content_ids(new_token_ids)
if (
content_ids
and len(content_ids) <= len(new_token_ids)
and new_token_ids[-len(content_ids) :] == content_ids
):
new_token_ids = new_token_ids[: -len(content_ids)]
# we've accepted the reasoning end token
request.structured_output_request.reasoning_ended = True # type: ignore[union-attr]

# Check for stop and update request status.
if new_token_ids:
new_token_ids, stopped = self._update_request_with_output(
Expand Down Expand Up @@ -1608,8 +1636,25 @@ def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None:
request.spec_token_ids = []
continue

# For non-reasoning case or when reasoning has ended, validate all tokens,
# when reasoning end is deteted within spec tokens, clear them and
# allow grammar to activate cleanly for non-reasoning tokens
should_validate = self.structured_output_manager.should_advance(request)
reasoner = self.structured_output_manager.reasoner
if (
not should_validate
and request.use_structured_output
and reasoner is not None
):
for i, token_id in enumerate(spec_token_ids):
if reasoner.is_reasoning_end_streaming(
request.all_token_ids, spec_token_ids[: i + 1]
):
spec_token_ids = []
break

# Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request):
if should_validate and spec_token_ids:
metadata = request.structured_output_request
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr]
request.spec_token_ids = spec_token_ids
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ def should_fill_bitmask(self, request: "Request") -> bool:
request.structured_output_request.reasoning_ended = (
self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
)

# Check if reasoning has actually ended by looking at tokens
# This handles async scheduling where flags might be stale
if self.reasoner.is_reasoning_end(request.all_token_ids):
request.structured_output_request.reasoning_ended = True
return request.structured_output_request.reasoning_ended
return True

Expand Down