From 555c67157e7515c71ab9abf4464346b2bb8710b3 Mon Sep 17 00:00:00 2001 From: SvenLorenz Date: Tue, 26 May 2026 18:22:49 +0200 Subject: [PATCH] [SpecDec + Reasoning] Fix race condition when reasoning-end token appears as a rejected draft token When speculative decoding generates the reasoning-end token as a draft token that gets rejected, the old code unconditionally set reasoning_ended=True and force-fed the unconstrained bonus token to the grammar, corrupting its state. All subsequent outputs were then constrained by a corrupted grammar. Fix: - Detect mid-draft-batch in grammar_bitmask() and set bonus_requires_grammar=True to flag that the bonus slot should get a constrained bitmask and the grammar needs advancing - In update_from_output(), only mark reasoning as ended and advance the grammar when the bonus token is actually accepted (meaning the reasoning-end draft token was accepted by spec decode) - When the bonus token is rejected, leave reasoning_ended=False so the model continues generating reasoning text naturally - Add suppress_accept_errors flag to avoid ERROR-level log spam from expected grammar rejections in this path Signed-off-by: SvenLorenz --- .../test_reasoning_structured_output.py | 281 ++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 15 + vllm/v1/structured_output/__init__.py | 38 ++- vllm/v1/structured_output/backend_types.py | 8 + vllm/v1/structured_output/backend_xgrammar.py | 20 +- vllm/v1/structured_output/request.py | 5 + 6 files changed, 354 insertions(+), 13 deletions(-) diff --git a/tests/v1/structured_output/test_reasoning_structured_output.py b/tests/v1/structured_output/test_reasoning_structured_output.py index 861e919c102a..c945e5de46e9 100644 --- a/tests/v1/structured_output/test_reasoning_structured_output.py +++ b/tests/v1/structured_output/test_reasoning_structured_output.py @@ -60,11 +60,13 @@ def mock_vllm_config(self, mock_model_config, mock_scheduler_config): def mock_request_with_structured_output(self): """Create a mock request with structured output.""" request = Mock(spec=Request) + request.request_id = "test-request-123" request.structured_output_request = Mock() request.structured_output_request.reasoning_ended = None request.structured_output_request.grammar = Mock() request.structured_output_request.reasoning_parser_kwargs = None request.structured_output_request.reasoner = None + request.structured_output_request.bonus_requires_grammar = False request.structured_output_request.grammar.is_terminated = Mock( return_value=False ) @@ -258,3 +260,282 @@ def test_should_advance_reasoning_already_ended( # Should return True since reasoning has ended assert result is True + + def _make_detecting_reasoner(self, end_token_id: int = 99): + """Return a reasoner whose is_reasoning_end_streaming detects + ``end_token_id`` and records every call in ``detected_tokens``.""" + reasoner = MockReasoner(tokenizer=Mock()) + reasoner.detected_tokens = [] + + def side_effect(input_ids, delta_ids): + delta_list = ( + list(delta_ids) if not isinstance(delta_ids, list) else delta_ids + ) + reasoner.detected_tokens.append(delta_list) + return end_token_id in delta_list + + reasoner.is_reasoning_end_streaming = Mock(side_effect=side_effect) + return reasoner + + def _make_mock_grammar(self, accept_result: bool = True): + grammar = Mock() + grammar.is_terminated = Mock(return_value=False) + grammar.fill_bitmask = Mock() + grammar.accept_tokens = Mock(return_value=accept_result) + grammar.rollback = Mock() + return grammar + + def _setup_manager_backend(self, manager): + import torch + + manager.backend = Mock() + manager.backend.allocate_token_bitmask = Mock( + return_value=torch.zeros((10, 50000), dtype=torch.int32) + ) + manager._full_mask = torch.tensor(-1, dtype=torch.int32) + + # ------------------------------------------------------------------ # + # Test: reasoning ends at the LAST draft token # + # Draft: [10, 20, 30, 99] ← 99 = reasoning-end # + # Expect: bonus slot (idx 4) gets constrained bitmask # + # ------------------------------------------------------------------ # + + def test_grammar_bitmask_reasoning_ends_mid_batch( + self, + manager_with_reasoner, + mock_request_with_structured_output, + ): + """Grammar_bitmask constrains bonus token when reasoning-end is + the last draft token.""" + + structured_req = mock_request_with_structured_output.structured_output_request + structured_req.reasoning_ended = False + + reasoner = self._make_detecting_reasoner(end_token_id=99) + structured_req.reasoner = reasoner + + grammar = self._make_mock_grammar() + structured_req.grammar = grammar + + self._setup_manager_backend(manager_with_reasoner) + + requests = { + mock_request_with_structured_output.request_id: mock_request_with_structured_output + } + scheduled_spec_decode_tokens = { + mock_request_with_structured_output.request_id: [10, 20, 30, 99] + } + + manager_with_reasoner.grammar_bitmask( + requests, + [mock_request_with_structured_output.request_id], + scheduled_spec_decode_tokens, + ) + + # --- assertions --- + + # is_reasoning_end_streaming was called per draft token + assert reasoner.is_reasoning_end_streaming.called + + # reasoning_ended flag is NOT set by grammar_bitmask — that is + # managed by should_advance() post-batch to avoid prematurely + # claiming reasoning is done while the current step's output + # still contains the reasoning-end token. + assert structured_req.reasoning_ended is False + + # bonus_requires_grammar flag IS set — tells update_from_output() + # to advance the grammar with the bonus token even though + # should_advance() will return False this step. + assert structured_req.bonus_requires_grammar is True + + # fill_bitmask should have been called exactly once — for the + # bonus token position (index 4). Positions 0-3 were filled + # with the full mask (no constraint). + fill_calls = grammar.fill_bitmask.call_args_list + assert len(fill_calls) == 1, ( + f"Expected 1 fill_bitmask (bonus pos), got {len(fill_calls)}" + ) + # The second argument is the index into the bitmask tensor + fill_index = fill_calls[0][0][1] + assert fill_index == 4, f"Expected bonus position index 4, got {fill_index}" + + # accept_tokens should NOT have been called (only the bonus + # position had the bitmask enabled, and its token is the + # sentinel -1 which is handled before the accept call). + grammar.accept_tokens.assert_not_called() + + # ------------------------------------------------------------------ # + # Test: reasoning ends MID-DRAFT (not at last position) # + # Draft: [10, 99, 30, 40] ← 99 = reasoning-end at pos 1 # + # Expect: positions 2, 3, and bonus (idx 4) constrained # + # NOTE: positions 2-3 are draft tokens generated WITHOUT the # + # constraint; accept_tokens is skipped to avoid xgrammar # + # state corruption risk. # + # ------------------------------------------------------------------ # + + def test_grammar_bitmask_reasoning_ends_mid_draft( + self, + manager_with_reasoner, + mock_request_with_structured_output, + ): + """Grammar_bitmask constrains remaining draft + bonus when + reasoning-end appears mid-draft (not at the last position).""" + + structured_req = mock_request_with_structured_output.structured_output_request + structured_req.reasoning_ended = False + + reasoner = self._make_detecting_reasoner(end_token_id=99) + structured_req.reasoner = reasoner + + grammar = self._make_mock_grammar(accept_result=False) + structured_req.grammar = grammar + + self._setup_manager_backend(manager_with_reasoner) + + requests = { + mock_request_with_structured_output.request_id: mock_request_with_structured_output + } + # reasoning-end (99) at position 1, two more draft tokens follow + scheduled_spec_decode_tokens = { + mock_request_with_structured_output.request_id: [10, 99, 30, 40] + } + + manager_with_reasoner.grammar_bitmask( + requests, + [mock_request_with_structured_output.request_id], + scheduled_spec_decode_tokens, + ) + + # --- assertions --- + + assert reasoner.is_reasoning_end_streaming.called + # reasoning_ended is NOT set by grammar_bitmask — done by + # should_advance() post-batch to avoid feeding reasoning-end + # tokens into the grammar FSM in the same step. + assert structured_req.reasoning_ended is False + + # bonus_requires_grammar IS set + assert structured_req.bonus_requires_grammar is True + + # fill_bitmask for constrained positions: indices 2, 3, 4 + # (draft after end, last draft after end, bonus) + fill_calls = grammar.fill_bitmask.call_args_list + fill_indices = sorted(c[0][1] for c in fill_calls) + assert fill_indices == [2, 3, 4], ( + f"Expected constrained indices [2, 3, 4], got {fill_indices}" + ) + + # accept_tokens is NOT called for any position — draft tokens + # after reasoning-end were generated without constraint and + # are skipped to avoid xgrammar state corruption risk. + grammar.accept_tokens.assert_not_called() + + # rollback is not called (state_advancements stayed at 0) + grammar.rollback.assert_not_called() + + # ------------------------------------------------------------------ # + # Test: NO reasoning-end token in the batch # + # Draft: [10, 20, 30, 40] # + # Expect: everything unconstrained (existing behaviour) # + # ------------------------------------------------------------------ # + + def test_grammar_bitmask_no_reasoning_end( + self, + manager_with_reasoner, + mock_request_with_structured_output, + ): + """Grammar_bitmask leaves everything unconstrained when no + reasoning-end token appears in the draft tokens.""" + + structured_req = mock_request_with_structured_output.structured_output_request + structured_req.reasoning_ended = False + + reasoner = self._make_detecting_reasoner(end_token_id=99) + structured_req.reasoner = reasoner + + grammar = self._make_mock_grammar() + structured_req.grammar = grammar + + self._setup_manager_backend(manager_with_reasoner) + + requests = { + mock_request_with_structured_output.request_id: mock_request_with_structured_output + } + # No reasoning-end token (99) in draft + scheduled_spec_decode_tokens = { + mock_request_with_structured_output.request_id: [10, 20, 30, 40] + } + + manager_with_reasoner.grammar_bitmask( + requests, + [mock_request_with_structured_output.request_id], + scheduled_spec_decode_tokens, + ) + + # --- assertions --- + + # is_reasoning_end_streaming was called but never returned True + assert reasoner.is_reasoning_end_streaming.called + + # reasoning_ended flag is still False + assert structured_req.reasoning_ended is False + + # fill_bitmask was never called — all positions stayed + # unconstrained (fill with full mask instead) + grammar.fill_bitmask.assert_not_called() + + # accept_tokens was never called + grammar.accept_tokens.assert_not_called() + + # ------------------------------------------------------------------ # + # Test: reasoning already ended BEFORE this batch # + # Draft: [10, 20, 30, 40] with reasoning_ended=True # + # Expect: all positions constrained normally # + # ------------------------------------------------------------------ # + + def test_grammar_bitmask_reasoning_already_ended( + self, + manager_with_reasoner, + mock_request_with_structured_output, + ): + """Grammar_bitmask constrains everything when reasoning already + ended before this batch.""" + + structured_req = mock_request_with_structured_output.structured_output_request + structured_req.reasoning_ended = True # already ended + + grammar = self._make_mock_grammar() + structured_req.grammar = grammar + + self._setup_manager_backend(manager_with_reasoner) + + requests = { + mock_request_with_structured_output.request_id: mock_request_with_structured_output + } + scheduled_spec_decode_tokens = { + mock_request_with_structured_output.request_id: [10, 20, 30, 40] + } + + manager_with_reasoner.grammar_bitmask( + requests, + [mock_request_with_structured_output.request_id], + scheduled_spec_decode_tokens, + ) + + # --- assertions --- + + # fill_bitmask was called for every position (4 draft + 1 bonus) + fill_calls = grammar.fill_bitmask.call_args_list + assert len(fill_calls) == 5, ( + f"Expected 5 fill_bitmask calls, got {len(fill_calls)}" + ) + + # accept_tokens was called for each draft token only (4). The + # bonus position uses sentinel -1 which sets apply_bitmask=False + # before the accept check, so no accept call for the bonus. + assert grammar.accept_tokens.call_count == 4, ( + f"Expected 4 accept_tokens calls, got {grammar.accept_tokens.call_count}" + ) + + # rollback was called (state_advancements == 4 > 0) + grammar.rollback.assert_called_once() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c69c9a8119ab..4ff27065c1ed 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1429,6 +1429,21 @@ def update_from_output( request.status = RequestStatus.FINISHED_ERROR request.resumable = False stopped = True + elif ( + request.structured_output_request is not None + and request.structured_output_request.bonus_requires_grammar + and generated_token_ids + ): + struct_req = request.structured_output_request + assert struct_req.grammar is not None + struct_req.grammar.suppress_accept_errors = True + accepted = struct_req.grammar.accept_tokens( + req_id, [generated_token_ids[-1]] + ) + struct_req.grammar.suppress_accept_errors = False + struct_req.bonus_requires_grammar = False + if accepted: + struct_req.reasoning_ended = True routed_experts = None if ( diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 6a4fcbb629ff..ee2d7decb099 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -191,9 +191,6 @@ def _fill_bitmasks( if apply_bitmask and not grammar.is_terminated(): grammar.fill_bitmask(self._grammar_bitmask, index) else: - # Note that for thinking support, we will need to - # reset the relevant part of the bitmask for consequent - # requests here. self._grammar_bitmask[index].fill_(self._full_mask) def _async_submit_fill_bitmask( @@ -275,17 +272,44 @@ def grammar_bitmask( grammar = structured_output_request.grammar apply_bitmask = self.should_fill_bitmask(request) + # Per-token reasoning-end detection: when apply_bitmask is + # False because reasoning is active, we need to watch for the + # end-of-reasoning token within the draft batch. When found, + # flip apply_bitmask so subsequent positions (including the + # bonus token) get the constrained bitmask. + reasoner = ( + self._get_reasoner(request) + if not apply_bitmask and not self.enable_in_reasoning + else None + ) + reasoning_ended_in_batch = False + state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, ()) for token in itertools.chain(req_tokens, (-1,)): self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),)) if token == -1: - # Stop advancing the grammar once we hit a padding token. apply_bitmask = False if apply_bitmask and not grammar.is_terminated(): - accepted = grammar.accept_tokens(req_id, [token]) - assert accepted, (token, req_id, scheduled_spec_decode_tokens) - state_advancements += 1 + if not reasoning_ended_in_batch: + accepted = grammar.accept_tokens(req_id, [token]) + assert accepted, ( + token, + req_id, + scheduled_spec_decode_tokens, + ) + state_advancements += 1 + elif ( + token != -1 + and reasoner is not None + and reasoner.is_reasoning_end_streaming( + request.all_token_ids, [token] + ) + ): + apply_bitmask = True + reasoning_ended_in_batch = True + reasoner = None + structured_output_request.bonus_requires_grammar = True cumulative_index += 1 if state_advancements > 0: grammar.rollback(state_advancements) diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 5c09b7b0634f..0536f6fd3bd8 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -94,6 +94,14 @@ def reset(self): Resets the state of the structured output grammar. """ + @property + def suppress_accept_errors(self) -> bool: + return False + + @suppress_accept_errors.setter + def suppress_accept_errors(self, value: bool) -> None: + pass + @dataclass class StructuredOutputBackend(ABC): diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index a92be3d44320..a86be932ea8e 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -144,6 +144,13 @@ class XgrammarGrammar(StructuredOutputGrammar): default_factory=lambda: 0, repr=False, hash=False, init=False ) _is_terminated: bool = field(default=False, repr=False, hash=False) + # When True, accept_tokens() will not log errors on rejection. + # Used to suppress expected failures e.g. when a bonus token generated + # without grammar constraint is force-fed in the bonus_requires_grammar + # path after the spec-dec draft token was rejected. + suppress_accept_errors: bool = field( + default=False, repr=False, hash=False, init=False + ) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -155,12 +162,13 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: return False for token in tokens: if not self.matcher.accept_token(token): - logger.error( - "Failed to advance FSM for request %s " - "for tokens %s. Please file an issue.", - request_id, - token, - ) + if not self.suppress_accept_errors: + logger.error( + "Failed to advance FSM for request %s " + "for tokens %s. Please file an issue.", + request_id, + token, + ) return False self.num_processed_tokens += 1 self._is_terminated = self.matcher.is_terminated() diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index dfa8c7efcae4..e97a36976310 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -27,6 +27,11 @@ class StructuredOutputRequest: # Cached per request; do not share reasoning parsers across requests because # their behavior can depend on reasoning_parser_kwargs. reasoner: "ReasoningParser | None" = None + # When reasoning ends mid-draft-batch during speculative decoding, the bonus + # token is constrained by the grammar bitmask but should_advance() returns + # False in the same step (because reasoning just ended). This flag tells + # update_from_output() to advance the grammar with the bonus token anyway. + bonus_requires_grammar: bool = False @staticmethod def from_sampling_params(