diff --git a/tests/reasoning/test_step3p5_reasoning_parser.py b/tests/reasoning/test_step3p5_reasoning_parser.py new file mode 100644 index 000000000000..718aeefb1743 --- /dev/null +++ b/tests/reasoning/test_step3p5_reasoning_parser.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "step3p5" +start_token = "" +end_token = "" + +REASONING_MODEL_NAME = "stepfun-ai/Step-3.5-Flash" + + +@pytest.fixture(scope="module") +def step3p5_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +SIMPLE_REASONING = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +# need to get into parser again to remove newline after +COMPLETE_REASONING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +NO_CONTENT = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING_STREAMING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +REASONING_WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING_WITH_THINK = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +THINK_NO_END = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +EMPTY = { + "output": "", + "reasoning_content": None, + "content": None, + "is_reasoning_end": False, +} +EMPTY_STREAMING = { + "output": "", + "reasoning_content": None, + "content": None, + "is_reasoning_end": False, +} +NEW_LINE = { + "output": "\nThis is a reasoning section\nThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +NEW_LINE_STREAMING = { + "output": "\nThis is a reasoning section\n\nThis is the rest", + "reasoning_content": "\nThis is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +NEW_LINE_STREAMING_COMPLEX_CONTENT = { + "output": "\n This is a \n reasoning section\n\n\n\n\nThis is the rest", + "reasoning_content": "\n This is a \n reasoning section\n\n", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} + +MULTI_TURN_PROMPT_CONTENT = { + "output": " This is last turn's reasoning section hello ", + "reasoning_content": "", + "content": "", + "is_reasoning_end": False, +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + False, + NO_CONTENT, + id="no_content_token", + ), + pytest.param( + True, + NO_REASONING_STREAMING, + id="no_reasoning_token_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + SHORTEST_REASONING, + id="shortest", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING, + id="shortest_streaming", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + id="shortest_with_think_streaming", + ), + pytest.param( + False, + THINK_NO_END, + id="think_no_end", + ), + pytest.param( + True, + THINK_NO_END, + id="think_no_end_streaming", + ), + pytest.param( + False, + EMPTY, + id="empty", + ), + pytest.param( + True, + EMPTY_STREAMING, + id="empty_streaming", + ), + pytest.param( + False, + NEW_LINE, + id="new_line", + ), + pytest.param( + True, + NEW_LINE_STREAMING, + id="new_line_streaming", + ), + pytest.param( + True, + NEW_LINE_STREAMING_COMPLEX_CONTENT, + id="new_line_streaming_complex_content", + ), + pytest.param( + True, + MULTI_TURN_PROMPT_CONTENT, + id="multi_turn_prompt_content", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + step3p5_tokenizer, + request, +): + output = step3p5_tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + step3p5_tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + step3p5_tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + print(f"reasoning: {reasoning}") + print(f"content: {content}") + test_id = request.node.callspec.id if hasattr(request.node, "callspec") else None + if request.node.callspec.id != "multi_turn_prompt_content": + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + # Test is_reasoning_end + output_ids = step3p5_tokenizer.convert_tokens_to_ids(output) + if streaming: + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_ids) + # Fixed expected token ids for specific test cases + test_id = ( + request.node.callspec.id if hasattr(request.node, "callspec") else None + ) + # Match most specific first + if test_id not in [ + "new_line_streaming_complex_content", + "new_line_streaming", + "new_line", + "multi_turn_prompt_content", + ]: + expected_content_ids = step3p5_tokenizer.convert_tokens_to_ids( + step3p5_tokenizer.tokenize(param_dict["content"]) + ) + assert content == expected_content_ids + else: + content = parser.extract_content_ids(output) + assert content == [] + + +def test_step3p5_streaming_drops_leading_newline(step3p5_tokenizer): + parser_cls = ReasoningParserManager.get_reasoning_parser("step3p5") + parser = parser_cls(step3p5_tokenizer) + output = "calc\nAnswer" + tokens = step3p5_tokenizer.tokenize(output) + output_tokens = [ + step3p5_tokenizer.convert_tokens_to_string([token]) for token in tokens + ] + + _, content = run_reasoning_extraction(parser, output_tokens, streaming=True) + assert content == "Answer" diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py index af9aa4b4141b..25e9cdb997f6 100644 --- a/vllm/reasoning/step3p5_reasoning_parser.py +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -39,24 +39,59 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): # whether it is immediately before . self._pending_reasoning_newline = False - # Used to delay the reasoning end detection. - # This is necessary to remove the newline appears immediately after , - # which may cause the end detection to be delayed by one round. - self.end_offset = 1 + # Tracks whether we've seen but are still waiting for one more + # token to confirm the end. + self._end_token_pending = False def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: - if self.end_token_id in input_ids and self.end_offset > 0: - self.end_offset -= 1 - return False - return self.end_offset < 1 + return self._is_reasoning_end_from_ids(input_ids) def is_reasoning_end_streaming( self, input_ids: Sequence[int], delta_ids: Iterable[int] ) -> bool: - if self.end_token_id in input_ids and self.end_offset > 0: - self.end_offset -= 1 + # Only examine newly generated tokens; they may contain multiple ids. + return self._is_reasoning_end_from_ids(delta_ids) + + def _is_reasoning_end_from_ids(self, input_ids: Sequence[int]) -> bool: + # Scan backwards to find the last special token, or . + last_special = None + last_idx = -1 + for i in range(len(input_ids) - 1, -1, -1): + token_id = input_ids[i] + if token_id == self.start_token_id: + last_special = "start" + last_idx = i + break + if token_id == self.end_token_id: + last_special = "end" + last_idx = i + break + + if last_special == "start": + # If we're already waiting for one token after , do not + # clear the pending state just because the prompt contains . + # Streaming deltas should not include for this model. + if self._end_token_pending: + return False + # A start token after any end token means reasoning is ongoing. + self._end_token_pending = False + return False + + if last_special == "end": + # Require at least one token after before ending. + if last_idx < len(input_ids) - 1: + self._end_token_pending = False + return True + self._end_token_pending = True return False - return self.end_offset < 1 + + # No special tokens in this input. If we were waiting for one token + # after , any new token completes the end. + if self._end_token_pending and input_ids: + self._end_token_pending = False + return True + + return False def extract_reasoning( self, @@ -136,9 +171,6 @@ def extract_reasoning_streaming( # Content: handle the newline immediately after . if content_to_output is not None: - # No need to get into parser again to remove newline after . - self.end_offset -= 1 - # If we have content, reasoning must have ended. self._pending_reasoning_newline = False