diff --git a/tests/reasoning/test_hyperclovax_reasoning_parser.py b/tests/reasoning/test_hyperclovax_reasoning_parser.py new file mode 100644 index 000000000000..aec17a987fb7 --- /dev/null +++ b/tests/reasoning/test_hyperclovax_reasoning_parser.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import ( + StreamingReasoningReconstructor, + run_reasoning_extraction, + run_reasoning_extraction_streaming, +) +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +PARSER_NAME = "hyperclovax" + +TOKENIZER_NAME = "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B" + + +THINK_START = "/think\n" +THINK_END_BASE = "<|im_end|>\n<|im_start|>assistant" +FUNCTION_CALL_ROLE = " -> tool/function_call\n" + + +def _tool_payload(name: str = "search", args: str = '{"query":"weather"}') -> str: + return f'[{{"name":"{name}","arguments":{args}}}]' + + +@pytest.fixture(scope="module") +def hyperclovax_tokenizer(): + return AutoTokenizer.from_pretrained(TOKENIZER_NAME) + + +@pytest.fixture +def parser(hyperclovax_tokenizer) -> ReasoningParser: + return ReasoningParserManager.get_reasoning_parser(PARSER_NAME)( + hyperclovax_tokenizer + ) + + +def test_hyperclovax_reasoning_parser_creation(hyperclovax_tokenizer): + parser_cls = ReasoningParserManager.get_reasoning_parser(PARSER_NAME) + created = parser_cls(hyperclovax_tokenizer) + assert isinstance(created, ReasoningParser) + + +@pytest.fixture +def request_auto() -> ChatCompletionRequest: + return ChatCompletionRequest(messages=[], model="test-model", tool_choice=None) + + +REASONING_WITH_CONTENT = { + "output": THINK_START + + "This is reasoning.\n" + + THINK_END_BASE + + "\nThis is the answer.", + "reasoning": "This is reasoning.\n", + "content": "\nThis is the answer.", +} + +REASONING_ONLY = { + "output": THINK_START + "Still thinking...", + "reasoning": "Still thinking...", + "content": None, +} + +EMPTY_THINKING_NONSTREAM = { + "output": THINK_START + THINK_END_BASE + "\nAnswer.", + "reasoning": None, + "content": "\nAnswer.", +} + +NO_THINKING_NONSTREAM = { + "output": "\nDirect answer.", + "reasoning": None, + "content": "Direct answer.", + "tool_choice": "auto", +} + +TOOL_CALL_AFTER_THINK_NONSTREAM = { + "output": THINK_START + + "Let me check.\n" + + THINK_END_BASE + + FUNCTION_CALL_ROLE + + _tool_payload(), + "reasoning": "Let me check.\n", + "content": FUNCTION_CALL_ROLE + _tool_payload(), +} + +DIRECT_TOOL_CALL_NONSTREAM = { + "output": FUNCTION_CALL_ROLE + _tool_payload(), + "reasoning": None, + "content": _tool_payload(), + "tool_choice": "required", +} + +MULTILINE_REASONING = { + "output": THINK_START + + "Line one.\nLine two.\n" + + THINK_END_BASE + + "\nFinal answer.", + "reasoning": "Line one.\nLine two.\n", + "content": "\nFinal answer.", +} + +NON_STREAMING_TEST_CASES = [ + pytest.param(REASONING_WITH_CONTENT, id="reasoning_with_content"), + pytest.param(REASONING_ONLY, id="reasoning_only"), + pytest.param(EMPTY_THINKING_NONSTREAM, id="empty_thinking"), + pytest.param(NO_THINKING_NONSTREAM, id="no_thinking"), + pytest.param(TOOL_CALL_AFTER_THINK_NONSTREAM, id="tool_call_after_think"), + pytest.param(DIRECT_TOOL_CALL_NONSTREAM, id="direct_tool_call"), + pytest.param(MULTILINE_REASONING, id="multiline_reasoning"), +] + + +EMPTY_THINKING_STREAM = { + "output": THINK_START + THINK_END_BASE + "\nAnswer.", + "reasoning": "", + "content": "\nAnswer.", +} + +NO_THINKING_STREAM = { + "output": "\nDirect answer.", + "reasoning": None, + "content": "\nDirect answer.", +} + +TOOL_CALL_AFTER_THINK_STREAM = { + "output": THINK_START + + "Let me check.\n" + + THINK_END_BASE + + FUNCTION_CALL_ROLE + + _tool_payload(), + "reasoning": "Let me check.\n", + "content": _tool_payload(), +} + +STREAMING_TEST_CASES = [ + pytest.param(REASONING_WITH_CONTENT, id="reasoning_with_content"), + pytest.param(REASONING_ONLY, id="reasoning_only"), + pytest.param(EMPTY_THINKING_STREAM, id="empty_thinking"), + pytest.param(NO_THINKING_STREAM, id="no_thinking"), + pytest.param(TOOL_CALL_AFTER_THINK_STREAM, id="tool_call_after_think"), + pytest.param(MULTILINE_REASONING, id="multiline_reasoning"), +] + + +def _make_request(tool_choice=None) -> ChatCompletionRequest: + if tool_choice in (None, "none"): + return ChatCompletionRequest(messages=[], model="test-model") + return ChatCompletionRequest( + messages=[], + model="test-model", + tools=[ + { + "type": "function", + "function": { + "name": "search", + "description": "test tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + tool_choice=tool_choice, + ) + + +@pytest.mark.parametrize("param_dict", NON_STREAMING_TEST_CASES) +def test_extract_reasoning_nonstreaming( + param_dict: dict, + parser: ReasoningParser, +): + tool_choice = param_dict.get("tool_choice", "none") + request = _make_request(tool_choice=tool_choice) + + output_tokens = [ + parser.model_tokenizer.convert_tokens_to_string([tok]) + for tok in parser.model_tokenizer.tokenize(param_dict["output"]) + ] + reasoning, content = run_reasoning_extraction( + parser, output_tokens, request=request, streaming=False + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] + + +@pytest.mark.parametrize("param_dict", STREAMING_TEST_CASES) +def test_extract_reasoning_streaming( + param_dict: dict, + hyperclovax_tokenizer, +): + fresh_parser = ReasoningParserManager.get_reasoning_parser(PARSER_NAME)( + hyperclovax_tokenizer + ) + + output_tokens = [ + hyperclovax_tokenizer.convert_tokens_to_string([tok]) + for tok in hyperclovax_tokenizer.tokenize(param_dict["output"]) + ] + reasoning, content = run_reasoning_extraction( + fresh_parser, output_tokens, streaming=True + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] + + +def test_is_reasoning_end_true_with_newline_variant(parser: ReasoningParser): + ids = parser.model_tokenizer.encode(THINK_START + "hello" + THINK_END_BASE + "\n") + assert parser.is_reasoning_end(ids) is True + + +def test_is_reasoning_end_true_with_content_after_end(parser: ReasoningParser): + ids = parser.model_tokenizer.encode( + THINK_START + "hello" + THINK_END_BASE + FUNCTION_CALL_ROLE + ) + assert parser.is_reasoning_end(ids) is True + + +def test_is_reasoning_end_false_start_after_end(parser: ReasoningParser): + ids = parser.model_tokenizer.encode(THINK_END_BASE + "\n" + THINK_START + "more") + assert parser.is_reasoning_end(ids) is False + + +def test_is_reasoning_end_false_no_end_token(parser: ReasoningParser): + regular_ids = parser.model_tokenizer.encode("hello world, still reasoning") + assert parser.is_reasoning_end(regular_ids) is False + + +def test_is_reasoning_end_true_single_end_token(parser: ReasoningParser): + assert parser.is_reasoning_end([parser.end_token_id]) is True + + +def test_is_reasoning_end_streaming_true_on_end_token_delta(parser: ReasoningParser): + assert ( + parser.is_reasoning_end_streaming([parser.end_token_id], [parser.end_token_id]) + is True + ) + + +def test_is_reasoning_end_streaming_false_without_end_token_delta( + parser: ReasoningParser, +): + assert ( + parser.is_reasoning_end_streaming( + [parser.end_token_id], [parser.end_token_id + 1] + ) + is False + ) + + +def test_is_reasoning_end_false_empty_sequence(parser: ReasoningParser): + assert parser.is_reasoning_end([]) is False + + +def test_extract_content_ids_after_end_token(parser: ReasoningParser): + sep_text = THINK_START + "abc" + THINK_END_BASE + "hello" + all_ids = parser.model_tokenizer.encode(sep_text) + content_ids = parser.extract_content_ids(all_ids) + + decoded = parser.model_tokenizer.decode(content_ids, skip_special_tokens=False) + assert "hello" in decoded + + +def test_extract_content_ids_no_end_token(parser: ReasoningParser): + still_reasoning_ids = parser.model_tokenizer.encode("still thinking") + assert parser.extract_content_ids(still_reasoning_ids) == [] + + +MULTI_TOKEN_DELTA_CASES = [ + pytest.param( + [THINK_START + "reasoning", THINK_END_BASE + "content"], + "reasoning", + "content", + id="end_tag_and_content_in_one_delta", + ), + pytest.param( + [THINK_START + "start of thinking", " more", THINK_END_BASE + "ok"], + "start of thinking more", + "ok", + id="start_marker_with_reasoning", + ), + pytest.param( + [THINK_START + "reasoning", "<|im_end|>", "\n<|im_start|>assistant", "result"], + "reasoning", + "result", + id="end_tag_split_across_deltas", + ), + pytest.param( + ["\ndirect content"], + None, + "\ndirect content", + id="no_thinking_single_delta", + ), + pytest.param( + [THINK_START + "think", THINK_END_BASE + FUNCTION_CALL_ROLE + _tool_payload()], + "think", + _tool_payload(), + id="tool_call_after_reasoning", + ), +] + + +@pytest.mark.parametrize( + "deltas, expected_reasoning, expected_content", + MULTI_TOKEN_DELTA_CASES, +) +def test_streaming_multi_token_deltas( + deltas: list[str], + expected_reasoning: str | None, + expected_content: str | None, + hyperclovax_tokenizer, +): + fresh_parser = ReasoningParserManager.get_reasoning_parser(PARSER_NAME)( + hyperclovax_tokenizer + ) + reconstructor: StreamingReasoningReconstructor = run_reasoning_extraction_streaming( + fresh_parser, deltas + ) + + assert reconstructor.reasoning == expected_reasoning + assert (reconstructor.other_content or None) == expected_content + + +def test_force_reasoning_treats_all_as_reasoning(parser: ReasoningParser): + request = ChatCompletionRequest( + messages=[], + model="test-model", + chat_template_kwargs={"force_reasoning": True}, + ) + reasoning, content = parser.extract_reasoning( + "No think marker but forced.", request + ) + assert reasoning == "No think marker but forced." + assert content is None + + +def test_skip_reasoning_returns_all_as_content(parser: ReasoningParser): + request = ChatCompletionRequest( + messages=[], + model="test-model", + chat_template_kwargs={"skip_reasoning": True}, + ) + reasoning, content = parser.extract_reasoning( + THINK_START + "This should be content.", request + ) + assert reasoning is None + assert content == THINK_START + "This should be content." + + +def test_force_reasoning_takes_priority_over_skip(parser: ReasoningParser): + request = ChatCompletionRequest( + messages=[], + model="test-model", + chat_template_kwargs={"force_reasoning": True, "skip_reasoning": True}, + ) + reasoning, content = parser.extract_reasoning("some output", request) + assert reasoning == "some output" + assert content is None diff --git a/tests/tool_parsers/test_hyperclovax_tool_parser.py b/tests/tool_parsers/test_hyperclovax_tool_parser.py new file mode 100644 index 000000000000..9499e979daa5 --- /dev/null +++ b/tests/tool_parsers/test_hyperclovax_tool_parser.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os + +import pytest +from transformers import AutoTokenizer + +from tests.tool_parsers.common_tests import ToolParserTestConfig, ToolParserTests +from tests.tool_parsers.utils import run_tool_extraction +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.hyperclovax_tool_parser import HyperCLOVAXToolParser + +TOKENIZER_NAME = "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B" + +TOOL_CALL_START = " -> tool/function_call\n" +TOOL_CALL_END = "<|im_end|>" +ASSISTANT_TOOL_PREFIX = "<|im_end|>\n<|im_start|>assistant -> tool/function_call\n" + + +def _call(name: str, arguments: dict | None = None) -> dict: + return {"name": name, "arguments": arguments or {}} + + +def _tool_call_output( + calls: list[dict], + *, + with_end: bool = True, + with_assistant_prefix: bool = False, + leading_content: str = "", +) -> str: + prefix = ASSISTANT_TOOL_PREFIX if with_assistant_prefix else TOOL_CALL_START + payload = json.dumps(calls, ensure_ascii=False) + suffix = TOOL_CALL_END if with_end else "" + return leading_content + prefix + payload + suffix + + +@pytest.fixture(scope="module") +def hcx_tokenizer() -> TokenizerLike: + local_candidates = [ + os.environ.get("HCX_TOKENIZER_PATH"), + "/home/jp/DEMO/LLM42/base_models/HCX/HCX-SEED-Think-32B", + ] + for path in local_candidates: + if path and os.path.isdir(path): + return AutoTokenizer.from_pretrained(path, local_files_only=True) + + pytest.skip( + "Local HyperCLOVAX tokenizer is required. Set HCX_TOKENIZER_PATH or " + "place tokenizer under /home/jp/DEMO/LLM42/base_models/HCX/HCX-SEED-Think-32B" + ) + + +class TestHyperCLOVAXToolParser(ToolParserTests): + @pytest.fixture + def tokenizer(self, hcx_tokenizer: TokenizerLike) -> TokenizerLike: + return hcx_tokenizer + + @pytest.fixture + def test_config(self) -> ToolParserTestConfig: + return ToolParserTestConfig( + parser_name="hyperclovax", + no_tool_calls_output="This is a plain response without any tools.", + single_tool_call_output=_tool_call_output( + [_call("get_weather", {"city": "Tokyo"})] + ), + parallel_tool_calls_output=_tool_call_output( + [ + _call("get_weather", {"city": "Tokyo"}), + _call("get_time", {"timezone": "Asia/Tokyo"}), + ] + ), + various_data_types_output=( + _tool_call_output( + [ + _call( + "test_function", + { + "string_field": "hello", + "int_field": 42, + "float_field": 3.14, + "bool_field": True, + "null_field": None, + "array_field": ["a", "b", "c"], + "object_field": {"nested": "value"}, + "empty_array": [], + "empty_object": {}, + }, + ) + ] + ) + ), + empty_arguments_output=_tool_call_output([_call("refresh", {})]), + surrounding_text_output=( + _tool_call_output( + [_call("get_weather", {"city": "Tokyo"})], + with_assistant_prefix=True, + leading_content="I will call a tool.\n", + ) + ), + escaped_strings_output=( + _tool_call_output( + [ + _call( + "test_function", + { + "quoted": 'He said "hello"', + "path": "C:\\Users\\file.txt", + "newline": "line1\nline2", + }, + ) + ] + ) + ), + malformed_input_outputs=[ + TOOL_CALL_START + "[", + TOOL_CALL_START + "not-json" + TOOL_CALL_END, + ], + single_tool_call_expected_name="get_weather", + single_tool_call_expected_args={"city": "Tokyo"}, + single_tool_call_expected_content=None, + parallel_tool_calls_count=2, + parallel_tool_calls_names=["get_weather", "get_time"], + ) + + @pytest.mark.parametrize("streaming", [True, False]) + def test_tool_call_after_assistant_separator( + self, hcx_tokenizer: TokenizerLike, streaming: bool + ): + model_output = _tool_call_output( + [_call("get_weather", {"city": "Seoul"})], + with_assistant_prefix=True, + leading_content="Let me check.\n", + ) + parser = HyperCLOVAXToolParser(hcx_tokenizer) + content, tool_calls = run_tool_extraction( + parser, model_output, streaming=streaming + ) + + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "get_weather" + assert json.loads(tool_calls[0].function.arguments) == {"city": "Seoul"} + if not streaming: + assert content == "Let me check.\n" + + @pytest.mark.parametrize("streaming", [True, False]) + def test_no_tool_call_plain_text( + self, hcx_tokenizer: TokenizerLike, streaming: bool + ): + model_output = "\nThis is a regular assistant reply." + parser = HyperCLOVAXToolParser(hcx_tokenizer) + content, tool_calls = run_tool_extraction( + parser, model_output, streaming=streaming + ) + + assert len(tool_calls) == 0 + assert content == model_output + + @pytest.mark.parametrize("streaming", [True, False]) + def test_multiple_tool_calls_extracted_in_order( + self, hcx_tokenizer: TokenizerLike, streaming: bool + ): + model_output = _tool_call_output( + [ + _call("alpha", {"x": 1}), + _call("beta", {"y": 2}), + _call("gamma", {"z": 3}), + ] + ) + parser = HyperCLOVAXToolParser(hcx_tokenizer) + content, tool_calls = run_tool_extraction( + parser, model_output, streaming=streaming + ) + + assert [tc.function.name for tc in tool_calls] == ["alpha", "beta", "gamma"] + assert json.loads(tool_calls[0].function.arguments) == {"x": 1} + assert json.loads(tool_calls[1].function.arguments) == {"y": 2} + assert json.loads(tool_calls[2].function.arguments) == {"z": 3} + + @pytest.mark.parametrize("streaming", [True, False]) + def test_nested_arguments_preserved( + self, hcx_tokenizer: TokenizerLike, streaming: bool + ): + model_output = _tool_call_output( + [ + _call( + "create_event", + { + "title": "Meeting", + "location": {"city": "Seoul", "room": "A1"}, + "attendees": ["alice", "bob"], + }, + ) + ] + ) + parser = HyperCLOVAXToolParser(hcx_tokenizer) + content, tool_calls = run_tool_extraction( + parser, model_output, streaming=streaming + ) + + assert len(tool_calls) == 1 + args = json.loads(tool_calls[0].function.arguments) + assert args["location"] == {"city": "Seoul", "room": "A1"} + assert args["attendees"] == ["alice", "bob"] + + @pytest.mark.parametrize("streaming", [True, False]) + def test_unicode_in_arguments_preserved( + self, hcx_tokenizer: TokenizerLike, streaming: bool + ): + model_output = _tool_call_output([_call("greet", {"message": "hello"})]) + parser = HyperCLOVAXToolParser(hcx_tokenizer) + content, tool_calls = run_tool_extraction( + parser, model_output, streaming=streaming + ) + + assert len(tool_calls) == 1 + args = json.loads(tool_calls[0].function.arguments) + assert args["message"] == "hello" + + def test_each_streaming_tool_call_has_unique_id(self, hcx_tokenizer: TokenizerLike): + model_output = _tool_call_output([_call("func_a", {}), _call("func_b", {})]) + parser = HyperCLOVAXToolParser(hcx_tokenizer) + _, tool_calls = run_tool_extraction(parser, model_output, streaming=True) + + assert len(tool_calls) == 2 + ids = [tc.id for tc in tool_calls] + assert all(ids), "All tool call IDs must be non-empty" + assert len(set(ids)) == len(ids), "Tool call IDs must be unique" + + def test_malformed_json_returns_no_tool_calls(self, hcx_tokenizer: TokenizerLike): + model_output = TOOL_CALL_START + "[{not valid json}]" + TOOL_CALL_END + parser = HyperCLOVAXToolParser(hcx_tokenizer) + result = parser.extract_tool_calls(model_output, request=None) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == model_output + + def test_missing_end_token_still_parsed(self, hcx_tokenizer: TokenizerLike): + model_output = _tool_call_output( + [_call("search", {"top_k": 3})], with_end=False + ) + parser = HyperCLOVAXToolParser(hcx_tokenizer) + result = parser.extract_tool_calls(model_output, request=None) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "search" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"top_k": 3} + + def test_no_marker_returns_content_unchanged(self, hcx_tokenizer: TokenizerLike): + model_output = "\nHello, how can I help you?" + parser = HyperCLOVAXToolParser(hcx_tokenizer) + result = parser.extract_tool_calls(model_output, request=None) + assert not result.tools_called + assert result.content == model_output diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 2d57b93369d7..37b9621a3848 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -52,6 +52,10 @@ "deepseek_v3_reasoning_parser", "DeepSeekV3ReasoningWithThinkingParser", ), + "hyperclovax": ( + "hyperclovax_reasoning_parser", + "HyperCLOVAXReasoningParser", + ), "hunyuan_a13b": ( "hunyuan_a13b_reasoning_parser", "HunyuanA13BReasoningParser", diff --git a/vllm/reasoning/hyperclovax_reasoning_parser.py b/vllm/reasoning/hyperclovax_reasoning_parser.py new file mode 100644 index 000000000000..7b4bbfde50ca --- /dev/null +++ b/vllm/reasoning/hyperclovax_reasoning_parser.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING + +import regex as re + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.tokenizers import TokenizerLike + +if TYPE_CHECKING: + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + from vllm.entrypoints.openai.responses.protocol import ResponsesRequest + +logger = init_logger(__name__) + + +class HyperCLOVAXReasoningParser(ReasoningParser): + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + + self.think_start_token = "/think\n" + self.think_end_string_base = "<|im_end|>\n<|im_start|>assistant" + self.function_call_role = " -> tool/function_call\n" + + self.end_token_id = self.vocab.get("<|im_end|>") + self.non_reasoning_mode_start_token = tokenizer.encode("\n")[0] + self.no_reasoning_content = False + + self.exact_think_end_strings = [ + self.think_end_string_base + "\n", + self.think_end_string_base + self.function_call_role, + ] + self.think_end_tokens = [ + tokenizer.encode(think_end_string) + for think_end_string in self.exact_think_end_strings + ] + + self.buffer_string = "" + self.special_strings = [ + self.think_start_token, + self.think_end_string_base, + self.function_call_role, + ] + self.escaped_special_strings = [re.escape(ss) for ss in self.special_strings] + + def _remove_special_string(self) -> tuple[str, str]: + positions: list[tuple[int, int]] = [] + for ss in self.escaped_special_strings: + positions += [ + (m.start(), m.end()) for m in re.finditer(ss, self.buffer_string) + ] + + sorted_positions = sorted(positions, key=lambda x: x[0]) + to_stream = self.buffer_string[: sorted_positions[-1][0]] + remaining = self.buffer_string[sorted_positions[-1][1] :] + for ss in self.special_strings: + to_stream = to_stream.replace(ss, "") + + return to_stream, remaining + + def _check_is_special_string(self) -> bool: + return any(ss in self.buffer_string for ss in self.special_strings) + + def _check_is_part_of_special_string(self) -> bool: + for ss in self.special_strings: + min_len = min(len(self.buffer_string), len(ss)) + for ln in range(min_len, 0, -1): + if self.buffer_string[-ln:] == ss[:ln]: + return True + return False + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if len(input_ids) > 1: + for think_end_tokens in self.think_end_tokens: + think_end_len = len(think_end_tokens) + if ( + len(input_ids) >= think_end_len + and input_ids[-think_end_len:] == think_end_tokens + ): + return True + return False + + return self.no_reasoning_content or ( + self.end_token_id is not None and self.end_token_id in input_ids + ) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + if self.end_token_id is None: + return False + return self.end_token_id in delta_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.end_token_id is None or self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning( + self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest" + ) -> tuple[str | None, str | None]: + chat_template_kwargs = getattr(request, "chat_template_kwargs", None) or {} + tool_choice = getattr(request, "tool_choice", None) + + is_reasoning = False + if chat_template_kwargs.get("force_reasoning", False): + is_reasoning = True + elif chat_template_kwargs.get("skip_reasoning", False): + return None, model_output + + if model_output.startswith(self.think_start_token): + is_reasoning = True + _, _, model_output = model_output.partition(self.think_start_token) + + if self.think_end_string_base not in model_output: + if is_reasoning: + return model_output, None + + if tool_choice in ("auto", None): + if model_output.startswith("\n"): + model_output = model_output[1:] + return None, model_output + + return None, model_output.replace(self.function_call_role, "") + + reasoning_content, _, content = model_output.partition( + self.think_end_string_base + ) + return reasoning_content or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if current_token_ids and ( + current_token_ids[0] == self.non_reasoning_mode_start_token + ): + self.no_reasoning_content = True + + if len(current_text) == 0: + return None + + if self.no_reasoning_content: + return DeltaMessage(content=delta_text) + + self.buffer_string += delta_text + + if self._check_is_special_string(): + if current_text.startswith(self.function_call_role): + self.no_reasoning_content = True + delta_text = self.buffer_string + self.buffer_string = "" + return DeltaMessage(content=delta_text) + + buffered_content, delta_text = self._remove_special_string() + self.buffer_string = delta_text + + if buffered_content: + if self._check_is_part_of_special_string(): + return DeltaMessage(reasoning=buffered_content) + self.buffer_string = "" + return DeltaMessage(reasoning=buffered_content, content=delta_text) + + if self._check_is_part_of_special_string(): + return None + + delta_text = self.buffer_string + self.buffer_string = "" + + if self.think_end_string_base in current_text: + return DeltaMessage(content=delta_text) + return DeltaMessage(reasoning=delta_text) diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index bffa00c4ef31..991927c2f3cd 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -62,6 +62,10 @@ "hermes_tool_parser", "Hermes2ProToolParser", ), + "hyperclovax": ( + "hyperclovax_tool_parser", + "HyperCLOVAXToolParser", + ), "hunyuan_a13b": ( "hunyuan_a13b_tool_parser", "HunyuanA13BToolParser", diff --git a/vllm/tool_parsers/hyperclovax_tool_parser.py b/vllm/tool_parsers/hyperclovax_tool_parser.py new file mode 100644 index 000000000000..b7b9e29c09f1 --- /dev/null +++ b/vllm/tool_parsers/hyperclovax_tool_parser.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from collections.abc import Sequence + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser + +logger = init_logger(__name__) + + +class HyperCLOVAXToolParser(ToolParser): + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): + super().__init__(tokenizer, tools) + + self.tool_call_start_token: str = " -> tool/function_call\n" + self.tool_call_end_token: str = "<|im_end|>" + self.tool_call_regex = re.compile( + r"-> tool/function_call\n(.*?)<\|im_end\|>|" + r"-> tool/function_call\n(.*)]", + re.DOTALL, + ) + + self.tool_call_offset = 0 + self._buffer = "" + self._sent_content_len = 0 + self._pending_messages: list[DeltaMessage] = [] + + @staticmethod + def _partial_tag_overlap(text: str, tag: str) -> int: + max_check = min(len(tag) - 1, len(text)) + for k in range(max_check, 0, -1): + if text.endswith(tag[:k]): + return k + return 0 + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + try: + tool_call_match = self.tool_call_regex.search(model_output) + if tool_call_match is None: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + if tool_call_match.group(1) is not None: + raw_function_calls = json.loads(tool_call_match.group(1)) + else: + raw_function_calls = json.loads(tool_call_match.group(2) + "]") + + if isinstance(raw_function_calls, dict): + raw_function_calls = [raw_function_calls] + if not isinstance(raw_function_calls, list): + raise ValueError("tool calls payload must be object or list") + + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) + for function_call in raw_function_calls + ] + + prefix = "<|im_end|>\n<|im_start|>assistant -> tool/function_call\n" + if prefix in model_output: + content = model_output.split(prefix)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=None, + ) + + except Exception: + logger.exception("Error extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if self._pending_messages: + return self._pending_messages.pop(0) + + self._buffer += delta_text + + if self.tool_call_start_token not in self._buffer: + overlap = self._partial_tag_overlap( + self._buffer, self.tool_call_start_token + ) + safe_len = len(self._buffer) - overlap + if safe_len > self._sent_content_len: + content = self._buffer[self._sent_content_len : safe_len] + self._sent_content_len = safe_len + return DeltaMessage(content=content) + return None + + marker_idx = self._buffer.find(self.tool_call_start_token) + if self._sent_content_len < marker_idx: + content = self._buffer[self._sent_content_len : marker_idx] + self._sent_content_len = marker_idx + if content: + return DeltaMessage(content=content) + + if marker_idx + len(self.tool_call_start_token) > len(self._buffer): + return None + + function_call_text = self._buffer[ + marker_idx + len(self.tool_call_start_token) : + ] + function_call_text = function_call_text[self.tool_call_offset :] + + opening_brace_index = None + for idx, ch in enumerate(function_call_text): + if ch == "{": + opening_brace_index = idx + break + + if opening_brace_index is None: + return None + + closing_brace_indices = [ + idx for idx, ch in enumerate(function_call_text) if ch == "}" + ] + if not closing_brace_indices: + return None + + for closing_brace_index in closing_brace_indices: + candidate = function_call_text[ + opening_brace_index : closing_brace_index + 1 + ] + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + + if not isinstance(parsed, dict): + continue + + self.current_tool_id += 1 + self.tool_call_offset += closing_brace_index + 1 + self.prev_tool_call_arr.append(parsed) + self.streamed_args_for_tool.append(candidate) + + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=parsed.get("name", ""), + arguments=json.dumps( + parsed.get("arguments", ""), ensure_ascii=False + ), + ).model_dump(exclude_none=True), + ) + ] + ) + + return None