diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0713aa8abdc2..83dd6478d912 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -503,6 +503,132 @@ def test_stop_via_update_from_output(): assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] +def test_reasoning_spec_decode_grammar(): + """ + Test for speculative decoding with reasoning parser and grammar constraints. + + This test validates the fix for the bug where combining reasoning parser, + structured output (JSON schema), and speculative decoding resulted in + invalid JSON output. + + Key behaviors tested: + - When reasoning_end is detected in spec tokens, ALL spec tokens are cleared. + - Grammar FSM does NOT accept reasoning_end token. + - Spec tokens validated even when reasoning_ended flag is stale. + """ + REASONING_END_TOKEN_ID = 151660 + + scheduler = create_scheduler(num_speculative_tokens=3) + requests = create_requests(num_requests=1, max_tokens=20) + request = requests[0] + + structured_params = StructuredOutputsParams( + json='{"type": "object", "properties": {"answer": {"type": "string"}}}' + ) + + from vllm.v1.structured_output.request import StructuredOutputRequest + + structured_req = StructuredOutputRequest(params=structured_params) + structured_req.reasoning_ended = False + + mock_grammar = Mock() + mock_grammar.is_terminated = Mock(return_value=False) + mock_grammar.validate_tokens = Mock(side_effect=lambda tokens: tokens) + mock_grammar.accept_tokens = Mock(return_value=True) + structured_req.grammar = mock_grammar + + request.structured_output_request = structured_req + + mock_reasoner = Mock() + mock_reasoner.is_reasoning_end = ( + lambda token_ids: REASONING_END_TOKEN_ID in token_ids + ) + mock_reasoner.is_reasoning_end_streaming = ( + lambda full_ids, delta_ids: REASONING_END_TOKEN_ID in delta_ids + ) + + def extract_content_ids(delta_ids): + if REASONING_END_TOKEN_ID not in delta_ids: + return [] + idx = delta_ids.index(REASONING_END_TOKEN_ID) + return delta_ids[idx + 1 :] + + mock_reasoner.extract_content_ids = extract_content_ids + scheduler.structured_output_manager.reasoner = mock_reasoner + scheduler.structured_output_manager.enable_in_reasoning = False + + request.num_computed_tokens = request.num_tokens + 3 + request.append_output_token_ids([100, 101, 102]) + + scheduler.requests[request.request_id] = request + scheduler.running.append(request) + request.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={request.request_id: 1}, + total_num_scheduled_tokens=1, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) + + # Test case 1: spec tokens cleared when reasoning_end detected in draft tokens + model_output = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[104]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(scheduler_output, model_output) + + draft_tokens = DraftTokenIds( + [request.request_id], [[105, REASONING_END_TOKEN_ID, 200, 201]] + ) + scheduler.update_draft_token_ids(draft_tokens) + + assert request.spec_token_ids == [] + + # Test case 2: drafted tokens after reasoning end are not accepted + scheduler_output2 = scheduler.schedule() + model_output2 = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[REASONING_END_TOKEN_ID, 200]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + mock_grammar.accept_tokens.reset_mock() + scheduler.update_from_output(scheduler_output2, model_output2) + + assert request.all_token_ids[-1] == REASONING_END_TOKEN_ID + + # Test case 3: grammar accepts content tokens after reasoning ended + scheduler_output3 = scheduler.schedule() + model_output3 = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[300]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + mock_grammar.accept_tokens.reset_mock() + scheduler.update_from_output(scheduler_output3, model_output3) + + assert mock_grammar.accept_tokens.call_count == 1 + accepted_tokens = mock_grammar.accept_tokens.call_args[0][1] + assert accepted_tokens == [300] + + def test_check_stop_min_tokens(): """Test that requests don't stop when min_tokens requirement isn't met.""" from vllm.v1.core.sched.utils import check_stop diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b2e09d2ffb74..4608e5ade40c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1349,6 +1349,34 @@ def update_from_output( kv_transfer_params = None status_before_stop = request.status + # If reasoning ends in this step and content tokens were emitted, + # drop those content tokens so constrained decoding can start + # in the next step. This prevents unconstrained prefixes like + # markdown fences when structured output is requested. + reasoner = self.structured_output_manager.reasoner + in_unconstraint_reasoning = ( + request.structured_output_request is not None + and not request.structured_output_request.reasoning_ended + and not self.structured_output_manager.enable_in_reasoning + ) + if ( + new_token_ids + and reasoner is not None + and in_unconstraint_reasoning + and reasoner.is_reasoning_end_streaming( + request.all_token_ids, new_token_ids + ) + ): + content_ids = reasoner.extract_content_ids(new_token_ids) + if ( + content_ids + and len(content_ids) <= len(new_token_ids) + and new_token_ids[-len(content_ids) :] == content_ids + ): + new_token_ids = new_token_ids[: -len(content_ids)] + # we've accepted the reasoning end token + request.structured_output_request.reasoning_ended = True # type: ignore[union-attr] + # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( @@ -1608,8 +1636,25 @@ def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None: request.spec_token_ids = [] continue + # For non-reasoning case or when reasoning has ended, validate all tokens, + # when reasoning end is deteted within spec tokens, clear them and + # allow grammar to activate cleanly for non-reasoning tokens + should_validate = self.structured_output_manager.should_advance(request) + reasoner = self.structured_output_manager.reasoner + if ( + not should_validate + and request.use_structured_output + and reasoner is not None + ): + for i, token_id in enumerate(spec_token_ids): + if reasoner.is_reasoning_end_streaming( + request.all_token_ids, spec_token_ids[: i + 1] + ): + spec_token_ids = [] + break + # Add newly generated spec token ids to the request. - if self.structured_output_manager.should_advance(request): + if should_validate and spec_token_ids: metadata = request.structured_output_request spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr] request.spec_token_ids = spec_token_ids diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 921bee6a647a..d614d708340f 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -296,6 +296,11 @@ def should_fill_bitmask(self, request: "Request") -> bool: request.structured_output_request.reasoning_ended = ( self.reasoner.is_reasoning_end(request.prompt_token_ids or []) ) + + # Check if reasoning has actually ended by looking at tokens + # This handles async scheduling where flags might be stale + if self.reasoner.is_reasoning_end(request.all_token_ids): + request.structured_output_request.reasoning_ended = True return request.structured_output_request.reasoning_ended return True