Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions tests/v1/structured_output/test_reasoning_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ def mock_vllm_config(self, mock_model_config, mock_scheduler_config):
def mock_request_with_structured_output(self):
"""Create a mock request with structured output."""
request = Mock(spec=Request)
request.request_id = "test-request-123"
request.structured_output_request = Mock()
request.structured_output_request.reasoning_ended = None
request.structured_output_request.grammar = Mock()
request.structured_output_request.reasoning_parser_kwargs = None
request.structured_output_request.reasoner = None
request.structured_output_request.bonus_requires_grammar = False
request.structured_output_request.grammar.is_terminated = Mock(
return_value=False
)
Expand Down Expand Up @@ -258,3 +260,282 @@ def test_should_advance_reasoning_already_ended(

# Should return True since reasoning has ended
assert result is True

def _make_detecting_reasoner(self, end_token_id: int = 99):
"""Return a reasoner whose is_reasoning_end_streaming detects
``end_token_id`` and records every call in ``detected_tokens``."""
reasoner = MockReasoner(tokenizer=Mock())
reasoner.detected_tokens = []

def side_effect(input_ids, delta_ids):
delta_list = (
list(delta_ids) if not isinstance(delta_ids, list) else delta_ids
)
reasoner.detected_tokens.append(delta_list)
return end_token_id in delta_list

reasoner.is_reasoning_end_streaming = Mock(side_effect=side_effect)
return reasoner

def _make_mock_grammar(self, accept_result: bool = True):
grammar = Mock()
grammar.is_terminated = Mock(return_value=False)
grammar.fill_bitmask = Mock()
grammar.accept_tokens = Mock(return_value=accept_result)
grammar.rollback = Mock()
return grammar

def _setup_manager_backend(self, manager):
import torch

manager.backend = Mock()
manager.backend.allocate_token_bitmask = Mock(
return_value=torch.zeros((10, 50000), dtype=torch.int32)
)
manager._full_mask = torch.tensor(-1, dtype=torch.int32)

# ------------------------------------------------------------------ #
# Test: reasoning ends at the LAST draft token #
# Draft: [10, 20, 30, 99] ← 99 = reasoning-end #
# Expect: bonus slot (idx 4) gets constrained bitmask #
# ------------------------------------------------------------------ #

def test_grammar_bitmask_reasoning_ends_mid_batch(
self,
manager_with_reasoner,
mock_request_with_structured_output,
):
"""Grammar_bitmask constrains bonus token when reasoning-end is
the last draft token."""

structured_req = mock_request_with_structured_output.structured_output_request
structured_req.reasoning_ended = False

reasoner = self._make_detecting_reasoner(end_token_id=99)
structured_req.reasoner = reasoner

grammar = self._make_mock_grammar()
structured_req.grammar = grammar

self._setup_manager_backend(manager_with_reasoner)

requests = {
mock_request_with_structured_output.request_id: mock_request_with_structured_output
}
scheduled_spec_decode_tokens = {
mock_request_with_structured_output.request_id: [10, 20, 30, 99]
}

manager_with_reasoner.grammar_bitmask(
requests,
[mock_request_with_structured_output.request_id],
scheduled_spec_decode_tokens,
)

# --- assertions ---

# is_reasoning_end_streaming was called per draft token
assert reasoner.is_reasoning_end_streaming.called

# reasoning_ended flag is NOT set by grammar_bitmask — that is
# managed by should_advance() post-batch to avoid prematurely
# claiming reasoning is done while the current step's output
# still contains the reasoning-end token.
assert structured_req.reasoning_ended is False

# bonus_requires_grammar flag IS set — tells update_from_output()
# to advance the grammar with the bonus token even though
# should_advance() will return False this step.
assert structured_req.bonus_requires_grammar is True

# fill_bitmask should have been called exactly once — for the
# bonus token position (index 4). Positions 0-3 were filled
# with the full mask (no constraint).
fill_calls = grammar.fill_bitmask.call_args_list
assert len(fill_calls) == 1, (
f"Expected 1 fill_bitmask (bonus pos), got {len(fill_calls)}"
)
# The second argument is the index into the bitmask tensor
fill_index = fill_calls[0][0][1]
assert fill_index == 4, f"Expected bonus position index 4, got {fill_index}"

# accept_tokens should NOT have been called (only the bonus
# position had the bitmask enabled, and its token is the
# sentinel -1 which is handled before the accept call).
grammar.accept_tokens.assert_not_called()

# ------------------------------------------------------------------ #
# Test: reasoning ends MID-DRAFT (not at last position) #
# Draft: [10, 99, 30, 40] ← 99 = reasoning-end at pos 1 #
# Expect: positions 2, 3, and bonus (idx 4) constrained #
# NOTE: positions 2-3 are draft tokens generated WITHOUT the #
# constraint; accept_tokens is skipped to avoid xgrammar #
# state corruption risk. #
# ------------------------------------------------------------------ #

def test_grammar_bitmask_reasoning_ends_mid_draft(
self,
manager_with_reasoner,
mock_request_with_structured_output,
):
"""Grammar_bitmask constrains remaining draft + bonus when
reasoning-end appears mid-draft (not at the last position)."""

structured_req = mock_request_with_structured_output.structured_output_request
structured_req.reasoning_ended = False

reasoner = self._make_detecting_reasoner(end_token_id=99)
structured_req.reasoner = reasoner

grammar = self._make_mock_grammar(accept_result=False)
structured_req.grammar = grammar

self._setup_manager_backend(manager_with_reasoner)

requests = {
mock_request_with_structured_output.request_id: mock_request_with_structured_output
}
# reasoning-end (99) at position 1, two more draft tokens follow
scheduled_spec_decode_tokens = {
mock_request_with_structured_output.request_id: [10, 99, 30, 40]
}

manager_with_reasoner.grammar_bitmask(
requests,
[mock_request_with_structured_output.request_id],
scheduled_spec_decode_tokens,
)

# --- assertions ---

assert reasoner.is_reasoning_end_streaming.called
# reasoning_ended is NOT set by grammar_bitmask — done by
# should_advance() post-batch to avoid feeding reasoning-end
# tokens into the grammar FSM in the same step.
assert structured_req.reasoning_ended is False

# bonus_requires_grammar IS set
assert structured_req.bonus_requires_grammar is True

# fill_bitmask for constrained positions: indices 2, 3, 4
# (draft after end, last draft after end, bonus)
fill_calls = grammar.fill_bitmask.call_args_list
fill_indices = sorted(c[0][1] for c in fill_calls)
assert fill_indices == [2, 3, 4], (
f"Expected constrained indices [2, 3, 4], got {fill_indices}"
)

# accept_tokens is NOT called for any position — draft tokens
# after reasoning-end were generated without constraint and
# are skipped to avoid xgrammar state corruption risk.
grammar.accept_tokens.assert_not_called()

# rollback is not called (state_advancements stayed at 0)
grammar.rollback.assert_not_called()

# ------------------------------------------------------------------ #
# Test: NO reasoning-end token in the batch #
# Draft: [10, 20, 30, 40] #
# Expect: everything unconstrained (existing behaviour) #
# ------------------------------------------------------------------ #

def test_grammar_bitmask_no_reasoning_end(
self,
manager_with_reasoner,
mock_request_with_structured_output,
):
"""Grammar_bitmask leaves everything unconstrained when no
reasoning-end token appears in the draft tokens."""

structured_req = mock_request_with_structured_output.structured_output_request
structured_req.reasoning_ended = False

reasoner = self._make_detecting_reasoner(end_token_id=99)
structured_req.reasoner = reasoner

grammar = self._make_mock_grammar()
structured_req.grammar = grammar

self._setup_manager_backend(manager_with_reasoner)

requests = {
mock_request_with_structured_output.request_id: mock_request_with_structured_output
}
# No reasoning-end token (99) in draft
scheduled_spec_decode_tokens = {
mock_request_with_structured_output.request_id: [10, 20, 30, 40]
}

manager_with_reasoner.grammar_bitmask(
requests,
[mock_request_with_structured_output.request_id],
scheduled_spec_decode_tokens,
)

# --- assertions ---

# is_reasoning_end_streaming was called but never returned True
assert reasoner.is_reasoning_end_streaming.called

# reasoning_ended flag is still False
assert structured_req.reasoning_ended is False

# fill_bitmask was never called — all positions stayed
# unconstrained (fill with full mask instead)
grammar.fill_bitmask.assert_not_called()

# accept_tokens was never called
grammar.accept_tokens.assert_not_called()

# ------------------------------------------------------------------ #
# Test: reasoning already ended BEFORE this batch #
# Draft: [10, 20, 30, 40] with reasoning_ended=True #
# Expect: all positions constrained normally #
# ------------------------------------------------------------------ #

def test_grammar_bitmask_reasoning_already_ended(
self,
manager_with_reasoner,
mock_request_with_structured_output,
):
"""Grammar_bitmask constrains everything when reasoning already
ended before this batch."""

structured_req = mock_request_with_structured_output.structured_output_request
structured_req.reasoning_ended = True # already ended

grammar = self._make_mock_grammar()
structured_req.grammar = grammar

self._setup_manager_backend(manager_with_reasoner)

requests = {
mock_request_with_structured_output.request_id: mock_request_with_structured_output
}
scheduled_spec_decode_tokens = {
mock_request_with_structured_output.request_id: [10, 20, 30, 40]
}

manager_with_reasoner.grammar_bitmask(
requests,
[mock_request_with_structured_output.request_id],
scheduled_spec_decode_tokens,
)

# --- assertions ---

# fill_bitmask was called for every position (4 draft + 1 bonus)
fill_calls = grammar.fill_bitmask.call_args_list
assert len(fill_calls) == 5, (
f"Expected 5 fill_bitmask calls, got {len(fill_calls)}"
)

# accept_tokens was called for each draft token only (4). The
# bonus position uses sentinel -1 which sets apply_bitmask=False
# before the accept check, so no accept call for the bonus.
assert grammar.accept_tokens.call_count == 4, (
f"Expected 4 accept_tokens calls, got {grammar.accept_tokens.call_count}"
)

# rollback was called (state_advancements == 4 > 0)
grammar.rollback.assert_called_once()
15 changes: 15 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,21 @@ def update_from_output(
request.status = RequestStatus.FINISHED_ERROR
request.resumable = False
stopped = True
elif (
request.structured_output_request is not None
and request.structured_output_request.bonus_requires_grammar
and generated_token_ids
):
struct_req = request.structured_output_request
assert struct_req.grammar is not None
struct_req.grammar.suppress_accept_errors = True
accepted = struct_req.grammar.accept_tokens(
req_id, [generated_token_ids[-1]]
)
struct_req.grammar.suppress_accept_errors = False
struct_req.bonus_requires_grammar = False
if accepted:
struct_req.reasoning_ended = True

routed_experts = None
if (
Expand Down
38 changes: 31 additions & 7 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ def _fill_bitmasks(
if apply_bitmask and not grammar.is_terminated():
grammar.fill_bitmask(self._grammar_bitmask, index)
else:
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# requests here.
self._grammar_bitmask[index].fill_(self._full_mask)

def _async_submit_fill_bitmask(
Expand Down Expand Up @@ -275,17 +272,44 @@ def grammar_bitmask(
grammar = structured_output_request.grammar
apply_bitmask = self.should_fill_bitmask(request)

# Per-token reasoning-end detection: when apply_bitmask is
# False because reasoning is active, we need to watch for the
# end-of-reasoning token within the draft batch. When found,
# flip apply_bitmask so subsequent positions (including the
# bonus token) get the constrained bitmask.
reasoner = (
self._get_reasoner(request)
if not apply_bitmask and not self.enable_in_reasoning
else None
)
reasoning_ended_in_batch = False

state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, ())
for token in itertools.chain(req_tokens, (-1,)):
self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),))
if token == -1:
# Stop advancing the grammar once we hit a padding token.
apply_bitmask = False
if apply_bitmask and not grammar.is_terminated():
accepted = grammar.accept_tokens(req_id, [token])
assert accepted, (token, req_id, scheduled_spec_decode_tokens)
state_advancements += 1
if not reasoning_ended_in_batch:
accepted = grammar.accept_tokens(req_id, [token])
assert accepted, (
token,
req_id,
scheduled_spec_decode_tokens,
)
state_advancements += 1
elif (
token != -1
and reasoner is not None
and reasoner.is_reasoning_end_streaming(
request.all_token_ids, [token]
)
):
apply_bitmask = True
reasoning_ended_in_batch = True
reasoner = None
structured_output_request.bonus_requires_grammar = True
cumulative_index += 1
if state_advancements > 0:
grammar.rollback(state_advancements)
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/structured_output/backend_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def reset(self):
Resets the state of the structured output grammar.
"""

@property
def suppress_accept_errors(self) -> bool:
return False

@suppress_accept_errors.setter
def suppress_accept_errors(self, value: bool) -> None:
pass


@dataclass
class StructuredOutputBackend(ABC):
Expand Down
Loading
Loading