diff --git a/tests/tool_parsers/test_lfm2_tool_parser.py b/tests/tool_parsers/test_lfm2_tool_parser.py new file mode 100644 index 000000000000..9cb5b195f1a7 --- /dev/null +++ b/tests/tool_parsers/test_lfm2_tool_parser.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest +from transformers import AutoTokenizer + +from tests.tool_parsers.utils import ( + run_tool_extraction, + run_tool_extraction_streaming, +) +from vllm.entrypoints.openai.engine.protocol import FunctionCall +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager + +TOOL_CALL_START = "<|tool_call_start|>" +TOOL_CALL_END = "<|tool_call_end|>" + +SIMPLE_FUNCTION_OUTPUT = "get_candidate_status(candidate_id='12345')" +SIMPLE_FUNCTION_CALL = FunctionCall( + name="get_candidate_status", + arguments='{"candidate_id": "12345"}', +) +MORE_TYPES_FUNCTION_OUTPUT = ( + "register_user(name='John Doe', " + "age=37, " + "address={'city': 'San Francisco', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])" +) +MORE_TYPES_FUNCTION_CALL = FunctionCall( + name="register_user", + arguments='{"name": "John Doe", ' + '"age": 37, ' + '"address": {"city": "San Francisco", "state": "CA"}, ' + '"role": null, ' + '"passed_test": true, ' + '"aliases": ["John", "Johnny"]}', +) +PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments="{}", +) +EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" +EMPTY_DICT_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"additional_data": {}}', +) +EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])" +EMPTY_LIST_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"steps": []}', +) +ESCAPED_STRING_FUNCTION_OUTPUT = ( + r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')" +) +ESCAPED_STRING_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', +) +DOTTED_NAME_FUNCTION_OUTPUT = ( + "grocery.orderIngredients(" + "ingredientList=[{'name': 'Lasagna noodles', 'amount': 250, 'unit': 'g'}], " + "deliveryAddress='845 Willow Lane, Springfield, IL 62704')" +) +DOTTED_NAME_FUNCTION_CALL = FunctionCall( + name="grocery.orderIngredients", + arguments=( + '{"ingredientList": [' + '{"name": "Lasagna noodles", "amount": 250, "unit": "g"}], ' + '"deliveryAddress": "845 Willow Lane, Springfield, IL 62704"}' + ), +) + + +@pytest.fixture(scope="module") +def lfm2_tokenizer() -> TokenizerLike: + return AutoTokenizer.from_pretrained("LiquidAI/LFM2.5-1.2B-Instruct") + + +def _wrap(tool_text: str, content_after: str = "") -> str: + """Wrap pythonic tool call in LFM2.5 sentinel tokens.""" + result = f"{TOOL_CALL_START}[{tool_text}]{TOOL_CALL_END}" + if content_after: + result += f"\n{content_after}" + return result + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool, lfm2_tokenizer: TokenizerLike): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + model_output = "How can I help you today?" + + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + + assert content == model_output + assert len(tool_calls) == 0 + + +TEST_CASES = [ + pytest.param( + True, + _wrap(SIMPLE_FUNCTION_OUTPUT), + [SIMPLE_FUNCTION_CALL], + None, + id="simple_streaming", + ), + pytest.param( + False, + _wrap(SIMPLE_FUNCTION_OUTPUT), + [SIMPLE_FUNCTION_CALL], + None, + id="simple_nonstreaming", + ), + pytest.param( + True, + _wrap(MORE_TYPES_FUNCTION_OUTPUT), + [MORE_TYPES_FUNCTION_CALL], + None, + id="more_types_streaming", + ), + pytest.param( + False, + _wrap(MORE_TYPES_FUNCTION_OUTPUT), + [MORE_TYPES_FUNCTION_CALL], + None, + id="more_types_nonstreaming", + ), + pytest.param( + True, + _wrap(PARAMETERLESS_FUNCTION_OUTPUT), + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_streaming", + ), + pytest.param( + False, + _wrap(PARAMETERLESS_FUNCTION_OUTPUT), + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_nonstreaming", + ), + pytest.param( + True, + _wrap(EMPTY_DICT_FUNCTION_OUTPUT), + [EMPTY_DICT_FUNCTION_CALL], + None, + id="empty_dict_streaming", + ), + pytest.param( + False, + _wrap(EMPTY_DICT_FUNCTION_OUTPUT), + [EMPTY_DICT_FUNCTION_CALL], + None, + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + _wrap(EMPTY_LIST_FUNCTION_OUTPUT), + [EMPTY_LIST_FUNCTION_CALL], + None, + id="empty_list_streaming", + ), + pytest.param( + False, + _wrap(EMPTY_LIST_FUNCTION_OUTPUT), + [EMPTY_LIST_FUNCTION_CALL], + None, + id="empty_list_nonstreaming", + ), + pytest.param( + True, + _wrap(ESCAPED_STRING_FUNCTION_OUTPUT), + [ESCAPED_STRING_FUNCTION_CALL], + None, + id="escaped_string_streaming", + ), + pytest.param( + False, + _wrap(ESCAPED_STRING_FUNCTION_OUTPUT), + [ESCAPED_STRING_FUNCTION_CALL], + None, + id="escaped_string_nonstreaming", + ), + pytest.param( + True, + _wrap(f"{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}"), + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + None, + id="parallel_calls_streaming", + ), + pytest.param( + False, + _wrap(f"{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}"), + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + None, + id="parallel_calls_nonstreaming", + ), + # LFM2.5 specific: content AFTER tool call + pytest.param( + False, + _wrap( + SIMPLE_FUNCTION_OUTPUT, + content_after="Checking the current status of candidate ID 12345.", + ), + [SIMPLE_FUNCTION_CALL], + "Checking the current status of candidate ID 12345.", + id="content_after_tool_call_nonstreaming", + ), + # Dotted / class-method function names: grocery.orderIngredients(...) + pytest.param( + True, + _wrap(DOTTED_NAME_FUNCTION_OUTPUT), + [DOTTED_NAME_FUNCTION_CALL], + None, + id="dotted_name_streaming", + ), + pytest.param( + False, + _wrap(DOTTED_NAME_FUNCTION_OUTPUT), + [DOTTED_NAME_FUNCTION_CALL], + None, + id="dotted_name_nonstreaming", + ), +] + + +@pytest.mark.parametrize( + "streaming, model_output, expected_tool_calls, expected_content", + TEST_CASES, +) +def test_tool_call( + streaming: bool, + model_output: str, + expected_tool_calls: list[FunctionCall], + expected_content: str | None, + lfm2_tokenizer: TokenizerLike, +): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + + if expected_content and not streaming: + assert content == expected_content + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function == expected + + +def test_streaming_tool_call_with_large_steps(lfm2_tokenizer: TokenizerLike): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + model_output_deltas = [ + f"{TOOL_CALL_START}[get_candidate_status(candidate_id='12345'), " + f"{PARAMETERLESS_FUNCTION_OUTPUT}, " + f"{EMPTY_LIST_FUNCTION_OUTPUT}]{TOOL_CALL_END}", + ] + + reconstructor = run_tool_extraction_streaming( + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) + + assert len(reconstructor.tool_calls) == 3 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL + assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL + + +def test_streaming_full_block_and_trailing_in_single_delta( + lfm2_tokenizer: TokenizerLike, +): + """The entire <|tool_call_start|>[...]<|tool_call_end|> block plus + trailing assistant text arrive in one delta. Trailing content must + still be emitted — not silently dropped.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + full_text = f"{TOOL_CALL_START}[{SIMPLE_FUNCTION_OUTPUT}]{TOOL_CALL_END}\nDone." + + reconstructor = run_tool_extraction_streaming(tool_parser, [full_text]) + + assert len(reconstructor.tool_calls) == 1 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert "Done." in reconstructor.other_content + + +def test_streaming_leading_content_and_full_block_in_single_delta( + lfm2_tokenizer: TokenizerLike, +): + """Leading assistant text plus the entire tool block arrive in one + delta. Leading content must be emitted — not silently dropped.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + full_text = ( + f"Let me check. {TOOL_CALL_START}[{SIMPLE_FUNCTION_OUTPUT}]{TOOL_CALL_END}" + ) + + reconstructor = run_tool_extraction_streaming(tool_parser, [full_text]) + + assert len(reconstructor.tool_calls) == 1 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert "Let me check." in reconstructor.other_content + + +def test_streaming_leading_block_and_trailing_in_single_delta( + lfm2_tokenizer: TokenizerLike, +): + """Leading text + complete tool block + trailing text in one delta. + Both leading and trailing content must be preserved.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + full_text = ( + "Let me check. " + f"{TOOL_CALL_START}[{SIMPLE_FUNCTION_OUTPUT}]{TOOL_CALL_END}\nDone." + ) + + reconstructor = run_tool_extraction_streaming(tool_parser, [full_text]) + + assert len(reconstructor.tool_calls) == 1 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert "Let me check." in reconstructor.other_content + assert "Done." in reconstructor.other_content + + +def test_echoed_tool_call_body_not_leaked_to_content( + lfm2_tokenizer: TokenizerLike, +): + """LFM2 sometimes emits the tool call body again after the first + <|tool_call_end|>, capped with a second <|tool_call_end|>. The + echoed body must not surface as assistant content — neither in + streaming nor non-streaming paths.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + body = ( + "[grocery.orderIngredients(" + "ingredientList=[{'name': 'apple', 'quantity': '2'}], " + "deliveryAddress='123 Main St')]" + ) + model_output = f"{TOOL_CALL_START}{body}{TOOL_CALL_END}{body}{TOOL_CALL_END}" + + # Non-streaming + content_ns, tool_calls_ns = run_tool_extraction( + tool_parser, model_output, streaming=False + ) + assert len(tool_calls_ns) == 1 + assert tool_calls_ns[0].function.name == "grocery.orderIngredients" + assert content_ns in (None, "") + + # Streaming: re-fetch a fresh parser since state was mutated above. + tool_parser2: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + content_s, tool_calls_s = run_tool_extraction( + tool_parser2, model_output, streaming=True + ) + assert len(tool_calls_s) == 1 + assert tool_calls_s[0].function.name == "grocery.orderIngredients" + # Echoed body must not leak as content. + assert content_s in (None, "") + assert "grocery.orderIngredients" not in (content_s or "") + assert TOOL_CALL_END not in (content_s or "") + + +def test_streaming_char_by_char_multi_dict_list(lfm2_tokenizer: TokenizerLike): + """Stream a tool call containing a list of multiple dicts one + character at a time. Every prefix lands in some partial-parse state + (mid-key, mid-value, open quote inside dict, empty dict, etc.). The + parser must not raise — incomplete prefixes should silently wait for + more text instead of logging exceptions.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + full_text = ( + f"{TOOL_CALL_START}[grocery.orderIngredients(" + "ingredientList=[" + '{"name": "apple", "quantity": "2"}, ' + '{"name": "bread", "quantity": "1"}' + f"])]{TOOL_CALL_END}" + ) + deltas = [c for c in full_text] + + reconstructor = run_tool_extraction_streaming( + tool_parser, deltas, assert_one_tool_per_delta=False + ) + + assert len(reconstructor.tool_calls) == 1 + assert reconstructor.tool_calls[0].function.name == "grocery.orderIngredients" + import json + + args = json.loads(reconstructor.tool_calls[0].function.arguments) + assert args == { + "ingredientList": [ + {"name": "apple", "quantity": "2"}, + {"name": "bread", "quantity": "1"}, + ] + } + + +def test_streaming_dotted_name_in_single_delta(lfm2_tokenizer: TokenizerLike): + """A pythonic call with a dotted/attribute function name (e.g. + ``domain.method(arg=...)``) must be parsed correctly in streaming mode + just as in non-streaming mode.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + full_text = f"{TOOL_CALL_START}[{DOTTED_NAME_FUNCTION_OUTPUT}]{TOOL_CALL_END}" + + reconstructor = run_tool_extraction_streaming(tool_parser, [full_text]) + + assert len(reconstructor.tool_calls) == 1 + assert reconstructor.tool_calls[0].function == DOTTED_NAME_FUNCTION_CALL + + +def test_adjust_request_disables_skip_special_tokens( + lfm2_tokenizer: TokenizerLike, +): + """When tools are present, the parser must force + ``skip_special_tokens=False`` so the engine does not strip the + <|tool_call_start|>/<|tool_call_end|> sentinels before they reach the + parser.""" + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + + request_with_tools = ChatCompletionRequest( + messages=[], + model="test-model", + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + ], + ) + assert request_with_tools.skip_special_tokens is True + adjusted = tool_parser.adjust_request(request_with_tools) + assert adjusted.skip_special_tokens is False + + # No tools → no override; default behaviour preserved. + request_no_tools = ChatCompletionRequest(messages=[], model="test-model") + assert request_no_tools.skip_special_tokens is True + adjusted_no_tools = tool_parser.adjust_request(request_no_tools) + assert adjusted_no_tools.skip_special_tokens is True + + +@pytest.mark.parametrize("streaming", [False]) +def test_regex_timeout_handling(streaming: bool, lfm2_tokenizer: TokenizerLike): + """Test regex timeout is handled gracefully.""" + tool_parser: ToolParser = ToolParserManager.get_tool_parser("lfm2")(lfm2_tokenizer) + + fake_input = f"{TOOL_CALL_START}[A(A=" + "\t)A(A=,\t" * 2 + fake_input += f"]{TOOL_CALL_END}" + + mock_regex = MagicMock() + mock_regex.match.side_effect = TimeoutError("Regex timeout") + + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_input, streaming=streaming + ) + + assert content == fake_input + assert len(tool_calls) == 0 + mock_regex.match.assert_called_once() diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index f64209e535b7..7c5f45d2022e 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -94,6 +94,10 @@ "jamba_tool_parser", "JambaToolParser", ), + "lfm2": ( + "lfm2_tool_parser", + "Lfm2ToolParser", + ), "kimi_k2": ( "kimi_k2_tool_parser", "KimiK2ToolParser", diff --git a/vllm/tool_parsers/lfm2_tool_parser.py b/vllm/tool_parsers/lfm2_tool_parser.py new file mode 100644 index 000000000000..ee92d060fbea --- /dev/null +++ b/vllm/tool_parsers/lfm2_tool_parser.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +from collections.abc import Sequence + +import regex as re + +import vllm.envs as envs +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaMessage, + ExtractedToolCallInformation, +) +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + Tool, + ToolParser, +) +from vllm.tool_parsers.utils import ( + UnexpectedAstError, + compute_tool_delta, + handle_single_tool, + make_valid_python, +) + +logger = init_logger(__name__) + +TOOL_CALL_START = "<|tool_call_start|>" +TOOL_CALL_END = "<|tool_call_end|>" + + +class Lfm2ToolParser(ToolParser): + """ + Tool call parser for LiquidAI LFM2/LFM2.5 models that produce pythonic + tool calls wrapped in <|tool_call_start|> and <|tool_call_end|> tokens. + + Example model output: + <|tool_call_start|>[get_weather(location="Paris")]<|tool_call_end|> + The weather in Paris is sunny. + + Used when --enable-auto-tool-choice --tool-call-parser lfm2 are all set. + """ + + TOOL_CALL_REGEX = re.compile(r"\[.*\]$", re.DOTALL) + + def __init__( + self, + tokenizer: TokenizerLike, + tools: list[Tool] | None = None, + ): + super().__init__(tokenizer, tools) + + self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START) + self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END) + + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: + raise RuntimeError( + "LFM2 tool parser could not locate " + "<|tool_call_start|>/<|tool_call_end|> tokens in the " + "tokenizer!" + ) + + # Trailing content already emitted to the client. Used by the + # streaming path to suppress LFM2's frequent echo of the tool + # call body after the first <|tool_call_end|> while still + # allowing legitimate post-call prose through. + self._trailing_emitted: str = "" + + def adjust_request( + self, request: ChatCompletionRequest | ResponsesRequest + ) -> ChatCompletionRequest | ResponsesRequest: + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # The <|tool_call_start|>/<|tool_call_end|> sentinels are + # registered as special tokens in the LFM2/LFM2.5 tokenizer. + # With the default ``skip_special_tokens=True`` they are + # stripped from the decoded text before reaching this parser, + # so the tool block becomes invisible. Force the engine to + # preserve them when tool calling is enabled. + request.skip_special_tokens = False + return request + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + @staticmethod + def _strip_echo(raw_after: str) -> str: + """Drop any orphan <|tool_call_end|> (and the preceding text) from + trailing content. LFM2 occasionally echoes the call body after the + first end token and caps it with a second end token; everything + through the last such orphan is model garbage, not user content.""" + last_orphan = raw_after.rfind(TOOL_CALL_END) + if last_orphan != -1: + return raw_after[last_orphan + len(TOOL_CALL_END) :] + return raw_after + + @classmethod + def _extract_tool_call_text( + cls, model_output: str + ) -> tuple[str | None, str | None]: + """Extract the pythonic call text and surrounding content. + + Returns (tool_text, content) where tool_text is the text between + the sentinel tokens and content is everything outside them. + """ + start_idx = model_output.find(TOOL_CALL_START) + if start_idx == -1: + return None, model_output + + end_idx = model_output.find(TOOL_CALL_END, start_idx) + if end_idx == -1: + # Incomplete — treat entire text after start as tool call + tool_text = model_output[start_idx + len(TOOL_CALL_START) :] + content_before = model_output[:start_idx].strip() + content = content_before or None + return tool_text, content + + tool_text = model_output[start_idx + len(TOOL_CALL_START) : end_idx] + content_before = model_output[:start_idx].strip() + content_after = cls._strip_echo( + model_output[end_idx + len(TOOL_CALL_END) :] + ).strip() + + content_parts = [] + if content_before: + content_parts.append(content_before) + if content_after: + content_parts.append(content_after) + content = "\n".join(content_parts) if content_parts else None + + return tool_text, content + + def extract_tool_calls( + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: + tool_text, content = self._extract_tool_call_text(model_output) + + if tool_text is None: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + tool_text = tool_text.strip() + + is_tool_call_pattern = False + try: + is_tool_call_pattern = ( + self.TOOL_CALL_REGEX.match( + tool_text, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS, + ) + is not None + ) + except TimeoutError: + logger.warning("Regex timeout occurred when matching tool call pattern.") + + if not is_tool_call_pattern: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + module = ast.parse(tool_text) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts + ): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=content, + ) + else: + raise UnexpectedAstError("Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + # If the tool call start token hasn't appeared yet, stream as content. + if TOOL_CALL_START not in current_text: + return DeltaMessage(content=delta_text) + + # Compute leading content (before <|tool_call_start|>) that arrived + # in this delta and hasn't been streamed yet. Without this, when the + # prefix and the start token land in the same delta the prefix is + # silently dropped — token-by-token streaming masked the bug because + # the prefix tokens always arrived in earlier deltas. + leading_content = "" + if TOOL_CALL_START not in previous_text: + start_idx = current_text.find(TOOL_CALL_START) + # previous_text contained no start token, so it has already been + # streamed via the no-start-token branch above. + leading_content = current_text[len(previous_text) : start_idx] + + has_end_in_current = TOOL_CALL_END in current_text + has_end_in_previous = TOOL_CALL_END in previous_text + + # Compute trailing content (after <|tool_call_end|>) not yet + # streamed. LFM2 frequently echoes the tool call body again + # after the first end token, capped with a second end token. + # Suppress that echo: + # - If a second <|tool_call_end|> has appeared, treat + # everything through the last one as garbage. + # - If the trailing starts with `[` or `<` (potential echo + # body or another sentinel) and no second end token has + # arrived yet, buffer it instead of emitting. + trailing_content = "" + if has_end_in_current: + end_idx = current_text.find(TOOL_CALL_END) + len(TOOL_CALL_END) + full_trailing = current_text[end_idx:] + stripped_trailing = self._strip_echo(full_trailing) + if stripped_trailing == full_trailing: + # No second end token yet — possibly mid-echo. + lstripped = full_trailing.lstrip() + if lstripped.startswith("[") or lstripped.startswith("<"): + # Suspect echo; hold off until resolved. + final_trailing = self._trailing_emitted + else: + final_trailing = full_trailing + else: + final_trailing = stripped_trailing + if final_trailing.startswith(self._trailing_emitted): + trailing_content = final_trailing[len(self._trailing_emitted) :] + self._trailing_emitted = final_trailing + + # If tools were already parsed in a prior delta, just stream any + # newly arrived trailing content. + if has_end_in_current and self.prev_tool_call_arr and has_end_in_previous: + if trailing_content: + return DeltaMessage(content=trailing_content) + return DeltaMessage(content="") + + # Extract the pythonic text between start and end tokens. + tool_text = current_text.split(TOOL_CALL_START, 1)[1] + # Strip the end token if present (entire call arrived at once). + if TOOL_CALL_END in tool_text: + tool_text = tool_text.split(TOOL_CALL_END, 1)[0] + + def _content_only_or_none() -> DeltaMessage | None: + """Return a content-only delta if any content arrived in this + chunk, otherwise None. Used on incremental-parse failure paths + so leading/trailing content is never silently dropped. + """ + combined = leading_content + trailing_content + return DeltaMessage(content=combined) if combined else None + + try: + valid_and_added_text = make_valid_python(tool_text) + if valid_and_added_text is None: + return _content_only_or_none() + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts + ): + raise UnexpectedAstError("Tool output must be a list of function calls") + tool_calls = [ + handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = ( + index < len(tool_calls) - 1 or ")]" not in added_text + ) + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = added_text[:-2] if not new_call_complete else "" + if not new_call_complete and added_text[-2] == ")": + withheld_suffix = withheld_suffix + "}" + withheld_suffix = withheld_suffix.replace("'", '"') + delta = compute_tool_delta( + self.streamed_args_for_tool[index], + new_call, + index, + withheld_suffix, + ) + + if delta is not None: + tool_deltas.append(delta) + if ( + delta.function is not None + and delta.function.arguments is not None + ): + self.streamed_args_for_tool[index] += delta.function.arguments + + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + combined_content = leading_content + trailing_content + + if tool_deltas or combined_content: + return DeltaMessage( + content=combined_content if combined_content else None, + tool_calls=tool_deltas, + ) + elif not added_text and self.current_tool_id > 0: + return DeltaMessage(content="") + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction error" + ) + return _content_only_or_none() diff --git a/vllm/tool_parsers/utils.py b/vllm/tool_parsers/utils.py index 439441690d04..464ed40f948b 100644 --- a/vllm/tool_parsers/utils.py +++ b/vllm/tool_parsers/utils.py @@ -308,20 +308,43 @@ def get_parameter_value(val: ast.expr) -> Any: raise UnexpectedAstError("Tool call arguments must be literals") +def _ast_callable_dotted_name(node: ast.expr) -> str: + """Return the dotted name for a call target, walking ``ast.Attribute`` + chains so ``a.b.c(...)`` becomes ``"a.b.c"``. + + Raises: + UnexpectedAstError: If the chain does not bottom out in an + ``ast.Name`` (e.g. subscript or call expression as receiver). + """ + parts: list[str] = [] + current: ast.expr = node + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if not isinstance(current, ast.Name): + raise UnexpectedAstError("Invalid tool call name") + parts.append(current.id) + return ".".join(reversed(parts)) + + def handle_single_tool(call: ast.Call) -> ToolCall: """Convert a single AST function call node into a ToolCall object. + Accepts both bare names (``foo(...)``) and dotted attribute chains + (``a.b.c(...)``); the resulting tool call ``name`` field preserves the + dotted form. + Raises: - UnexpectedAstError: If the call node does not have a simple - function name (e.g. it's an attribute access or subscript). + UnexpectedAstError: If the call target is neither a simple name + nor a chain of attribute accesses bottoming out in a name. """ - if not isinstance(call.func, ast.Name): + if not isinstance(call.func, (ast.Name, ast.Attribute)): logger.warning( "Tool call has non-simple function name: %s", ast.dump(call.func), ) raise UnexpectedAstError("Invalid tool call name") - function_name = call.func.id + function_name = _ast_callable_dotted_name(call.func) arguments = {} for keyword in call.keywords: arguments[keyword.arg] = get_parameter_value(keyword.value) @@ -403,7 +426,28 @@ def make_valid_python(text: str) -> tuple[str, str] | None: for char in reversed(bracket_stack): added_text += _CLOSING[char] - return text + added_text, added_text + candidate = text + added_text + + # Streaming partial text can land in shapes the bracket-counting + # heuristics above don't catch. Two failure modes: + # 1. Mid-key inside a dict (`..., "k`) closes to `..., "k"}` — a + # syntactically invalid mixed dict/set. + # 2. A bare string inside a dict (`{"k`) closes to `{"k"}` — valid + # Python but a *set* literal, which downstream tool-call AST + # handling rejects. + # Validate the candidate parses, has a body, and contains no Set + # nodes (pythonic tool calls always use dicts for `{...}`). + try: + module = ast.parse(candidate) + except SyntaxError: + return None + if not module.body: + return None + for node in ast.walk(module): + if isinstance(node, ast.Set): + return None + + return candidate, added_text def compute_tool_delta(