diff --git a/tests/v1/structured_output/test_reasoning_structured_output.py b/tests/v1/structured_output/test_reasoning_structured_output.py index 98a25e41dfe0..a9b7cd43b454 100644 --- a/tests/v1/structured_output/test_reasoning_structured_output.py +++ b/tests/v1/structured_output/test_reasoning_structured_output.py @@ -12,6 +12,12 @@ from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager +THINK_END_TOKEN = 99 +REASONING_TOKEN_A = 10 +REASONING_TOKEN_B = 11 +JSON_TOKEN_A = 20 +JSON_TOKEN_B = 21 + class TestReasoningStructuredOutput: """Test reasoning-aware structured output functionality.""" @@ -55,6 +61,11 @@ def mock_reasoning_parser(self): """Create a mock ReasoningParser.""" parser = Mock(spec=ReasoningParser) parser.is_reasoning_end = Mock(return_value=False) + + def mock_streaming(prefix, delta): + return THINK_END_TOKEN in delta + + parser.is_reasoning_end_streaming = mock_streaming return parser @pytest.fixture @@ -125,85 +136,278 @@ def test_should_fill_bitmask_no_reasoner( # Should default to True when no reasoner assert result is True - def test_should_advance_with_enable_in_reasoning( + def test_update_reasoning_ended_with_new_token_ids_mid_batch( self, mock_vllm_config, mock_request_with_structured_output, mock_reasoning_parser, ): - """Test should_advance when enable_in_reasoning is True.""" - # Enable enable_in_reasoning - mock_vllm_config.structured_outputs_config.enable_in_reasoning = True + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + struct_req = mock_request_with_structured_output.structured_output_request + struct_req.reasoning_ended = False + + new_token_ids = [REASONING_TOKEN_A, THINK_END_TOKEN, JSON_TOKEN_A] + + manager.update_reasoning_ended( + mock_request_with_structured_output, + new_token_ids=new_token_ids, + ) + assert struct_req.reasoning_ended is True + + def test_update_reasoning_ended_no_end_found( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): manager = StructuredOutputManager(mock_vllm_config) manager.reasoner = mock_reasoning_parser - # Should always return True when enable_in_reasoning is enabled - result = manager.should_advance(mock_request_with_structured_output) - assert result is True + struct_req = mock_request_with_structured_output.structured_output_request + struct_req.reasoning_ended = False + + manager.update_reasoning_ended( + mock_request_with_structured_output, + new_token_ids=[REASONING_TOKEN_A, REASONING_TOKEN_B], + ) - def test_should_advance_reasoning_not_ended( + assert struct_req.reasoning_ended is False + + def test_identify_constrained_draft_tokens_reasoning_ends_mid_draft( self, mock_vllm_config, mock_request_with_structured_output, mock_reasoning_parser, ): - """Test should_advance when reasoning has not ended.""" manager = StructuredOutputManager(mock_vllm_config) manager.reasoner = mock_reasoning_parser - # Set reasoning as not ended - ( - mock_request_with_structured_output.structured_output_request - ).reasoning_ended = False - mock_reasoning_parser.is_reasoning_end.return_value = False + struct_req = mock_request_with_structured_output.structured_output_request + struct_req.reasoning_ended = False + + draft_tokens = [ + REASONING_TOKEN_A, + THINK_END_TOKEN, + JSON_TOKEN_A, + JSON_TOKEN_B, + ] + unconstrained, constrained = manager.identify_constrained_draft_tokens( + mock_request_with_structured_output, draft_tokens + ) - result = manager.should_advance(mock_request_with_structured_output) + # Tokens up to and including THINK_END_TOKEN are unconstrained + assert unconstrained == [REASONING_TOKEN_A, THINK_END_TOKEN] + # Tokens after the marker are constrained + assert constrained == [JSON_TOKEN_A, JSON_TOKEN_B] - # Should return False since reasoning hasn't ended - assert result is False + def test_identify_constrained_draft_tokens_reasoning_already_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + struct_req = mock_request_with_structured_output.structured_output_request + struct_req.reasoning_ended = True + + # Include a spurious — it should NOT cause a split + draft_tokens = [ + JSON_TOKEN_A, + THINK_END_TOKEN, + JSON_TOKEN_B, + ] + unconstrained, constrained = manager.identify_constrained_draft_tokens( + mock_request_with_structured_output, draft_tokens + ) - def test_should_advance_reasoning_just_ended( + # When reasoning_ended=True, ALL draft tokens should be constrained + assert unconstrained == [] + assert constrained == [JSON_TOKEN_A, THINK_END_TOKEN, JSON_TOKEN_B] + + def test_validate_tokens_reasoning_aware_reasoning_ended( self, mock_vllm_config, mock_request_with_structured_output, mock_reasoning_parser, ): - """Test should_advance when reasoning ends in current step.""" manager = StructuredOutputManager(mock_vllm_config) manager.reasoner = mock_reasoning_parser - # Set reasoning as not ended initially, but ends in this step - ( - mock_request_with_structured_output.structured_output_request - ).reasoning_ended = False - mock_reasoning_parser.is_reasoning_end.return_value = True + struct_req = mock_request_with_structured_output.structured_output_request + struct_req.reasoning_ended = True + # Mock validate_tokens to return a subset (simulating rejection) + struct_req.grammar.validate_tokens = Mock(return_value=[JSON_TOKEN_A]) - result = manager.should_advance(mock_request_with_structured_output) + draft_tokens = [JSON_TOKEN_A, JSON_TOKEN_B] + result = manager.validate_tokens_reasoning_aware( + mock_request_with_structured_output, draft_tokens + ) - # Should set reasoning_ended to True but return False for this step - assert ( - mock_request_with_structured_output.structured_output_request.reasoning_ended - is True + # validate_tokens was called with all draft tokens (since reasoning_ended=True) + struct_req.grammar.validate_tokens.assert_called_once_with( + [JSON_TOKEN_A, JSON_TOKEN_B] ) - assert result is False + # result must not contain draft token JSON_TOKEN_B, which didn't match grammar + assert result == [JSON_TOKEN_A] - def test_should_advance_reasoning_already_ended( + def test_validate_tokens_reasoning_aware_reasoning_ends_mid_draft( self, mock_vllm_config, mock_request_with_structured_output, mock_reasoning_parser, ): - """Test should_advance when reasoning has already ended.""" manager = StructuredOutputManager(mock_vllm_config) manager.reasoner = mock_reasoning_parser - # Set reasoning as already ended - ( - mock_request_with_structured_output.structured_output_request - ).reasoning_ended = True + struct_req = mock_request_with_structured_output.structured_output_request + struct_req.reasoning_ended = False - result = manager.should_advance(mock_request_with_structured_output) + # Mock validate_tokens to return only the first valid token + struct_req.grammar.validate_tokens = Mock(return_value=[JSON_TOKEN_A]) - # Should return True since reasoning has ended - assert result is True + draft_tokens = [ + REASONING_TOKEN_A, + THINK_END_TOKEN, + JSON_TOKEN_A, + JSON_TOKEN_B, + ] + result = manager.validate_tokens_reasoning_aware( + mock_request_with_structured_output, draft_tokens + ) + + # Only post-reasoning tokens should be validated (since reasoning_ended=False) + struct_req.grammar.validate_tokens.assert_called_once_with( + [JSON_TOKEN_A, JSON_TOKEN_B] + ) + # Result: unconstrained prefix + validated suffix + assert result == [REASONING_TOKEN_A, THINK_END_TOKEN, JSON_TOKEN_A] + + def _make_manager_for_bitmask_test( + self, mock_vllm_config, mock_reasoning_parser, num_spec_tokens=5 + ): + """Helper: create a StructuredOutputManager wired up for bitmask + tests with a mock backend and pre-allocated bitmask tensor.""" + import torch + + mock_vllm_config.speculative_config = Mock() + mock_vllm_config.speculative_config.num_speculative_tokens = num_spec_tokens + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + max_entries = mock_vllm_config.scheduler_config.max_num_seqs * ( + 1 + num_spec_tokens + ) + manager._grammar_bitmask = torch.zeros(max_entries, 1, dtype=torch.int32) + manager.backend = Mock() + return manager + + def _make_grammar_mock(self): + """Helper: create a mock grammar that tracks calls.""" + grammar = Mock() + grammar.is_terminated.return_value = False + grammar.accept_tokens.return_value = True + grammar.fill_bitmask = Mock() + grammar.rollback = Mock() + return grammar + + def test_grammar_bitmask_reasoning_ends_mid_speculation( + self, + mock_vllm_config, + mock_reasoning_parser, + ): + """When reasoning_end appears mid-speculation, only post-reasoning + tokens should be grammar-constrained and accepted (then rolled + back).""" + manager = self._make_manager_for_bitmask_test( + mock_vllm_config, mock_reasoning_parser, num_spec_tokens=5 + ) + grammar = self._make_grammar_mock() + + mock_reasoning_parser.is_reasoning_end.return_value = False + + request = Mock(spec=Request) + request.structured_output_request = Mock() + request.structured_output_request.reasoning_ended = False + request.structured_output_request.grammar = grammar + request.use_structured_output = True + request.prompt_token_ids = [1, 2, 3] + + req_id = "req-mid-spec" + spec_tokens = [ + REASONING_TOKEN_A, + THINK_END_TOKEN, + JSON_TOKEN_A, + JSON_TOKEN_B, + ] + + result = manager.grammar_bitmask( + requests={req_id: request}, + structured_output_request_ids=[req_id], + scheduled_spec_decode_tokens={req_id: spec_tokens}, + ) + + assert result is not None + assert request.structured_output_request.reasoning_ended is False + + # Only post-reasoning tokens should have been accepted + accepted_tokens = [ + call[0][1][0] for call in grammar.accept_tokens.call_args_list + ] + assert accepted_tokens == [JSON_TOKEN_A, JSON_TOKEN_B] + + # Grammar advancements should be rolled back + if grammar.accept_tokens.call_count > 0: + grammar.rollback.assert_called_once_with(grammar.accept_tokens.call_count) + + def test_grammar_bitmask_all_constrained_when_reasoning_ended( + self, + mock_vllm_config, + mock_reasoning_parser, + ): + """After reasoning ended, ALL bitmask positions must be grammar-constrained""" + manager = self._make_manager_for_bitmask_test( + mock_vllm_config, mock_reasoning_parser, num_spec_tokens=5 + ) + grammar = self._make_grammar_mock() + + mock_reasoning_parser.is_reasoning_end.return_value = True + + request = Mock(spec=Request) + request.structured_output_request = Mock() + request.structured_output_request.reasoning_ended = True + request.structured_output_request.grammar = grammar + request.use_structured_output = True + request.prompt_token_ids = [1, 2, 3] + + req_id = "req-root-cause-bitmask" + # Spurious at index 3 in the draft tokens + spec_tokens = [ + JSON_TOKEN_A, + JSON_TOKEN_B, + REASONING_TOKEN_A, + THINK_END_TOKEN, + JSON_TOKEN_A, + ] + + result = manager.grammar_bitmask( + requests={req_id: request}, + structured_output_request_ids=[req_id], + scheduled_spec_decode_tokens={req_id: spec_tokens}, + ) + + assert result is not None + + # ALL spec_tokens should have been accepted, since reasoning_ended=True + accepted_tokens = [ + call[0][1][0] for call in grammar.accept_tokens.call_args_list + ] + assert accepted_tokens == spec_tokens + + # All state advancements should be rolled back + if grammar.accept_tokens.call_count > 0: + grammar.rollback.assert_called_once_with(grammar.accept_tokens.call_count) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cd9be17f6dac..217b8d5f8840 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1402,22 +1402,35 @@ def update_from_output( request.status = RequestStatus.FINISHED_STOPPED stopped = True - if new_token_ids and self.structured_output_manager.should_advance(request): - struct_output_request = request.structured_output_request - assert struct_output_request is not None - assert struct_output_request.grammar is not None - if not struct_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids - ): - logger.error( - "Unexpected: grammar rejected tokens %s for request %s. " - "Terminating request.", - new_token_ids, - req_id, + if new_token_ids and self.structured_output_manager is not None: + # When reasoning ends within this token batch (e.g. during + # speculative decoding), only the tokens after the + # reasoning_end marker should be fed to the grammar. + _, tokens_for_grammar = ( + self.structured_output_manager.identify_constrained_draft_tokens( + request, new_token_ids + ) + ) + if tokens_for_grammar: + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + ok = struct_output_request.grammar.accept_tokens( + req_id, tokens_for_grammar ) - request.status = RequestStatus.FINISHED_ERROR - request.resumable = False - stopped = True + if not ok: + logger.warning( + "Unexpected: grammar rejected tokens %s for request %s.", + tokens_for_grammar, + req_id, + ) + request.status = RequestStatus.FINISHED_ERROR + request.resumable = False + stopped = True + # Update reasoning_ended state based on accepted tokens. + self.structured_output_manager.update_reasoning_ended( + request, new_token_ids=new_token_ids + ) routed_experts = None finish_reason = None @@ -1684,9 +1697,12 @@ def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None: continue # Add newly generated spec token ids to the request. - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr] + # Validate spec tokens against grammar if applicable. + spec_token_ids = ( + self.structured_output_manager.validate_tokens_reasoning_aware( + request, spec_token_ids + ) + ) request.spec_token_ids = spec_token_ids def update_draft_token_ids_in_output( @@ -1713,10 +1729,11 @@ def update_draft_token_ids_in_output( # (needed for chunked prefill case for example). del spec_token_ids[orig_num_spec_tokens:] # Filter out spec tokens which do not adhere to the grammar. - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - assert metadata is not None and metadata.grammar is not None - spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) + spec_token_ids = ( + self.structured_output_manager.validate_tokens_reasoning_aware( + request, spec_token_ids + ) + ) # Pad to original number of spec tokens. num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids) if num_invalid_tokens: diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 213b49f28d91..769dbb3457b1 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -257,16 +257,39 @@ def grammar_bitmask( grammar = structured_output_request.grammar apply_bitmask = self.should_fill_bitmask(request) - state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, ()) - for token in itertools.chain(req_tokens, (-1,)): - self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),)) + reasoning_end_idx: int | None = None + if req_tokens and not apply_bitmask: + # When reasoning hasn't already ended (apply_bitmask=False), + # only positions after reasoning_end must be constrained. + reasoning_end_idx = self._find_reasoning_end_in_tokens( + list(req_tokens) + ) + + state_advancements = 0 + for tok_idx, token in enumerate(itertools.chain(req_tokens, (-1,))): + # Tokens up to and including reasoning_end are + # unconstrained; tokens after are grammar-constrained. + if reasoning_end_idx is not None: + is_post_reasoning = tok_idx > reasoning_end_idx + pos_apply_bitmask = is_post_reasoning + else: + pos_apply_bitmask = apply_bitmask + + self._fill_bitmasks( + ((grammar, cumulative_index, pos_apply_bitmask),) + ) if token == -1: - # Stop advancing the grammar once we hit a padding token. - apply_bitmask = False - if apply_bitmask and not grammar.is_terminated(): + # Stop advancing the grammar once we hit a + # padding token. + pos_apply_bitmask = False + if pos_apply_bitmask and not grammar.is_terminated(): accepted = grammar.accept_tokens(req_id, [token]) - assert accepted, (token, req_id, scheduled_spec_decode_tokens) + assert accepted, ( + token, + req_id, + scheduled_spec_decode_tokens, + ) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: @@ -300,42 +323,106 @@ def should_fill_bitmask(self, request: "Request") -> bool: return request.structured_output_request.reasoning_ended return True - def should_advance(self, request: "Request") -> bool: + def update_reasoning_ended( + self, + request: "Request", + new_token_ids: list[int], + ) -> None: + """Update the reasoning_ended flag based on accepted tokens.""" if not request.use_structured_output: - return False - - # To determine whether we can advance the FSM. - # Supports thinking usage where we skip the reasoning components. - if TYPE_CHECKING: - assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None - # by default, we should always advance - # for cases that don't use thinking mode. - if self.reasoner is None: - return True + return - # if the model needs structured in reasoning, we should advance - if self.enable_in_reasoning: - return True + if self.reasoner is None or self.enable_in_reasoning: + return structured_req = request.structured_output_request + assert structured_req is not None if structured_req.reasoning_ended: - return True + return - # Check if reasoning ends in *this* step - delta_from = request.num_computed_tokens - request.num_output_placeholders all_token_ids = request.all_token_ids - start = ( - delta_from if delta_from >= 0 else max(len(all_token_ids) + delta_from, 0) - ) - if self.reasoner.is_reasoning_end_streaming( - all_token_ids, itertools.islice(all_token_ids, start, None) - ): - # Reasoning just ended, so we shouldn't advance til - # next pass + if self.reasoner.is_reasoning_end_streaming(all_token_ids, new_token_ids): structured_req.reasoning_ended = True - return False + def validate_tokens_reasoning_aware( + self, request: "Request", spec_token_ids: list[int] + ) -> list[int]: + """Validate speculative tokens against the grammar, handling + reasoning-end markers. + """ + unconstrained_tokens, constrained_tokens = ( + self.identify_constrained_draft_tokens(request, spec_token_ids) + ) + if constrained_tokens: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + grammar = request.structured_output_request.grammar + grammar_validated_tokens = grammar.validate_tokens(constrained_tokens) + return unconstrained_tokens + grammar_validated_tokens + return spec_token_ids + + def identify_constrained_draft_tokens( + self, request: "Request", spec_token_ids: list[int] + ) -> tuple[list[int], list[int]]: + """Identify which draft tokens need to be constrained by grammar, + taking mid-batch reasoning-end markers correctly into account. + + Returns: + tuple of (unconstrained draft tokens, constrained draft tokens) + """ + if not request.use_structured_output: + unconstrained_tokens = spec_token_ids + return unconstrained_tokens, [] + + if self.reasoner is None: + constrained_tokens = spec_token_ids + return [], constrained_tokens + + if self.enable_in_reasoning: + constrained_tokens = spec_token_ids + return [], constrained_tokens + + structured_output_request = request.structured_output_request + assert structured_output_request is not None + assert structured_output_request.grammar is not None + + # When reasoning already ended, validate ALL draft tokens. + if structured_output_request.reasoning_ended: + constrained_tokens = spec_token_ids + return [], constrained_tokens + + # Reasoning hasn't ended yet — check if it ends mid-draft. + split_idx = self._find_reasoning_end_in_tokens(spec_token_ids) + if split_idx is None: + unconstrained_tokens = spec_token_ids + return unconstrained_tokens, [] + + # validate only tokens after reasoning_end marker; + # pass tokens up to reasoning_end through unvalidated + unconstrained_tokens = spec_token_ids[: split_idx + 1] + constrained_tokens = spec_token_ids[split_idx + 1 :] + return unconstrained_tokens, constrained_tokens + + def _find_reasoning_end_in_tokens(self, token_ids: list[int]) -> int | None: + """Find the index of the reasoning-end token within a token list. + + Uses is_reasoning_end_streaming to check progressively longer + prefixes, supporting multi-token end markers. + + Returns: + The index of the last token of the reasoning-end marker, + or None if not found. + """ + if self.reasoner is None or self.enable_in_reasoning: + return None + + for i, token in enumerate(token_ids): + # Check if reasoning ends at position i by testing the + # prefix up to and including this token. + prefix = token_ids[: i + 1] + if self.reasoner.is_reasoning_end_streaming(prefix, [token]): + return i + return None def clear_backend(self) -> None: if self.backend is not None: