From 238b91e51015cf64736508f2986b055b174aa0f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 19 May 2026 07:14:40 -0700 Subject: [PATCH 01/17] studio: tool calling for Llama-3, Mistral, Gemma 4 on safetensors + MLX (#5615) Adds tool calling for Llama-3, Mistral (pre-v11 + v11+ + [ARGS]), and Gemma 4 to the safetensors / transformers and MLX backends. Parser patched against llama.cpp / vLLM / SGLang per-family parsers and normalises to OpenAI shape. 96 targeted unit tests + cross-OS staging CI (ubuntu / macos-14 / windows) green on the multi-format probe. --- .../core/inference/tool_call_parser.py | 820 +++++++++++++++--- studio/backend/routes/inference.py | 45 +- .../test_safetensors_capability_advertise.py | 83 +- .../tests/test_safetensors_tool_loop.py | 351 ++++++++ 4 files changed, 1167 insertions(+), 132 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index a0ab8a2a53..7b0c9b1f62 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -2,32 +2,72 @@ # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ -Backend-neutral tool-call XML parser shared by GGUF and safetensors. -Tolerates missing closing tags in either ``{json}`` -or ``v...`` shape. +Backend-neutral tool-call parser shared by GGUF, safetensors, and MLX. + +Covers the emission formats so the safetensors + MLX agentic loop sees +the same call shape llama-server normalises for GGUF: + + - ``{json}`` (Qwen / Hermes) + - ``v`` (Qwen3.5 xml) + - ``<|python_tag|>NAME.call(k="v", ...)`` (Llama-3 built-in tools) + - ``<|python_tag|>{"name":..., "parameters":...}`` (Llama-3 custom) + - ``{"name":..., "parameters":...}`` (Llama-3.2 bare JSON) + - ``[TOOL_CALLS] [{...}, ...]`` (Mistral v0.3 / Nemo / Small) + - ``[TOOL_CALLS]name{json}`` (Mistral v11+ / Magistral) + - ``[TOOL_CALLS]name[ARGS]{json}`` (Ministral / Mistral Large 3) + - ``<|tool_call>call:NAME{k:<|"|>v<|"|>}`` (Gemma 4) + +Closing tags / brackets are tolerated when missing because models +frequently truncate them mid-stream. """ import json import re +from typing import Any + + +# ── Streaming-buffer signal markers ───────────────────────────────── + + +# Prefixes the safetensors / MLX streaming buffer watches for to gate +# in-progress text. When ANY of these appear in the cumulative text, +# the state machine switches from STREAMING to DRAINING so we don't +# leak partial markup to the user before we can parse it. +TOOL_XML_SIGNALS = ( + "", + "", + "[TOOL_CALLS]", + "<|tool_call>", +) -# _TOOL_CLOSED_PATS: closed pairs only. _TOOL_ALL_PATS: also trailing -# unclosed runs so truncated tails don't leak markup. +# ── Strip patterns for ``strip_tool_markup`` ──────────────────────── + + +# _TOOL_CLOSED_PATS: closed pairs only (used during streaming so +# in-progress XML stays buffered). _TOOL_ALL_PATS: also matches trailing +# unclosed runs so truncated tails don't leak markup at end-of-turn. _TOOL_CLOSED_PATS = [ re.compile(r".*?", re.DOTALL), re.compile(r".*?", re.DOTALL), + re.compile(r"<\|tool_call>.*?", re.DOTALL), + re.compile(r"\[TOOL_CALLS\]\s*\[.*?\](?:\s*)?", re.DOTALL), + # Mistral v11+ ``[TOOL_CALLS]name{json}`` (may chain), close at ``}``. + re.compile(r"\[TOOL_CALLS\]\s*[\w\.\-]+\s*(?:\[ARGS\])?\s*\{.*?\}", re.DOTALL), ] _TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [ re.compile(r".*$", re.DOTALL), re.compile(r".*$", re.DOTALL), + re.compile(r"<\|tool_call>.*$", re.DOTALL), + re.compile(r"\[TOOL_CALLS\].*$", re.DOTALL), + re.compile(r"<\|python_tag\|>.*$", re.DOTALL), ] -# Prefixes the streaming buffer watches for to gate in-progress text. -TOOL_XML_SIGNALS = ("", "{json} _TC_JSON_START_RE = re.compile(r"\s*\{") -_TC_FUNC_START_RE = re.compile(r"\s*") +# Qwen3.5 / Hermes XML form v +_TC_FUNC_START_RE = re.compile(r"\s*") _TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") -_TC_PARAM_START_RE = re.compile(r"\s*") +_TC_PARAM_START_RE = re.compile(r"\s*") _TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") +# Llama-3 <|python_tag|>NAME.call(...) +_LLAMA3_PYTHON_TAG = "<|python_tag|>" +_LLAMA3_PY_CALL_RE = re.compile( + r"<\|python_tag\|>\s*([\w\.\-]+)\s*\.\s*call\s*\(", +) +_LLAMA3_KV_RE = re.compile( + r"""(\w+)\s*=\s*(?:"((?:\\.|[^"\\])*)"|(-?\d+(?:\.\d+)?)|(true|false|null))""", + re.VERBOSE, +) + +# Mistral [TOOL_CALLS] trigger. v11+ chains multiple triggers, each +# followed by a bare name then either ``{json}`` (Magistral) or +# ``[ARGS]{json}`` (Ministral / Mistral Large 3). +_MISTRAL_TRIGGER = "[TOOL_CALLS]" +_MISTRAL_ARGS_MARKER = "[ARGS]" +_MISTRAL_V11_NAME_RE = re.compile(r"\s*([\w\.\-]+)\s*") + +# Gemma 4 <|tool_call>call:NAME{...}. ``<|"|>`` wraps strings. +_GEMMA_TC_RE = re.compile(r"<\|tool_call>\s*call\s*:\s*([\w\.\-]+)\s*\{") +_GEMMA_STR_BEGIN = '<|"|>' +_GEMMA_STR_END = '<|"|>' +_GEMMA_TC_END = "" + + +# ── Public API ────────────────────────────────────────────────────── + def strip_tool_markup(text: str, *, final: bool = False) -> str: - """Strip tool-call XML from streamed text. + """Strip tool-call markup from streamed text. - ``final=False`` only removes closed pairs (used during streaming so - in-progress XML stays buffered). ``final=True`` also removes a - trailing unclosed run and trims the result. + ``final=False`` only removes closed pairs so in-progress markup + stays buffered. ``final=True`` also removes trailing unclosed runs + and trims the result. """ pats = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS for pat in pats: @@ -80,125 +150,651 @@ def strip_tool_markup(text: str, *, final: bool = False) -> str: return text.strip() if final else text +def has_tool_signal(text: str) -> bool: + """True if ``text`` contains any known tool-call signal.""" + return any(s in text for s in TOOL_XML_SIGNALS) + + def parse_tool_calls_from_text(content: str, *, id_offset: int = 0) -> list[dict]: """Parse OpenAI-format ``tool_calls`` from model text. - Returns a list of ``{"id", "type", "function": {"name", "arguments"}}`` - dicts. ``arguments`` is always a JSON string so callers can hand it - straight back into an OpenAI-style response. + Returns ``[{"id", "type", "function": {"name", "arguments"}}]`` + where ``arguments`` is always a JSON string. Tries each known + emission format in turn; returns as soon as one yields calls so + we never double-count. + """ + # Qwen / Hermes {json} + calls = _parse_tool_call_json(content, id_offset = id_offset) + if calls: + return calls - Handles two shapes: + # Qwen3.5 / Hermes v + calls = _parse_function_xml(content, id_offset = id_offset) + if calls: + return calls - - JSON inside ```` tags: - ``{"name":"web_search","arguments":{"query":"..."}}`` - - XML-style function blocks: - ``v`` + # Llama-3 <|python_tag|>... + calls = _parse_llama3_python_tag(content, id_offset = id_offset) + if calls: + return calls + + # Mistral [TOOL_CALLS]... + calls = _parse_mistral_tool_calls(content, id_offset = id_offset) + if calls: + return calls + + # Gemma 4 <|tool_call>... + calls = _parse_gemma_tool_calls(content, id_offset = id_offset) + if calls: + return calls + + # Llama-3.2 bare JSON ``{"name":..., "parameters":...}`` (no tag). + # Strict: only fires when stripped content STARTS with ``{`` and + # parses as ``{name: str, parameters|arguments: dict}``. Keeps + # plain assistant prose unaffected. + return _parse_llama3_bare_json(content, id_offset = id_offset) - Closing tags (````, ````, ````) - are all optional since models frequently omit them. - """ - tool_calls: list[dict] = [] - # Pattern 1: {json}. Balanced-brace scan that skips - # braces inside JSON strings. +# ── Per-format parsers ────────────────────────────────────────────── + + +def _parse_tool_call_json(content: str, *, id_offset: int) -> list[dict]: + out: list[dict] = [] for m in _TC_JSON_START_RE.finditer(content): - brace_start = m.end() - 1 # position of the opening { - depth, i = 0, brace_start + brace_start = m.end() - 1 + end = _balanced_brace_end(content, brace_start) + if end is None: + continue + try: + obj = json.loads(content[brace_start : end + 1]) + except (json.JSONDecodeError, ValueError): + continue + name = obj.get("name", "") + args = obj.get("arguments", {}) + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + args_str = json.dumps({"value": args}) + if not name: + continue + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + return out + + +def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]: + out: list[dict] = [] + func_starts = list(_TC_FUNC_START_RE.finditer(content)) + for idx, fm in enumerate(func_starts): + func_name = fm.group(1) + body_start = fm.end() + next_func = ( + func_starts[idx + 1].start() if idx + 1 < len(func_starts) else len(content) + ) + end_tag = _TC_END_TAG_RE.search(content[body_start:]) + if end_tag: + body_end = body_start + end_tag.start() + else: + body_end = len(content) + body_end = min(body_end, next_func) + body = _TC_FUNC_CLOSE_RE.sub("", content[body_start:body_end]) + + args: dict = {} + param_starts = list(_TC_PARAM_START_RE.finditer(body)) + if len(param_starts) == 1: + pm = param_starts[0] + val = _TC_PARAM_CLOSE_RE.sub("", body[pm.end() :]) + args[pm.group(1)] = val.strip() + else: + for pidx, pm in enumerate(param_starts): + val_start = pm.end() + next_param = ( + param_starts[pidx + 1].start() + if pidx + 1 < len(param_starts) + else len(body) + ) + val = _TC_PARAM_CLOSE_RE.sub("", body[val_start:next_param]) + args[pm.group(1)] = val.strip() + + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": func_name, "arguments": json.dumps(args)}, + } + ) + return out + + +def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: + """Llama-3 emission shapes: + <|python_tag|>NAME.call(arg="v", ...) (built-in tools) + <|python_tag|>{"name":"NAME", "parameters":{...}} (custom tools) + <|python_tag|>{"name":...}; {"name":...} (multi-call, ``; `` sep) + Accepts both ``parameters`` and ``arguments`` keys per Llama 3.1/3.2. + """ + out: list[dict] = [] + if _LLAMA3_PYTHON_TAG not in content: + return out + + # 1. NAME.call(...) built-in form. + for m in _LLAMA3_PY_CALL_RE.finditer(content): + name = m.group(1) + i = m.end() + depth = 1 in_string = False - while i < len(content): + esc = False + while i < len(content) and depth > 0: ch = content[i] if in_string: - if ch == "\\" and i + 1 < len(content): - i += 2 - continue - if ch == '"': + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': in_string = False + else: + if ch == '"': + in_string = True + elif ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + break + i += 1 + body = content[m.end() : i] + args: dict[str, Any] = {} + for kv in _LLAMA3_KV_RE.finditer(body): + k = kv.group(1) + if kv.group(2) is not None: + try: + args[k] = bytes(kv.group(2), "utf-8").decode("unicode_escape") + except (UnicodeDecodeError, ValueError): + args[k] = kv.group(2) + elif kv.group(3) is not None: + v = kv.group(3) + args[k] = float(v) if "." in v else int(v) + elif kv.group(4) is not None: + args[k] = {"true": True, "false": False, "null": None}[kv.group(4)] + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": json.dumps(args)}, + } + ) + + # 2. <|python_tag|>{"name":..., "parameters":...} JSON form. Use a + # streaming JSON decoder (raw_decode) so we can peel multiple + # objects out of the same emission (separated by ``; `` per + # Llama 3 template). + if not out: + decoder = json.JSONDecoder() + idx = content.find(_LLAMA3_PYTHON_TAG) + while idx >= 0: + search_from = idx + len(_LLAMA3_PYTHON_TAG) + # Scan all `{` from this trigger; raw_decode jumps the + # cursor past each parsed object, but if a `{` falls + # inside an already-decoded object we skip it. + cursor = search_from + while cursor < len(content): + brace = content.find("{", cursor) + if brace < 0: + break + # Stop if we've hit the next <|python_tag|>. + next_tag = content.find(_LLAMA3_PYTHON_TAG, search_from, brace) + if next_tag >= 0: + break + try: + obj, end_offset = decoder.raw_decode(content[brace:]) + except (json.JSONDecodeError, ValueError): + cursor = brace + 1 + continue + if not isinstance(obj, dict): + cursor = brace + end_offset + continue + name = obj.get("name") or obj.get("function") or "" + args = ( + obj.get("parameters") + if "parameters" in obj + else obj.get("arguments", {}) + ) + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + args_str = json.dumps({"value": args}) + if name: + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + cursor = brace + end_offset + idx = content.find(_LLAMA3_PYTHON_TAG, cursor) + return out + + +def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: + """Llama-3.2 ``custom_tools`` shape -- bare JSON ``{"name":..., + "parameters":{...}}`` emitted directly, no ``<|python_tag|>``. + + Strict to avoid firing on tool-message echoes: + + * Content must start with ``{`` once whitespace and any leading + ``<|begin_of_text|>`` / ``<|eot_id|>`` etc. sentinels are stripped. + * Object must have ``name`` (non-empty str) plus a dict in + ``parameters`` or ``arguments``. + * Loops via ``raw_decode`` to peel multiple ``;``-separated calls. + """ + out: list[dict] = [] + stripped = content.lstrip() + # Strip leading Llama-3 sentinel tokens that sometimes precede the + # JSON (``<|eot_id|>`` from the prior turn, ``<|start_header_id|>``). + for sentinel in ( + "<|begin_of_text|>", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", + ): + stripped = stripped.lstrip() + if stripped.startswith(sentinel): + stripped = stripped[len(sentinel) :] + stripped = stripped.lstrip() + if not stripped.startswith("{"): + return out + + decoder = json.JSONDecoder() + cursor = 0 + n = len(stripped) + while cursor < n: + # Skip whitespace and Llama 3 inter-call separator ``;``. + while cursor < n and stripped[cursor] in " \t\n\r;": + cursor += 1 + if cursor >= n or stripped[cursor] != "{": + break + try: + obj, end_offset = decoder.raw_decode(stripped[cursor:]) + except (json.JSONDecodeError, ValueError): + break + if not isinstance(obj, dict): + break + name = obj.get("name") or obj.get("function") or "" + if not isinstance(name, str) or not name: + break + if "parameters" in obj: + args = obj.get("parameters") + elif "arguments" in obj: + args = obj.get("arguments") + else: + break + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + break + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + cursor += end_offset + return out + + +def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: + """Mistral emissions covered: + Pre-v11 array: ``[TOOL_CALLS] [{"name":..., "arguments":...}, ...]`` + Pre-v11 single: ``[TOOL_CALLS]{"name":..., "arguments":...}`` + v11+ single: ``[TOOL_CALLS]name{json_args}`` + v11+ parallel: ``[TOOL_CALLS]a{...}[TOOL_CALLS]b{...}`` + v11+ w/ [ARGS]: ``[TOOL_CALLS]name[ARGS]{json_args}`` (Ministral / Large 3) + """ + out: list[dict] = [] + idx = content.find(_MISTRAL_TRIGGER) + if idx < 0: + return out + + # Decide whether the FIRST occurrence is array / single-object + # (pre-v11) or v11+ bare-name. Skip whitespace, peek at next char. + j = idx + len(_MISTRAL_TRIGGER) + k = j + while k < len(content) and content[k] in " \t\n\r": + k += 1 + if k >= len(content): + return out + + if content[k] == "[": + return _parse_mistral_array(content, k, id_offset) + + if content[k] == "{": + # Could be pre-v11 single object ``{"name": ...}`` or a JSON + # blob immediately following the trigger (rare). Try parsing + # as an object that exposes ``name``; if not, fall through to + # v11+ handling so we don't drop emission silently. + end = _balanced_brace_end(content, k) + if end is not None: + try: + obj = json.loads(content[k : end + 1]) + if isinstance(obj, dict) and obj.get("name"): + _consume_mistral_call(content[k : end + 1], out, id_offset) + return out + except (json.JSONDecodeError, ValueError): + pass + + # v11+ path: walk every ``[TOOL_CALLS]`` and parse ``name{json}`` + # or ``name[ARGS]{json}`` after each trigger. + pos = idx + while pos >= 0: + cur = pos + len(_MISTRAL_TRIGGER) + nm = _MISTRAL_V11_NAME_RE.match(content, cur) + if not nm: + pos = content.find(_MISTRAL_TRIGGER, cur) + continue + name = nm.group(1) + after_name = nm.end() + # Optional ``[ARGS]`` marker. + if content.startswith(_MISTRAL_ARGS_MARKER, after_name): + after_name += len(_MISTRAL_ARGS_MARKER) + while after_name < len(content) and content[after_name] in " \t\n\r": + after_name += 1 + if after_name >= len(content) or content[after_name] != "{": + pos = content.find(_MISTRAL_TRIGGER, cur) + continue + end = _balanced_brace_end(content, after_name) + if end is None: + break + try: + args = json.loads(content[after_name : end + 1]) + except (json.JSONDecodeError, ValueError): + pos = content.find(_MISTRAL_TRIGGER, end + 1) + continue + if not isinstance(args, dict): + pos = content.find(_MISTRAL_TRIGGER, end + 1) + continue + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(args), + }, + } + ) + pos = content.find(_MISTRAL_TRIGGER, end + 1) + return out + + +def _parse_mistral_array(content: str, start: int, id_offset: int) -> list[dict]: + """Parse pre-v11 ``[TOOL_CALLS] [{...}, ...]`` JSON array form.""" + out: list[dict] = [] + j = start + depth = 0 + in_string = False + esc = False + while j < len(content): + ch = content[j] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True elif ch == '"': + in_string = False + else: + if ch == '"': in_string = True - elif ch == "{": + elif ch == "[": depth += 1 - elif ch == "}": + elif ch == "]": depth -= 1 if depth == 0: break - i += 1 - if depth == 0: - json_str = content[brace_start : i + 1] - try: - obj = json.loads(json_str) - tc = { - "id": f"call_{id_offset + len(tool_calls)}", - "type": "function", - "function": { - "name": obj.get("name", ""), - "arguments": obj.get("arguments", {}), - }, - } - if isinstance(tc["function"]["arguments"], dict): - tc["function"]["arguments"] = json.dumps( - tc["function"]["arguments"] - ) - tool_calls.append(tc) - except (json.JSONDecodeError, ValueError): - pass + j += 1 + body = content[start : j + 1] if depth == 0 else content[start:] - # Pattern 2: v... -- closing tags - # optional; don't use as body boundary because code - # values can contain that literal. - if not tool_calls: - func_starts = list(_TC_FUNC_START_RE.finditer(content)) - for idx, fm in enumerate(func_starts): - func_name = fm.group(1) - body_start = fm.end() - next_func = ( - func_starts[idx + 1].start() - if idx + 1 < len(func_starts) - else len(content) - ) - end_tag = _TC_END_TAG_RE.search(content[body_start:]) - if end_tag: - body_end = body_start + end_tag.start() - else: - body_end = len(content) - body_end = min(body_end, next_func) - body = content[body_start:body_end] - body = _TC_FUNC_CLOSE_RE.sub("", body) - - arguments: dict = {} - param_starts = list(_TC_PARAM_START_RE.finditer(body)) - if len(param_starts) == 1: - # Single param: take everything to body end so - # embedded in code strings is preserved. - pm = param_starts[0] - val = body[pm.end() :] - val = _TC_PARAM_CLOSE_RE.sub("", val) - arguments[pm.group(1)] = val.strip() - else: - for pidx, pm in enumerate(param_starts): - param_name = pm.group(1) - val_start = pm.end() - next_param = ( - param_starts[pidx + 1].start() - if pidx + 1 < len(param_starts) - else len(body) - ) - val = body[val_start:next_param] - val = _TC_PARAM_CLOSE_RE.sub("", val) - arguments[param_name] = val.strip() + try: + arr = json.loads(body) + if isinstance(arr, list): + for obj in arr: + if isinstance(obj, dict): + _consume_mistral_call(json.dumps(obj), out, id_offset) + return out + except (json.JSONDecodeError, ValueError): + pass - tc = { - "id": f"call_{id_offset + len(tool_calls)}", + # Healing path: walk objects manually for unclosed array. + for m in re.finditer(r"\{", body): + end = _balanced_brace_end(body, m.start()) + if end is None: + continue + _consume_mistral_call(body[m.start() : end + 1], out, id_offset) + return out + + +def _consume_mistral_call(obj_text: str, out: list[dict], id_offset: int) -> None: + try: + obj = json.loads(obj_text) + except (json.JSONDecodeError, ValueError): + return + if not isinstance(obj, dict): + return + name = obj.get("name") or "" + args = obj.get("arguments") or {} + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + args_str = json.dumps({"value": args}) + if name: + out.append( + { + "id": obj.get("id") or f"call_{id_offset + len(out)}", "type": "function", - "function": { - "name": func_name, - "arguments": json.dumps(arguments), - }, + "function": {"name": name, "arguments": args_str}, } - tool_calls.append(tc) + ) - return tool_calls +def _parse_gemma_tool_calls(content: str, *, id_offset: int) -> list[dict]: + """Gemma 4: <|tool_call>call:NAME{k:<|"|>v<|"|>, ...}.""" + out: list[dict] = [] + for m in _GEMMA_TC_RE.finditer(content): + name = m.group(1) + body_start = m.end() - 1 + end_marker = content.find(_GEMMA_TC_END, body_start) + scan_end = end_marker if end_marker >= 0 else len(content) + end = _gemma_balanced_brace_end(content, body_start, scan_end) + if end is None: + continue + body = content[body_start + 1 : end] + try: + args = _gemma_parse_mapping_body(body) + except Exception: + args = {} + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": json.dumps(args)}, + } + ) + return out -def has_tool_signal(text: str) -> bool: - """Return True if ``text`` contains any tool-call XML signal.""" - return any(s in text for s in TOOL_XML_SIGNALS) + +# ── Brace-balancing helpers ───────────────────────────────────────── + + +def _balanced_brace_end(text: str, brace_pos: int) -> int | None: + """Index of `}` matching `{` at ``brace_pos`` -- ignores `{` `}` + inside JSON strings. Returns None if unmatched.""" + if brace_pos >= len(text) or text[brace_pos] != "{": + return None + depth = 0 + in_string = False + esc = False + i = brace_pos + while i < len(text): + ch = text[i] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + elif ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + +def _gemma_balanced_brace_end(text: str, brace_pos: int, hard_stop: int) -> int | None: + """Same as ``_balanced_brace_end`` but respects Gemma ``<|"|>`` + string runs and matches `{`/`[` symmetrically.""" + if brace_pos >= len(text) or text[brace_pos] != "{": + return None + depth = 0 + i = brace_pos + while i < hard_stop: + if text.startswith(_GEMMA_STR_BEGIN, i): + close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) + if close < 0: + return None + i = close + len(_GEMMA_STR_END) + continue + ch = text[i] + if ch == "{" or ch == "[": + depth += 1 + elif ch == "}" or ch == "]": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + +def _gemma_parse_value(text: str, i: int): + """Parse one Gemma argument value starting at ``i``. Returns + ``(value, next_index)``.""" + if text.startswith(_GEMMA_STR_BEGIN, i): + close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) + if close < 0: + return text[i + len(_GEMMA_STR_BEGIN) :], len(text) + return text[i + len(_GEMMA_STR_BEGIN) : close], close + len(_GEMMA_STR_END) + if text[i] == "{": + end = _gemma_balanced_brace_end(text, i, len(text)) + if end is None: + return {}, len(text) + return _gemma_parse_mapping_body(text[i + 1 : end]), end + 1 + if text[i] == "[": + j, depth = i, 0 + while j < len(text): + if text.startswith(_GEMMA_STR_BEGIN, j): + k = text.find(_GEMMA_STR_END, j + len(_GEMMA_STR_BEGIN)) + if k < 0: + j = len(text) + break + j = k + len(_GEMMA_STR_END) + continue + ch = text[j] + if ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + if depth == 0: + break + j += 1 + body = text[i + 1 : j] + items: list[Any] = [] + k = 0 + while k < len(body): + if body[k] in " \t\n\r,": + k += 1 + continue + v, k = _gemma_parse_value(body, k) + items.append(v) + return items, j + 1 + # Primitive: number, true/false/null, or bare identifier (rare). + end = i + while ( + end < len(text) + and text[end] not in ",}]" + and not text.startswith(_GEMMA_STR_BEGIN, end) + ): + end += 1 + raw = text[i:end].strip() + if raw == "true": + return True, end + if raw == "false": + return False, end + if raw == "null": + return None, end + try: + return int(raw), end + except ValueError: + pass + try: + return float(raw), end + except ValueError: + pass + return raw, end + + +def _gemma_parse_mapping_body(body: str) -> dict[str, Any]: + """Parse content between `{` and `}` for a Gemma argument mapping.""" + out: dict[str, Any] = {} + i = 0 + n = len(body) + while i < n: + while i < n and body[i] in " \t\n\r,": + i += 1 + if i >= n: + break + if body.startswith(_GEMMA_STR_BEGIN, i): + close = body.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) + if close < 0: + break + key = body[i + len(_GEMMA_STR_BEGIN) : close] + i = close + len(_GEMMA_STR_END) + else: + kstart = i + while i < n and body[i] != ":": + i += 1 + key = body[kstart:i].strip() + while i < n and body[i] in " \t\n\r": + i += 1 + if i < n and body[i] == ":": + i += 1 + while i < n and body[i] in " \t\n\r": + i += 1 + if i >= n: + out[key] = None + break + v, i = _gemma_parse_value(body, i) + out[key] = v + return out diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 1b4e7051b0..39f2004fa5 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -256,16 +256,29 @@ def _detect_safetensors_features(backend, chat_template: Optional[str]) -> dict: "supports_tools": False, } ) - # Our safetensors loop only parses {json} - # and .... Llama uses <|python_tag|>, - # Mistral uses [TOOL_CALLS]; advertising tools for those would - # enable a pill the parser cannot honour. GGUF is unaffected -- - # llama-server normalises every format into structured deltas. + # The safetensors / MLX loop parses these emission formats: + # Qwen ``{json}``, Qwen3.5 ``...``, + # Llama-3 ``<|python_tag|>``, Llama-3.2 bare JSON ``{"name":..., + # "parameters":...}``, Mistral ``[TOOL_CALLS]`` (pre-v11 array + + # v11+ ``name{json}``), and Gemma 4 ``<|tool_call>...``. If the + # template advertises tools but does NOT use any of these markers, + # the parser cannot honour the emission - drop the pill. ``{"name":`` + # catches Llama-3.2's ``custom_tools`` shape whose template instructs + # the model to "Respond in the format {\"name\": ..., \"parameters\": + # ...}" without a ``<|python_tag|>`` prefix. + _PARSER_MARKERS = ( + "", + "", + "[TOOL_CALLS]", + "<|tool_call>", + '{"name":', + '{\\"name\\":', + ) if ( flags.get("supports_tools") and chat_template - and "" not in chat_template - and " None: " Do NOT output code blocks -- use the python tool instead." ) -# Regex for stripping leaked tool-call XML from assistant messages/stream +# Regex for stripping leaked tool-call markup from assistant messages / +# stream. Covers every emission format the shared parser handles +# (Qwen / Hermes ````, Qwen3.5 ````, Llama-3 +# ``<|python_tag|>``, Mistral ``[TOOL_CALLS]`` pre-v11 array and v11+ +# ``name{json}``, Gemma 4 ``<|tool_call>...``). Closed +# pairs only so in-progress markup stays buffered upstream. _TOOL_XML_RE = _re.compile( - r".*?|.*?", + "|".join( + [ + r".*?", + r".*?", + r"<\|tool_call>.*?", + r"\[TOOL_CALLS\]\s*\[.*?\](?:\s*)?", + r"\[TOOL_CALLS\]\s*[\w\.\-]+\s*(?:\[ARGS\])?\s*\{.*?\}", + r"<\|python_tag\|>[^\n<]*", + ] + ), _re.DOTALL, ) logger = get_logger(__name__) diff --git a/studio/backend/tests/test_safetensors_capability_advertise.py b/studio/backend/tests/test_safetensors_capability_advertise.py index c3ee5b9ff1..b63e835d2a 100644 --- a/studio/backend/tests/test_safetensors_capability_advertise.py +++ b/studio/backend/tests/test_safetensors_capability_advertise.py @@ -129,11 +129,11 @@ def test_detect_safetensors_features_gptoss_disables_tools(): assert flags["supports_tools"] is False -# Llama-3 / Mistral templates advertise tool handling but the model emits -# tool calls in <|python_tag|> / [TOOL_CALLS] format -- not the -# / , [TOOL_CALLS], and +# <|tool_call>). The route helper must surface supports_tools=True for +# all of them so the UI enables the pill. Only templates whose tool +# format is NONE of the five known markers should be suppressed. LLAMA3_TEMPLATE = """ {%- if tools %} @@ -165,27 +165,88 @@ def test_detect_safetensors_features_gptoss_disables_tools(): {%- endfor %} """ +GEMMA4_TEMPLATE = """ +{%- if tools %} + {{- 'Tools available. Emit calls as ' }} + {{- '<|tool_call>call:NAME{key:<|"|>val<|"|>}' }} + {%- for tool in tools %} + {{- tool | tojson }} + {%- endfor %} +{%- endif %} +""" + -def test_detect_safetensors_features_llama3_template_suppresses_tools(): - """Llama-3 emits <|python_tag|>; safetensors loop cannot parse it.""" +def test_detect_safetensors_features_llama3_template_keeps_tools_on(): + """Llama-3 emits <|python_tag|>; parser now supports it.""" from routes.inference import _detect_safetensors_features backend = SimpleNamespace(active_model_name = "unsloth/Llama-3.2-3B-Instruct") flags = _detect_safetensors_features(backend, LLAMA3_TEMPLATE) - assert flags["supports_tools"] is False + assert flags["supports_tools"] is True -def test_detect_safetensors_features_mistral_template_suppresses_tools(): - """Mistral emits [TOOL_CALLS]; safetensors loop cannot parse it.""" +def test_detect_safetensors_features_mistral_template_keeps_tools_on(): + """Mistral emits [TOOL_CALLS]; parser now supports it.""" from routes.inference import _detect_safetensors_features backend = SimpleNamespace(active_model_name = "unsloth/mistral-7b-instruct-v0.3") flags = _detect_safetensors_features(backend, MISTRAL_TEMPLATE) + assert flags["supports_tools"] is True + + +def test_detect_safetensors_features_gemma4_template_keeps_tools_on(): + """Gemma 4 emits <|tool_call>; parser now supports it.""" + from routes.inference import _detect_safetensors_features + + backend = SimpleNamespace(active_model_name = "unsloth/gemma-4-E2B-it-UD-MLX-4bit") + flags = _detect_safetensors_features(backend, GEMMA4_TEMPLATE) + assert flags["supports_tools"] is True + + +LLAMA3_2_BARE_JSON_TEMPLATE = """ +{%- if tools %} + {{- 'Given the following functions, respond with JSON for a function call.' }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary}.' }} + {%- for tool in tools %} + {{- tool | tojson }} + {%- endfor %} +{%- endif %} +{%- for message in messages %} + {%- if 'tool_calls' in message %} + {{- '{"name": "' + message.tool_calls[0].function.name + '", '}} + {{- '"parameters": ' + (message.tool_calls[0].function.arguments | tojson) + '}' }} + {%- endif %} +{%- endfor %} +""" + + +def test_detect_safetensors_features_llama3_2_bare_json_keeps_tools_on(): + """Llama-3.2 emits bare JSON ``{"name":..., "parameters":...}`` -- the + parser now handles that path, so the pill must stay enabled.""" + from routes.inference import _detect_safetensors_features + + backend = SimpleNamespace(active_model_name = "unsloth/Llama-3.2-3B-Instruct") + flags = _detect_safetensors_features(backend, LLAMA3_2_BARE_JSON_TEMPLATE) + assert flags["supports_tools"] is True + + +def test_detect_safetensors_features_unknown_format_suppresses_tools(): + """A template that advertises tools but uses no known marker must + be suppressed so the UI does not enable an unsupported pill.""" + from routes.inference import _detect_safetensors_features + + tpl = ( + "{%- if tools %}<|im_start|>system\n" + "Emit tool calls as JSON-RPC notifications inside the response." + "<|im_end|>{%- endif %}" + ) + backend = SimpleNamespace(active_model_name = "custom/unknown-tool-format") + flags = _detect_safetensors_features(backend, tpl) assert flags["supports_tools"] is False def test_detect_safetensors_features_qwen_tool_call_keeps_tools_on(): - """Sanity check: gate only suppresses non-Qwen formats.""" + """Sanity check: Qwen marker still flips supports_tools.""" from routes.inference import _detect_safetensors_features backend = SimpleNamespace(active_model_name = "unsloth/Qwen3-0.6B") diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index 923af87c4f..c838cab72d 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -130,6 +130,292 @@ def test_strip_markup_unclosed_final(self): assert "partial" in strip_tool_markup(text) +class TestParserMultiFormat: + """Parser coverage for Llama-3 / Mistral / Gemma 4 emission formats. + + Each model family upstream of GGUF emits a different tool-call + shape. The shared parser must turn all of them into the same + OpenAI ``{name, arguments}`` shape so the safetensors / MLX + agentic loop is family-agnostic. + """ + + # ── Llama-3 ──────────────────────────────────────────────────── + + def test_llama3_python_tag_dot_call(self): + # Llama-3 built-in tools: <|python_tag|>NAME.call(k="v", ...). + import json + + text = '<|python_tag|>brave_search.call(query="weather in Tokyo")' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "brave_search" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"query": "weather in Tokyo"} + + def test_llama3_python_tag_dot_call_multi_arg(self): + import json + + text = ( + "<|python_tag|>get_weather.call(" + 'location="Tokyo", units="celsius", days=5)' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"location": "Tokyo", "units": "celsius", "days": 5} + + def test_llama3_python_tag_json_form(self): + import json + + text = ( + '<|python_tag|>{"name":"web_search",' '"parameters":{"query":"hi","n":5}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"query": "hi", "n": 5} + + def test_llama3_python_tag_json_form_with_eom(self): + # Llama-3 emits ``<|eom_id|>`` after the JSON; must not break parsing. + import json + + text = ( + '<|python_tag|>{"name":"python",' + '"parameters":{"code":"print(2+2)"}}<|eom_id|>' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"code": "print(2+2)"} + + def test_llama3_strip_markup_final(self): + text = '<|python_tag|>brave_search.call(query="x")' + assert strip_tool_markup(text, final = True) == "" + + # ── Llama-3.2 bare JSON ``custom_tools`` ───────────────────── + + def test_llama3_2_bare_json_parameters(self): + # Llama-3.2-Instruct emits bare JSON directly as content; no + # <|python_tag|> prefix per its training template. + import json + + text = '{"name":"web_search","parameters":{"query":"Tokyo weather"}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"query": "Tokyo weather"} + + def test_llama3_2_bare_json_arguments_key(self): + import json + + text = '{"name":"add","arguments":{"a":1,"b":2}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"a": 1, "b": 2} + + def test_llama3_2_bare_json_multi_call(self): + # Llama-3 may chain calls with ``; `` per training template. + text = '{"name":"a","parameters":{}}; ' '{"name":"b","parameters":{}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "a" + assert result[1]["function"]["name"] == "b" + + def test_llama3_2_bare_json_with_eom_sentinel(self): + text = '{"name":"x","parameters":{"y":1}}<|eom_id|>' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "x" + + def test_llama3_2_bare_json_leading_sentinel_skipped(self): + # Sometimes prior <|eot_id|> leaks into the next turn. + text = '<|eot_id|>{"name":"x","parameters":{}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "x" + + def test_llama3_2_bare_json_plain_prose_does_not_fire(self): + # Defensive: must NOT fire on plain assistant prose. + text = "Hello world, how are you today?" + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_embedded_in_prose_does_not_fire(self): + # Defensive: JSON embedded in prose must NOT fire (parser is + # strict about content STARTING with `{`). + text = 'The tool result was: {"name":"foo"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_missing_name_does_not_fire(self): + text = '{"result":"ok","data":[1,2,3]}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_missing_args_does_not_fire(self): + text = '{"name":"x"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_args_not_dict_does_not_fire(self): + text = '{"name":"x","parameters":42}' + assert parse_tool_calls_from_text(text) == [] + + # ── Mistral pre-v11 ─────────────────────────────────────────── + + def test_mistral_pre_v11_array(self): + import json + + text = ( + '[TOOL_CALLS] [{"name":"web_search",' + '"arguments":{"query":"hello"},"id":"abc"}]' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + # Mistral provides its own id; preserve it. + assert result[0]["id"] == "abc" + assert json.loads(result[0]["function"]["arguments"]) == {"query": "hello"} + + def test_mistral_pre_v11_array_multi(self): + text = ( + '[TOOL_CALLS] [{"name":"a","arguments":{"x":1},"id":"id1"},' + '{"name":"b","arguments":{"y":2},"id":"id2"}]' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "a" + assert result[1]["function"]["name"] == "b" + + def test_mistral_pre_v11_unclosed_array(self): + # Closing ``]`` truncated -- parser must heal off individual objects. + text = '[TOOL_CALLS] [{"name":"web_search","arguments":{"q":"x"},"id":"id"}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + + # ── Mistral v11+ ─────────────────────────────────────────────── + + def test_mistral_v11_single(self): + # Magistral / Mistral Small 3.1: bare ``name{json}`` after trigger. + import json + + text = '[TOOL_CALLS]add{"a":3.5,"b":4}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "add" + assert json.loads(result[0]["function"]["arguments"]) == {"a": 3.5, "b": 4} + + def test_mistral_v11_parallel(self): + # v11+ parallel: ``[TOOL_CALLS]a{...}[TOOL_CALLS]b{...}``. + text = '[TOOL_CALLS]add{"a":1}[TOOL_CALLS]sub{"b":2}' + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "add" + assert result[1]["function"]["name"] == "sub" + + def test_mistral_v11_with_args_marker(self): + # Ministral / Mistral Large 3: ``[TOOL_CALLS]name[ARGS]{json}``. + import json + + text = '[TOOL_CALLS]add[ARGS]{"a":1,"b":2}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "add" + assert json.loads(result[0]["function"]["arguments"]) == {"a": 1, "b": 2} + + def test_mistral_strip_markup_v11(self): + text = '[TOOL_CALLS]add{"a":1}' + assert strip_tool_markup(text, final = True) == "" + + # ── Gemma 4 ─────────────────────────────────────────────────── + + def test_gemma4_simple_call(self): + import json + + text = ( + "<|tool_call>call:get_weather{" + 'location:<|"|>Tokyo<|"|>,units:<|"|>celsius<|"|>}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"location": "Tokyo", "units": "celsius"} + + def test_gemma4_with_primitives(self): + import json + + text = ( + "<|tool_call>call:set_pref{" + "enabled:true,attempts:5,threshold:1.5,nickname:null}" + ) + result = parse_tool_calls_from_text(text) + args = json.loads(result[0]["function"]["arguments"]) + assert args == { + "enabled": True, + "attempts": 5, + "threshold": 1.5, + "nickname": None, + } + + def test_gemma4_nested_args(self): + # Gemma 4 nests dicts / lists with bare keys and ``<|"|>`` strings. + import json + + text = ( + "<|tool_call>call:search{" + 'query:<|"|>foo<|"|>,filters:{site:<|"|>example.com<|"|>,recent:true},' + 'tags:[<|"|>a<|"|>,<|"|>b<|"|>]}' + ) + result = parse_tool_calls_from_text(text) + args = json.loads(result[0]["function"]["arguments"]) + assert args["query"] == "foo" + assert args["filters"] == {"site": "example.com", "recent": True} + assert args["tags"] == ["a", "b"] + + def test_gemma4_multi_call(self): + text = ( + "<|tool_call>call:a{x:1}" "<|tool_call>call:b{y:2}" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "a" + assert result[1]["function"]["name"] == "b" + + def test_gemma4_unclosed_does_not_raise(self): + # Truncated mid-stream; must not raise. + text = '<|tool_call>call:foo{x:<|"|>bar<|"|>' + result = parse_tool_calls_from_text(text) + assert isinstance(result, list) + + def test_gemma4_strip_markup_final(self): + text = "<|tool_call>call:foo{x:1}" + assert strip_tool_markup(text, final = True) == "" + + # ── Cross-format sentinels ──────────────────────────────────── + + def test_all_markers_in_tool_xml_signals(self): + # Streaming buffer wakes up on every emission marker. + from core.inference.tool_call_parser import TOOL_XML_SIGNALS + + for marker in ( + "", + "", + "[TOOL_CALLS]", + "<|tool_call>", + ): + assert ( + marker in TOOL_XML_SIGNALS + ), f"streaming loop would not wake on {marker!r}" + + def test_has_tool_signal_for_all_formats(self): + assert has_tool_signal('<|python_tag|>brave_search.call(q="x")') + assert has_tool_signal('[TOOL_CALLS] [{"name":"x"}]') + assert has_tool_signal('[TOOL_CALLS]add{"a":1}') + assert has_tool_signal("<|tool_call>call:foo{}") + + # ──────────────────────────────────────────────────────────────────── # run_safetensors_tool_loop # ──────────────────────────────────────────────────────────────────── @@ -280,6 +566,71 @@ def test_function_xml_form(self): contents = [e for e in events if e["type"] == "content"] assert "Result: 1" in contents[-1]["text"] + def test_llama3_python_tag_form(self): + # The agentic loop must recognise Llama-3's <|python_tag|> + # marker, drain the rest of the turn, and execute the call. + loop, exec_fn = _make_loop( + turns = [ + [ + "<|python_tag|>web_search.call(", + 'query="weather in Tokyo"', + ")", + ], + ["The weather is sunny."], + ], + exec_results = ["Sunny, 22C"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather in Tokyo"})] + contents = [e for e in events if e["type"] == "content"] + assert "sunny" in contents[-1]["text"].lower() + + def test_mistral_pre_v11_form(self): + # Pre-v11 Mistral emission: ``[TOOL_CALLS] [{...}]``. + loop, exec_fn = _make_loop( + turns = [ + [ + '[TOOL_CALLS] [{"name":"web_search",', + '"arguments":{"query":"hi"},"id":"abc"}]', + ], + ["done"], + ], + exec_results = ["ok"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "hi"})] + # Mistral-provided ids must propagate to tool_start events. + tool_start = next(e for e in events if e["type"] == "tool_start") + assert tool_start["tool_call_id"] == "abc" + + def test_mistral_v11_form(self): + # v11+ Mistral emission: bare ``name{json}`` after the trigger. + loop, exec_fn = _make_loop( + turns = [ + ['[TOOL_CALLS]web_search{"query":"hi"}'], + ["done"], + ], + exec_results = ["ok"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "hi"})] + + def test_gemma4_form(self): + # Gemma 4 emission: ``<|tool_call>call:NAME{...}``. + loop, exec_fn = _make_loop( + turns = [ + [ + "<|tool_call>call:web_search{", + 'query:<|"|>weather<|"|>', + "}", + ], + ["sunny"], + ], + exec_results = ["Sunny, 22C"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather"})] + def test_truncated_unclosed_tool_call(self): loop, exec_fn = _make_loop( turns = [ From a0a0c97473ccc37f384321ef994d05ce61545c36 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Tue, 19 May 2026 14:28:21 +0000 Subject: [PATCH 02/17] studio: tool-call healing parity between safetensors / MLX and GGUF After the multi-format parser landed in #5615, the safetensors / MLX agentic loop and the GGUF loop still differed on healing behaviour. This commit closes the gaps in both directions so the two backends react the same way to identical model output. Changes: 1. core/inference/llama_cpp.py -- the GGUF BUFFERING state machine now wakes on every emission marker the shared parser knows. Was ("", " / Mistral [TOOL_CALLS] / Gemma 4 <|tool_call>). Stream cleanup is delegated to the same shared strip_tool_markup so leaked markup from any family is removed from assistant content. 2. core/inference/llama_cpp.py -- per-tool canonical heal key. When a tool arguments field is a bare string and JSON parsing fails, the GGUF path now heals to {"code": raw_args} for python, {"command": raw_args} for terminal, and {"query": raw_args} for everything else. Was hard-coded to {"query": raw_args}, which silently routed every python / terminal emission through web_search. Mirrors safetensors_agentic._CANONICAL_HEAL_ARG. 3. core/inference/safetensors_agentic.py -- re-prompt on plan- without-action. When the model emits a short forward-looking intent ("I'll search for that", "Let me check", "First, I will...") and no tool call, the loop nudges the model to act instead of silently returning a plan-only answer. Up to _MAX_REPROMPTS=3 (matches GGUF). The intent regex, character cap, and instruction text are byte-identical to the GGUF path. The buffer-end fall-through is unified so a buffered intent emission that never exits the BUFFERING state still triggers the re-prompt. 4. core/inference/safetensors_agentic.py -- extra iteration slots for re-prompts. The loop now budgets max_tool_iterations + _MAX_REPROMPTS + 1 total iterations and tracks the tool-call count separately, so a stalling model can be nudged 3x without eating the caller's tool-call budget. Mirrors the _extra slot reservation in the GGUF path. Tests (14 new safetensors-side units; 5 GGUF parity pins): TestLoopRePrompt -- intent-trigger, plain-answer, no-tools, cap-at-three, budget preserved, buffer-end intent. TestLoopCanonicalHealKey -- python / terminal / unknown. TestGGUFSafetensorsHealingParity -- shared markers used, shared strip used, canonical heal keys identical, intent regex matches same phrases, _MAX_REPROMPTS equal on both backends. All 110 targeted tests pass locally; the broader tool / inference / model-config / sandbox / anthropic / mlx suites stay green. Why this matters Without this parity, Llama-3.2 / Mistral / Gemma 4 emissions on Mac (MLX) and Linux-safetensors stop the agentic loop as soon as the model says "Let me...", because the GGUF re-prompt logic never existed on these backends. The two-marker GGUF BUFFERING tuple also let non-Qwen tool emissions stream out as plain prose when llama-server's structured channel did not pick them up. Both paths now drain the same way, heal the same way, and re-prompt the same way -- so a tool call that works on GGUF works identically on safetensors / MLX. --- studio/backend/core/inference/llama_cpp.py | 34 ++- .../core/inference/safetensors_agentic.py | 84 +++++- .../tests/test_safetensors_tool_loop.py | 268 ++++++++++++++++++ 3 files changed, 371 insertions(+), 15 deletions(-) diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 260e675a73..8eff36bc16 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -38,7 +38,15 @@ _TOOL_ALL_PATS, _TOOL_CLOSED_PATS, parse_tool_calls_from_text, - strip_tool_call_markup, +) +# Stripping and signal-marker constants come from the multi-format +# parser so Llama-3 / Mistral / Gemma 4 emissions are also detected +# in the BUFFERING state machine and stripped from the assistant +# stream. Pre-PR-5615 we used the legacy two-format helper which +# only knew / str: if not auto_heal_tool_calls: return text - return strip_tool_call_markup(text, final = final) + return _shared_strip_tool_markup(text, final = final) - # XML prefixes that signal a tool call in content. - # Empty when auto_heal is disabled so the buffer never - # speculatively holds content for XML detection. + # Markers the BUFFERING state machine watches for. Empty when + # auto-heal is off so the buffer never speculatively holds + # content. Covers all five emission formats the shared parser + # understands: Qwen , Qwen3.5 , Mistral [TOOL_CALLS], Gemma 4 <|tool_call>. _TOOL_XML_SIGNALS = ( - ("", " str: arguments = json.loads(raw_args) except (json.JSONDecodeError, ValueError): if auto_heal_tool_calls: - arguments = {"query": raw_args} + # Per-tool canonical heal key so a bare + # string emission still runs the right + # tool: ``code`` for python, ``command`` + # for terminal, ``query`` for everything + # else (e.g. web_search). Mirrors + # safetensors_agentic._CANONICAL_HEAL_ARG. + _heal_key = { + "python": "code", + "terminal": "command", + }.get(tool_name, "query") + arguments = {_heal_key: raw_args} else: arguments = {"raw": raw_args} else: diff --git a/studio/backend/core/inference/safetensors_agentic.py b/studio/backend/core/inference/safetensors_agentic.py index 73bb3d090a..f70421b584 100644 --- a/studio/backend/core/inference/safetensors_agentic.py +++ b/studio/backend/core/inference/safetensors_agentic.py @@ -18,6 +18,7 @@ """ import json +import re import threading from typing import Callable, Generator, Optional from urllib.parse import urlparse @@ -42,6 +43,30 @@ # Buffer cap while waiting to disambiguate a possible tool-call prefix. _MAX_BUFFER_CHARS = 32 +# Forward-looking intent signals that indicate the model is describing +# what it *will* do rather than giving a final answer. Mirrors the GGUF +# path so safetensors / MLX nudge the model to act when it stalls on +# planning instead of calling a tool. Excludes "I can", "I should", +# "I want to", "let's" which appear in direct answers / explanations. +_INTENT_SIGNAL = re.compile( + r"(?i)(" + # Direct intent: "I'll", "I will", "Let me", "I am going to". + r"\b(i['’](ll|m going to|m gonna)|i am (going to|gonna)|i will|i shall|let me|allow me)\b" + r"|" + # Step / plan framing: "First", "Step 1:", "Here's my plan". + r"\b(?:first\b|step \d+:?|here['’]?s (?:my |the |a )?(?:plan|approach))" + r"|" + # "Now I" / "Next I" patterns. + r"\b(?:now i|next i)\b" + r")" +) +_MAX_REPROMPTS = 3 +_REPROMPT_MAX_CHARS = 2000 +_REPROMPT_INSTRUCTION = ( + "STOP. Do NOT write code or explain. You MUST call a tool NOW. " + "Call web_search or python immediately." +) + def _status_for_tool(tool_name: str, arguments: dict) -> str: """Return a human-readable status line matching the GGUF path.""" @@ -142,6 +167,7 @@ def run_safetensors_tool_loop( if (tool.get("function") or {}).get("name") } next_call_id = 0 + reprompt_count = 0 if max_tool_iterations <= 0: # 0 = disabled (same contract as the GGUF loop). @@ -152,7 +178,10 @@ def run_safetensors_tool_loop( _state_streaming = 1 _state_draining = 2 - for iteration in range(max_tool_iterations + 1): + # Reserve extra iterations for re-prompts so they do not eat the + # caller's tool-call budget. Mirrors GGUF (_MAX_REPROMPTS slots). + _extra_iters = _MAX_REPROMPTS if max_tool_iterations > 0 else 0 + for iteration in range(max_tool_iterations + _extra_iters + 1): if cancel_event is not None and cancel_event.is_set(): return @@ -242,14 +271,18 @@ def run_safetensors_tool_loop( if stripped and has_tool_signal(stripped): detect_state = _state_draining else: + # Emit the buffered content, then fall through to the + # STREAMING block so the intent re-prompt + safety-net + # parser still get a chance. Without this, a short + # intent emission like "Let me search." that never + # exits BUFFERING would silently terminate the loop. if content_buffer: cumulative_display += content_buffer - yield { - "type": "content", - "text": strip_tool_markup(cumulative_display, final = True), - } - yield {"type": "status", "text": ""} - return + cleaned = strip_tool_markup(cumulative_display, final = True) + if len(cleaned) > len(last_emitted): + last_emitted = cleaned + yield {"type": "content", "text": cleaned} + detect_state = _state_streaming if detect_state == _state_streaming: # No tool detected mid-stream -- check for late tool XML. @@ -260,6 +293,36 @@ def run_safetensors_tool_loop( id_offset = next_call_id, ) if not safety_tc: + # Re-prompt on plan-without-action: if the model + # described what it intends to do but did not call a + # tool, nudge it to act. Mirrors the GGUF path. Only + # fires on responses that signal intent / planning -- + # direct answers like "4" or "Hello!" don't trigger. + _stripped = content_accum.strip() + if ( + tools + and reprompt_count < _MAX_REPROMPTS + and 0 < len(_stripped) < _REPROMPT_MAX_CHARS + and _INTENT_SIGNAL.search(_stripped) + and not final_attempt_done + ): + reprompt_count += 1 + logger.info( + "Safetensors re-prompt %d/%d: model planned without " + "calling tools (%d chars)", + reprompt_count, + _MAX_REPROMPTS, + len(_stripped), + ) + conversation.append( + {"role": "assistant", "content": _stripped} + ) + conversation.append( + {"role": "user", "content": _REPROMPT_INSTRUCTION} + ) + yield {"type": "status", "text": ""} + continue + # Final answer: streaming already emitted content. # Skip a final=True re-strip so literal "" # in prose survives when no real tool call parsed. @@ -379,7 +442,12 @@ def run_safetensors_tool_loop( # Clear the status badge before the next turn. yield {"type": "status", "text": ""} - if iteration + 1 >= max_tool_iterations and not final_attempt_done: + # Budget tracked against the caller-requested cap, ignoring + # the re-prompt slots so a stalling model still gets a final + # answer attempt. Tool-call iterations executed = iteration - + # reprompt_count. + _tool_iters_done = iteration + 1 - reprompt_count + if _tool_iters_done >= max_tool_iterations and not final_attempt_done: # Budget exhausted; nudge a final plain answer. final_attempt_done = True conversation.append( diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index c838cab72d..ae5af37dde 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -806,6 +806,274 @@ def test_exception_in_executor_does_not_raise(self): assert "boom" in tool_end["result"] +class TestLoopRePrompt: + """Re-prompt-on-plan-without-action parity with the GGUF path. + + When the model emits forward-looking intent ("Let me search for + that") without actually calling a tool, the loop must nudge it to + act instead of silently terminating. Up to ``_MAX_REPROMPTS`` (3) + re-prompts per request, drawn from extra iteration slots so the + caller's tool-call budget is preserved. + """ + + def test_intent_signal_triggers_reprompt(self): + # Turn 1: intent signal, no tool call. + # Turn 2 (re-prompt): proper tool call -> executes. + # Turn 3: final answer. + loop, exec_fn = _make_loop( + turns = [ + ["Let me search for that."], + [ + '{"name":"web_search","arguments":' + '{"query":"sky color"}}' + ], + ["The sky is blue."], + ], + exec_results = ["Blue (Rayleigh scattering)"], + ) + events = _collect_events(loop) + # web_search must have been called once (after the re-prompt). + assert exec_fn.calls == [("web_search", {"query": "sky color"})] + contents = [e for e in events if e["type"] == "content"] + assert contents and "blue" in contents[-1]["text"].lower() + + def test_intent_signal_without_tools_does_not_reprompt(self): + # Same intent signal but no tools enabled -- must NOT re-prompt. + loop, exec_fn = _make_loop( + turns = [["Let me think about that for a moment."]], + exec_results = [], + ) + # _make_loop hard-codes three tools; rebuild without tools. + from core.inference.safetensors_agentic import run_safetensors_tool_loop + + def _gen(_messages): + yield "Let me think about that for a moment." + + exec_fn = FakeExecuteTool([]) + events = _collect_events( + run_safetensors_tool_loop( + single_turn = _gen, + messages = [{"role": "user", "content": "hi"}], + tools = [], + execute_tool = exec_fn, + ) + ) + assert exec_fn.calls == [] + contents = [e for e in events if e["type"] == "content"] + assert contents and "think" in contents[-1]["text"].lower() + + def test_direct_answer_does_not_trigger_reprompt(self): + # Plain answer with no intent words: do NOT re-prompt. + loop, exec_fn = _make_loop( + turns = [["4"]], + exec_results = [], + ) + events = _collect_events(loop) + assert exec_fn.calls == [] + contents = [e for e in events if e["type"] == "content"] + assert contents and contents[-1]["text"].strip() == "4" + + def test_max_reprompts_capped_at_three(self): + # Model keeps stalling with intent -- after 3 re-prompts the + # loop must give up rather than burn forever. + turns = [["Let me search for that."]] * 6 # well over the cap + loop, exec_fn = _make_loop( + turns = turns, + exec_results = [], + ) + events = _collect_events(loop, max_events = 500) + # No tool ever ran, but the loop terminated cleanly. + assert exec_fn.calls == [] + statuses = [e for e in events if e["type"] == "status"] + assert statuses and statuses[-1]["text"] == "" + + def test_short_intent_below_buffer_threshold_triggers_reprompt(self): + # Short emission that never exits BUFFERING (< 32 chars + no + # marker prefix). The unified buffer-end path must still + # trigger the intent re-prompt, not silently terminate. + loop, exec_fn = _make_loop( + turns = [ + ["Let me check."], + [ + '{"name":"web_search","arguments":' + '{"query":"x"}}' + ], + ["found"], + ], + exec_results = ["..."], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "x"})] + + def test_reprompt_does_not_consume_tool_budget(self): + # max_tool_iterations=1: one re-prompt, then one real tool call, + # then the budget-exhausted final answer must still fire. If the + # re-prompt ate the slot the tool call would never run. + loop, exec_fn = _make_loop( + turns = [ + # 1. Intent stall (re-prompt 1/3). + ["Let me search for that."], + # 2. Real tool call (uses the budget slot). + [ + '{"name":"web_search","arguments":' + '{"query":"weather"}}' + ], + # 3. Budget exhausted -> nudged final answer. + ["Final: it is sunny"], + ], + exec_results = ["sunny"], + max_tool_iterations = 1, + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather"})] + contents = [e for e in events if e["type"] == "content"] + assert contents and "sunny" in contents[-1]["text"].lower() + + +class TestLoopCanonicalHealKey: + """Per-tool canonical heal key (``code`` for python, ``command`` for + terminal, ``query`` for everything else). Mirrors GGUF after the + PR-5615 follow-up that ported this mapping over.""" + + def test_python_bare_string_heals_to_code(self): + loop, exec_fn = _make_loop( + turns = [ + [ + '{"name":"python","arguments":"print(1)"}' + "" + ], + ["done"], + ], + exec_results = ["1\n"], + ) + events = _collect_events(loop) + # The bare string must heal to {"code": "print(1)"}, not + # {"query": ...}, so the python sandbox actually executes it. + assert exec_fn.calls == [("python", {"code": "print(1)"})] + + def test_terminal_bare_string_heals_to_command(self): + loop, exec_fn = _make_loop( + turns = [ + [ + '{"name":"terminal","arguments":"ls -la"}' + "" + ], + ["done"], + ], + exec_results = ["..."], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("terminal", {"command": "ls -la"})] + + def test_unknown_tool_bare_string_heals_to_query(self): + loop, exec_fn = _make_loop( + turns = [ + [ + '{"name":"web_search","arguments":"hello"}' + "" + ], + ["ok"], + ], + exec_results = ["..."], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "hello"})] + + +class TestGGUFSafetensorsHealingParity: + """Pin parity between the GGUF agentic loop and the safetensors / + MLX loop so a regression on either side breaks CI.""" + + def test_gguf_imports_shared_signal_markers(self): + # The GGUF BUFFERING state machine must wake on every emission + # marker the shared parser knows -- otherwise Llama-3 / Mistral + # / Gemma 4 emissions slip past as plain prose when the + # llama-server structured channel fails. + import inspect + + from core.inference.llama_cpp import LlamaCppBackend + + src = inspect.getsource( + LlamaCppBackend.generate_chat_completion_with_tools + ) + assert "_SHARED_TOOL_XML_SIGNALS" in src, ( + "GGUF agentic loop must reuse the shared TOOL_XML_SIGNALS " + "tuple so it wakes on all five emission formats" + ) + + def test_gguf_uses_shared_strip_helper(self): + # The GGUF stream-cleanup function must delegate to the shared + # strip_tool_markup so closed-pair markup is removed for every + # emission family (Llama-3 <|python_tag|>, Mistral [TOOL_CALLS], + # Gemma 4 <|tool_call>...). + import inspect + + from core.inference.llama_cpp import LlamaCppBackend + + src = inspect.getsource( + LlamaCppBackend.generate_chat_completion_with_tools + ) + assert "_shared_strip_tool_markup" in src, ( + "GGUF stream cleanup must delegate to the shared " + "strip_tool_markup helper" + ) + + def test_gguf_uses_canonical_heal_keys(self): + # GGUF must heal a bare-string ``arguments`` to the same per-tool + # canonical key as safetensors -- ``code`` for python, ``command`` + # for terminal, ``query`` for everything else. + import inspect + + from core.inference.llama_cpp import LlamaCppBackend + + src = inspect.getsource( + LlamaCppBackend.generate_chat_completion_with_tools + ) + # The canonical key dict literal must be present in the heal + # path so a Llama-3 / Mistral / Gemma 4 bare-string emission + # for python doesn't get routed as {"query": "print(1)"}. + assert '"python": "code"' in src + assert '"terminal": "command"' in src + + def test_intent_regex_matches_same_phrases_as_gguf(self): + # The intent re-prompt regex must match the SAME forward-looking + # phrases on both backends so behaviour is the same on Mac (MLX + # / safetensors) and on Linux (GGUF). + from core.inference.llama_cpp import _INTENT_SIGNAL as gguf_re + from core.inference.safetensors_agentic import ( + _INTENT_SIGNAL as sf_re, + ) + + for phrase in ( + "I'll search for that", + "I will look it up", + "Let me check", + "I am going to call the tool", + "First, I will explore", + "Here's my plan", + "Now I need to call web_search", + ): + assert gguf_re.search(phrase), f"GGUF missed {phrase!r}" + assert sf_re.search(phrase), f"safetensors missed {phrase!r}" + + for plain in ( + "4", + "Hello!", + "The sky is blue.", + "I can help with that.", + "I should mention", + "Let's go.", + ): + assert not gguf_re.search(plain), f"GGUF wrongly fired on {plain!r}" + assert not sf_re.search(plain), f"safetensors wrongly fired on {plain!r}" + + def test_max_reprompts_equal_on_both_backends(self): + from core.inference.llama_cpp import _MAX_REPROMPTS as gguf_cap + from core.inference.safetensors_agentic import _MAX_REPROMPTS as sf_cap + + assert gguf_cap == sf_cap == 3 + + class TestLoopControl: def test_cancel_event_breaks_loop(self): cancel = threading.Event() From f79068046e87a49c8231bf496b9f6773f1afb448 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 14:28:56 +0000 Subject: [PATCH 03/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/core/inference/llama_cpp.py | 5 ++-- .../core/inference/safetensors_agentic.py | 4 +-- .../tests/test_safetensors_tool_loop.py | 27 +++++-------------- 3 files changed, 9 insertions(+), 27 deletions(-) diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 8eff36bc16..3275df0240 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -39,6 +39,7 @@ _TOOL_CLOSED_PATS, parse_tool_calls_from_text, ) + # Stripping and signal-marker constants come from the multi-format # parser so Llama-3 / Mistral / Gemma 4 emissions are also detected # in the BUFFERING state machine and stripped from the assistant @@ -4306,9 +4307,7 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: # content. Covers all five emission formats the shared parser # understands: Qwen , Qwen3.5 , Mistral [TOOL_CALLS], Gemma 4 <|tool_call>. - _TOOL_XML_SIGNALS = ( - _SHARED_TOOL_XML_SIGNALS if auto_heal_tool_calls else () - ) + _TOOL_XML_SIGNALS = _SHARED_TOOL_XML_SIGNALS if auto_heal_tool_calls else () _MAX_BUFFER_CHARS = 32 # ── Duplicate tool-call detection ──────────────────────── diff --git a/studio/backend/core/inference/safetensors_agentic.py b/studio/backend/core/inference/safetensors_agentic.py index f70421b584..e1b9c0bc8c 100644 --- a/studio/backend/core/inference/safetensors_agentic.py +++ b/studio/backend/core/inference/safetensors_agentic.py @@ -314,9 +314,7 @@ def run_safetensors_tool_loop( _MAX_REPROMPTS, len(_stripped), ) - conversation.append( - {"role": "assistant", "content": _stripped} - ) + conversation.append({"role": "assistant", "content": _stripped}) conversation.append( {"role": "user", "content": _REPROMPT_INSTRUCTION} ) diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index ae5af37dde..842ce54e49 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -938,10 +938,7 @@ class TestLoopCanonicalHealKey: def test_python_bare_string_heals_to_code(self): loop, exec_fn = _make_loop( turns = [ - [ - '{"name":"python","arguments":"print(1)"}' - "" - ], + ['{"name":"python","arguments":"print(1)"}' ""], ["done"], ], exec_results = ["1\n"], @@ -954,10 +951,7 @@ def test_python_bare_string_heals_to_code(self): def test_terminal_bare_string_heals_to_command(self): loop, exec_fn = _make_loop( turns = [ - [ - '{"name":"terminal","arguments":"ls -la"}' - "" - ], + ['{"name":"terminal","arguments":"ls -la"}' ""], ["done"], ], exec_results = ["..."], @@ -968,10 +962,7 @@ def test_terminal_bare_string_heals_to_command(self): def test_unknown_tool_bare_string_heals_to_query(self): loop, exec_fn = _make_loop( turns = [ - [ - '{"name":"web_search","arguments":"hello"}' - "" - ], + ['{"name":"web_search","arguments":"hello"}' ""], ["ok"], ], exec_results = ["..."], @@ -993,9 +984,7 @@ def test_gguf_imports_shared_signal_markers(self): from core.inference.llama_cpp import LlamaCppBackend - src = inspect.getsource( - LlamaCppBackend.generate_chat_completion_with_tools - ) + src = inspect.getsource(LlamaCppBackend.generate_chat_completion_with_tools) assert "_SHARED_TOOL_XML_SIGNALS" in src, ( "GGUF agentic loop must reuse the shared TOOL_XML_SIGNALS " "tuple so it wakes on all five emission formats" @@ -1010,9 +999,7 @@ def test_gguf_uses_shared_strip_helper(self): from core.inference.llama_cpp import LlamaCppBackend - src = inspect.getsource( - LlamaCppBackend.generate_chat_completion_with_tools - ) + src = inspect.getsource(LlamaCppBackend.generate_chat_completion_with_tools) assert "_shared_strip_tool_markup" in src, ( "GGUF stream cleanup must delegate to the shared " "strip_tool_markup helper" @@ -1026,9 +1013,7 @@ def test_gguf_uses_canonical_heal_keys(self): from core.inference.llama_cpp import LlamaCppBackend - src = inspect.getsource( - LlamaCppBackend.generate_chat_completion_with_tools - ) + src = inspect.getsource(LlamaCppBackend.generate_chat_completion_with_tools) # The canonical key dict literal must be present in the heal # path so a Llama-3 / Mistral / Gemma 4 bare-string emission # for python doesn't get routed as {"query": "print(1)"}. From e9b4d3f7e41ca02195e44cdca814c995493ff25f Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Tue, 19 May 2026 14:56:20 +0000 Subject: [PATCH 04/17] studio: fix tool-call parser bugs from gemini review on #5620 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three high-priority gemini findings on the tool-call parsing additions: 1. unicode_escape on UTF-8 bytes corrupts non-ASCII literals (e.g. ✨ becomes â\x9c¨). Replace with json.loads on a quoted string -- preserves emoji / CJK / RTL while still handling \n \t \uXXXX escapes. 2. Llama-3 sentinel stripping is order-dependent. A leading `<|eot_id|><|begin_of_text|>` left `<|begin_of_text|>` behind because the loop had already passed that sentinel. Loop until no sentinel matches at the start. 3. Mistral v11+ `[TOOL_CALLS] name { json }` regex uses non-greedy `\{.*?\}` which truncates at the first `}` of a nested JSON argument, leaking the tail (e.g. `}}`) into user-visible streamed text. Same problem for the v0.3 array pattern with nested brackets. Strip those with balanced brace/bracket scanning via a new `_strip_mistral_closed_calls` helper called from `strip_tool_markup`. Also fix the inference routes' parallel `_TOOL_XML_RE`: - Same nested-JSON truncation in the Mistral patterns; route the strip through the parser's balanced-scan helper via a thin `_strip_tool_xml` wrapper that all existing callers now use. - Llama-3 `<|python_tag|>[^\n<]*` stopped at any `<`, leaking the tail of any tool call whose argument contained a literal `<` (queries, code snippets). Relax to `[^\n]*` which keeps the strip confined to the actual end-of-line. --- .../core/inference/tool_call_parser.py | 137 ++++++++++++++++-- studio/backend/routes/inference.py | 43 ++++-- 2 files changed, 157 insertions(+), 23 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 7b0c9b1f62..270c05fe8d 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -52,9 +52,6 @@ re.compile(r".*?", re.DOTALL), re.compile(r".*?", re.DOTALL), re.compile(r"<\|tool_call>.*?", re.DOTALL), - re.compile(r"\[TOOL_CALLS\]\s*\[.*?\](?:\s*)?", re.DOTALL), - # Mistral v11+ ``[TOOL_CALLS]name{json}`` (may chain), close at ``}``. - re.compile(r"\[TOOL_CALLS\]\s*[\w\.\-]+\s*(?:\[ARGS\])?\s*\{.*?\}", re.DOTALL), ] _TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [ re.compile(r".*$", re.DOTALL), @@ -137,6 +134,106 @@ # ── Public API ────────────────────────────────────────────────────── +def _balanced_bracket_end(text: str, start: int) -> int | None: + """Index of the matching ``]`` for the ``[`` at ``text[start]``. + + Skips brackets inside JSON string literals. Returns ``None`` if no + matching close is found. + """ + if start >= len(text) or text[start] != "[": + return None + depth = 0 + in_string = False + esc = False + i = start + while i < len(text): + ch = text[i] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + elif ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + +def _strip_mistral_closed_calls(text: str) -> str: + """Strip ``[TOOL_CALLS]`` blocks with balanced brace/bracket scanning. + + Handles three Mistral emission shapes: + + - ``[TOOL_CALLS] [ {...}, {...} ]`` (v0.3 / Nemo / Small) + - ``[TOOL_CALLS] name { json }`` (v11+ / Magistral) + - ``[TOOL_CALLS] name [ARGS] { json }`` (Ministral / Large 3) + + The regex ``\\{.*?\\}`` truncates at the first ``}``, losing nested + JSON, so this walks balanced braces/brackets instead. Only matches + runs that close cleanly; unclosed trailing markup is left in place + for ``final=True`` cleanup via ``_TOOL_ALL_PATS``. + """ + n = len(text) + out = [] + cursor = 0 + while cursor < n: + idx = text.find(_MISTRAL_TRIGGER, cursor) + if idx == -1: + out.append(text[cursor:]) + break + out.append(text[cursor:idx]) + body_start = idx + len(_MISTRAL_TRIGGER) + # Skip whitespace + optional name + optional ``[ARGS]``. + i = body_start + while i < n and text[i] in " \t\n\r": + i += 1 + # Array shape: [TOOL_CALLS] [ ... ] + if i < n and text[i] == "[": + end = _balanced_bracket_end(text, i) + if end is None: + # Truncated mid-array; leave the trigger and everything + # after it in place so caller can buffer / final-strip. + out.append(text[idx:]) + break + cursor = end + 1 + # Optional trailing (Mistral EOS). + if text.startswith("", cursor): + cursor += len("") + continue + # Named shape: [TOOL_CALLS] name [ARGS]? { json } + name_match = _MISTRAL_V11_NAME_RE.match(text, i) + if not name_match: + out.append(text[idx:body_start]) + cursor = body_start + continue + i = name_match.end() + while i < n and text[i] in " \t\n\r": + i += 1 + if text.startswith(_MISTRAL_ARGS_MARKER, i): + i += len(_MISTRAL_ARGS_MARKER) + while i < n and text[i] in " \t\n\r": + i += 1 + if i >= n or text[i] != "{": + out.append(text[idx:i]) + cursor = i + continue + end = _balanced_brace_end(text, i) + if end is None: + out.append(text[idx:]) + break + cursor = end + 1 + return "".join(out) + + def strip_tool_markup(text: str, *, final: bool = False) -> str: """Strip tool-call markup from streamed text. @@ -144,6 +241,10 @@ def strip_tool_markup(text: str, *, final: bool = False) -> str: stays buffered. ``final=True`` also removes trailing unclosed runs and trims the result. """ + # Mistral patterns need balanced brace/bracket scanning -- a + # non-greedy regex would truncate at the first ``}`` inside a + # nested JSON object and leak the rest into user-visible text. + text = _strip_mistral_closed_calls(text) pats = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS for pat in pats: text = pat.sub("", text) @@ -315,9 +416,13 @@ def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: for kv in _LLAMA3_KV_RE.finditer(body): k = kv.group(1) if kv.group(2) is not None: + # json.loads on a wrapped JSON string handles + # \n / \t / \uXXXX escapes correctly while preserving + # literal UTF-8 bytes (emoji, CJK, etc.) that the older + # ``bytes.decode("unicode_escape")`` path mangled. try: - args[k] = bytes(kv.group(2), "utf-8").decode("unicode_escape") - except (UnicodeDecodeError, ValueError): + args[k] = json.loads('"' + kv.group(2) + '"') + except (json.JSONDecodeError, ValueError): args[k] = kv.group(2) elif kv.group(3) is not None: v = kv.group(3) @@ -401,18 +506,28 @@ def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: out: list[dict] = [] stripped = content.lstrip() # Strip leading Llama-3 sentinel tokens that sometimes precede the - # JSON (``<|eot_id|>`` from the prior turn, ``<|start_header_id|>``). - for sentinel in ( + # JSON (``<|eot_id|>`` from the prior turn, ``<|start_header_id|>``, + # ``<|begin_of_text|>``). Loop until no sentinel matches: the + # tokens can appear in any order and chain, so a single pass would + # leave later sentinels behind once an earlier one consumed its + # prefix. + _sentinels = ( "<|begin_of_text|>", "<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>", "<|eom_id|>", - ): + ) + while True: stripped = stripped.lstrip() - if stripped.startswith(sentinel): - stripped = stripped[len(sentinel) :] - stripped = stripped.lstrip() + matched = False + for sentinel in _sentinels: + if stripped.startswith(sentinel): + stripped = stripped[len(sentinel) :] + matched = True + break + if not matched: + break if not stripped.startswith("{"): return out diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 39f2004fa5..8b0d854412 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -442,19 +442,38 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: # ``<|python_tag|>``, Mistral ``[TOOL_CALLS]`` pre-v11 array and v11+ # ``name{json}``, Gemma 4 ``<|tool_call>...``). Closed # pairs only so in-progress markup stays buffered upstream. +# Flat-regex patterns. Mistral ``[TOOL_CALLS]`` blocks need balanced +# brace/bracket scanning -- a non-greedy ``\{.*?\}`` truncates at the +# first ``}`` of a nested JSON arg, so those are handled by the parser +# module's ``_strip_mistral_closed_calls`` helper invoked by +# ``_strip_tool_xml`` below. _TOOL_XML_RE = _re.compile( "|".join( [ r".*?", r".*?", r"<\|tool_call>.*?", - r"\[TOOL_CALLS\]\s*\[.*?\](?:\s*)?", - r"\[TOOL_CALLS\]\s*[\w\.\-]+\s*(?:\[ARGS\])?\s*\{.*?\}", - r"<\|python_tag\|>[^\n<]*", + # ``<|python_tag|>...`` runs to end of line. ``[^\n<]`` was + # used to stop at any ``<`` but that leaked the tail of + # any tool call whose argument contained a literal ``<`` + # (queries, code snippets) into the user-visible stream. + r"<\|python_tag\|>[^\n]*", ] ), _re.DOTALL, ) + + +def _strip_tool_xml(text: str) -> str: + """Strip closed-pair tool-call markup with balanced brace scanning. + + Combines the shared parser's Mistral helper (handles nested JSON + correctly) with ``_TOOL_XML_RE`` for the remaining flat patterns. + """ + from studio.backend.core.inference.tool_call_parser import ( + _strip_mistral_closed_calls, + ) + return _TOOL_XML_RE.sub("", _strip_mistral_closed_calls(text)) logger = get_logger(__name__) @@ -2418,7 +2437,7 @@ async def audio_input_stream(): if _msg.get("role") == "assistant" and isinstance( _msg.get("content"), str ): - _msg["content"] = _TOOL_XML_RE.sub("", _msg["content"]).strip() + _msg["content"] = _strip_tool_xml(_msg["content"]).strip() def gguf_generate_with_tools(): return llama_backend.generate_chat_completion_with_tools( @@ -2519,7 +2538,7 @@ async def gguf_tool_stream(): # the last sanitized snapshot so cross-chunk XML # tags are handled correctly. raw_cumulative = event.get("text", "") - clean_cumulative = _TOOL_XML_RE.sub("", raw_cumulative) + clean_cumulative = _strip_tool_xml(raw_cumulative) new_text = clean_cumulative[len(prev_text) :] prev_text = clean_cumulative if not new_text: @@ -2904,7 +2923,7 @@ async def gguf_stream_chunks(): _sf_chat_messages.append( { **_msg, - "content": _TOOL_XML_RE.sub("", _msg["content"]).strip(), + "content": _strip_tool_xml(_msg["content"]).strip(), } ) else: @@ -2991,7 +3010,7 @@ async def sf_tool_stream(): # Diff cumulative cleaned text against last snapshot. raw_cumulative = event.get("text", "") - clean_cumulative = _TOOL_XML_RE.sub("", raw_cumulative) + clean_cumulative = _strip_tool_xml(raw_cumulative) new_text = clean_cumulative[len(prev_text) :] prev_text = clean_cumulative if not new_text: @@ -3063,7 +3082,7 @@ def _drain_to_text(): if cancel_event.is_set(): break if event.get("type") == "content": - full_text = _TOOL_XML_RE.sub("", event.get("text", "")) + full_text = _strip_tool_xml(event.get("text", "")) return full_text content_text = await asyncio.to_thread(_drain_to_text) @@ -4490,7 +4509,7 @@ async def anthropic_messages( # Strip stale tool-call XML from conversation for _msg in openai_messages: if _msg.get("role") == "assistant" and isinstance(_msg.get("content"), str): - _msg["content"] = _TOOL_XML_RE.sub("", _msg["content"]).strip() + _msg["content"] = _strip_tool_xml(_msg["content"]).strip() def _run_tool_gen(): return llama_backend.generate_chat_completion_with_tools( @@ -4582,7 +4601,7 @@ async def _stream(): # Strip leaked tool-call XML from content events if event.get("type") == "content": event = dict(event) - event["text"] = _TOOL_XML_RE.sub("", event["text"]) + event["text"] = _strip_tool_xml(event["text"]) for line in emitter.feed(event): yield line except Exception as e: @@ -4673,7 +4692,7 @@ async def _anthropic_tool_non_streaming(run_gen, message_id, model_name): etype = event.get("type", "") if etype == "content": # Strip leaked tool-call XML - clean = _TOOL_XML_RE.sub("", event["text"]) + clean = _strip_tool_xml(event["text"]) new = clean[len(prev_text) :] prev_text = clean if new: @@ -5010,7 +5029,7 @@ async def _anthropic_passthrough_non_streaming( content_blocks = [] text = message.get("content") or "" if text: - text = _TOOL_XML_RE.sub("", text).strip() + text = _strip_tool_xml(text).strip() if text: content_blocks.append(AnthropicResponseTextBlock(text = text)) From 7d8e725aaabd1d392d29e755aadb784e26db3ba6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 14:56:43 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/routes/inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 8b0d854412..288a439512 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -473,7 +473,10 @@ def _strip_tool_xml(text: str) -> str: from studio.backend.core.inference.tool_call_parser import ( _strip_mistral_closed_calls, ) + return _TOOL_XML_RE.sub("", _strip_mistral_closed_calls(text)) + + logger = get_logger(__name__) From 7ef4b11590f408311ab5d93e65c10081fa4602d0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 22 May 2026 08:18:32 +0000 Subject: [PATCH 06/17] studio/routes: make python_tag strip multi-line aware Earlier revisions of _TOOL_XML_RE in studio.backend.routes.inference oscillated between two bug shapes: 5615 r"<\|python_tag\|>[^\n<]*" -- stopped at any literal "<" so code='if x < 10: pass' leaked '< 10: pass)' to the user. 5620.1 r"<\|python_tag\|>[^\n]*" -- single-line only; the second line of python.call(code="a\nb") leaked. The full parser (_parse_llama3_python_tag) already handles both via balanced-brace scanning, so the parsing path was fine; the LEAK was in the streaming strip path that runs on every cumulative emission while content is still arriving. Switch to r"<\|python_tag\|>(?:[^<]|<(?!\|))*" so the strip consumes: * any character that is not a "<" (newlines, JSON, code, ...), * a "<" only when it is NOT followed by "|" (i.e. NOT a Llama-3 sentinel start like <|eot_id|>, <|eom_id|>, <|begin_of_text|>). This means: * code='if x < 10' stays inside the strip (5615 fix preserved), * multi-line code stays inside the strip (5620 round 2), * the strip terminates at the next Llama-3 sentinel so trailing assistant content survives. Tests: TestRoutesPythonTagStrip (8 cases) pytest test_safetensors_tool_loop.py test_safetensors_capability_advertise.py -> 118 passed in 1.81s (was 110). --- studio/backend/routes/inference.py | 15 ++- .../tests/test_safetensors_tool_loop.py | 92 +++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 288a439512..4d713bad96 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -453,11 +453,16 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: r".*?", r".*?", r"<\|tool_call>.*?", - # ``<|python_tag|>...`` runs to end of line. ``[^\n<]`` was - # used to stop at any ``<`` but that leaked the tail of - # any tool call whose argument contained a literal ``<`` - # (queries, code snippets) into the user-visible stream. - r"<\|python_tag\|>[^\n]*", + # ``<|python_tag|>...`` runs until the next Llama-3 ``<|`` + # sentinel (``<|eot_id|>``, ``<|eom_id|>``, etc.) or end of + # text. Earlier revisions used ``[^\n<]*`` (leaked tools + # whose args contained a literal ``<`` like + # ``code="if x < 10"``) and then ``[^\n]*`` (single-line + # only; multi-line ``python.call(code="line1\nline2")`` + # leaked the second line). ``(?:[^<]|<(?!\|))*`` consumes + # any character that is not a sentinel start, so newlines, + # bare ``<``, and embedded JSON all stay inside the strip. + r"<\|python_tag\|>(?:[^<]|<(?!\|))*", ] ), _re.DOTALL, diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index 842ce54e49..d7e2a13e38 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -1388,5 +1388,97 @@ def test_empty_or_none_returns_false(self): assert is_gpt_oss_model_name(None) is False +# ──────────────────────────────────────────────────────────────────── +# Routes-level python_tag strip (multi-line; stop on next sentinel) +# ──────────────────────────────────────────────────────────────────── + + +class TestRoutesPythonTagStrip: + """Earlier revisions of ``_TOOL_XML_RE`` in + ``studio.backend.routes.inference`` used either ``[^\\n<]*`` (5615 -- + leaked the tail of any tool call whose argument contained a literal + ``<`` like ``code="if x < 10"``) or ``[^\\n]*`` (5620 round one -- + single-line only, so the second line of + ``python.call(code="line1\\nline2")`` leaked). The current pattern + ``(?:[^<]|<(?!\\|))*`` consumes any character that is not a Llama-3 + ``<|`` sentinel start, so multi-line code, embedded JSON, and bare + ``<`` characters in code all stay inside the strip. + + The fully resolved strip is also exposed via + ``strip_tool_markup(text, final=True)`` in the parser; the + streaming path's routes-level strip is the regression-prone one + because it runs on every cumulative emission while content is + still arriving. + """ + + def _strip(self, text: str) -> str: + # Import inside the test so a routes-module import error does + # not blow up the entire test file at collection time. + from routes.inference import _strip_tool_xml + + return _strip_tool_xml(text) + + def test_single_line_python_tag_stripped(self): + # Floor: the original 5620 single-line behaviour still works. + text = '<|python_tag|>brave_search.call(query="weather")' + assert self._strip(text) == "" + + def test_python_tag_with_less_than_in_code(self): + # 5615 regression: literal ``<`` inside code must NOT terminate + # the strip early. + text = '<|python_tag|>python.call(code="if x < 10: pass")' + assert self._strip(text) == "" + + def test_python_tag_multiline_code_stripped(self): + # 5620 round-1 regression: multi-line code's second line leaked. + text = '<|python_tag|>python.call(code="line1\nline2\nline3")' + assert self._strip(text) == "" + + def test_python_tag_multiline_with_less_than(self): + # Combined: multi-line code AND literal ``<`` in code. + text = ( + '<|python_tag|>python.call(code="for i in range(10):\n' + ' if i < 5:\n' + ' print(i)")' + ) + assert self._strip(text) == "" + + def test_python_tag_stops_at_eom_sentinel(self): + # Strip stops at the next Llama-3 ``<|`` sentinel so any + # trailing assistant content survives. + text = ( + '<|python_tag|>python.call(code="multi\nline")' + "<|eom_id|>final answer text" + ) + assert self._strip(text) == "<|eom_id|>final answer text" + + def test_python_tag_stops_at_eot_sentinel(self): + text = ( + "<|python_tag|>brave_search.call(query=\"x\")" + "<|eot_id|>after" + ) + assert self._strip(text) == "<|eot_id|>after" + + def test_python_tag_json_form_multiline_stripped(self): + # The JSON form of python_tag with newlines inside string args. + text = ( + '<|python_tag|>{"name":"python",' + '"parameters":{"code":"a = 1\nb = 2\nprint(a+b)"}}' + ) + assert self._strip(text) == "" + + def test_python_tag_with_eom_then_trailing_python_tag(self): + # Two python_tag emissions back-to-back across a sentinel: both + # should strip independently. + text = ( + '<|python_tag|>brave_search.call(query="a")' + "<|eom_id|>" + '<|python_tag|>python.call(code="x=1")' + ) + # ``<|eom_id|>`` between the two strips remains; both + # python_tag blocks are fully consumed. + assert self._strip(text) == "<|eom_id|>" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 0b4fa12a4e7f373db213daf43a5a4f72246aa8b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 08:18:44 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/tests/test_safetensors_tool_loop.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index d7e2a13e38..729e9ce766 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -1438,7 +1438,7 @@ def test_python_tag_multiline_with_less_than(self): # Combined: multi-line code AND literal ``<`` in code. text = ( '<|python_tag|>python.call(code="for i in range(10):\n' - ' if i < 5:\n' + " if i < 5:\n" ' print(i)")' ) assert self._strip(text) == "" @@ -1453,10 +1453,7 @@ def test_python_tag_stops_at_eom_sentinel(self): assert self._strip(text) == "<|eom_id|>final answer text" def test_python_tag_stops_at_eot_sentinel(self): - text = ( - "<|python_tag|>brave_search.call(query=\"x\")" - "<|eot_id|>after" - ) + text = '<|python_tag|>brave_search.call(query="x")' "<|eot_id|>after" assert self._strip(text) == "<|eot_id|>after" def test_python_tag_json_form_multiline_stripped(self): From 7f9177a183acaac36895a1f0cb897668540f7677 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 12:21:22 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/tests/test_cpu_threads.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/studio/backend/tests/test_cpu_threads.py b/studio/backend/tests/test_cpu_threads.py index 0dbcbdb74b..1224941622 100644 --- a/studio/backend/tests/test_cpu_threads.py +++ b/studio/backend/tests/test_cpu_threads.py @@ -63,16 +63,18 @@ def test_cpu_thread_cap_is_opt_in(raw): # Anything that is not a positive integer raises a clear ValueError. -@pytest.mark.parametrize("raw", ["zero", "0", "-3", "1.5", "abc", "8a", "0x4", "1e3", "4 0"]) +@pytest.mark.parametrize( + "raw", ["zero", "0", "-3", "1.5", "abc", "8a", "0x4", "1e3", "4 0"] +) def test_cpu_thread_cap_requires_positive_integer(raw): - with pytest.raises(ValueError, match="must be a positive integer"): + with pytest.raises(ValueError, match = "must be a positive integer"): configure_cpu_threads({"UNSLOTH_CPU_THREADS": raw}) # env=None path uses real os.environ (production call from run.py / main.py). def test_cpu_thread_cap_uses_os_environ_when_env_is_none(monkeypatch): for variable in (*_THREAD_POOL_ENV_VARS, "UNSLOTH_CPU_THREADS"): - monkeypatch.delenv(variable, raising=False) + monkeypatch.delenv(variable, raising = False) monkeypatch.setenv("UNSLOTH_CPU_THREADS", "3") configure_cpu_threads() @@ -84,7 +86,7 @@ def test_cpu_thread_cap_uses_os_environ_when_env_is_none(monkeypatch): # Calling twice must not flip any seeded value. def test_cpu_thread_cap_idempotent(monkeypatch): for variable in (*_THREAD_POOL_ENV_VARS, "UNSLOTH_CPU_THREADS"): - monkeypatch.delenv(variable, raising=False) + monkeypatch.delenv(variable, raising = False) monkeypatch.setenv("UNSLOTH_CPU_THREADS", "5") configure_cpu_threads() @@ -138,9 +140,9 @@ def test_invalid_cpu_thread_cap_exits_without_traceback(entry_point): result = subprocess.run( [sys.executable, str(entry_point)], - env=env, - capture_output=True, - text=True, + env = env, + capture_output = True, + text = True, ) assert result.returncode == 1 From 1fe632d2f5aa87417cc18cdec042829e4ab46f5e Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 27 May 2026 12:37:05 +0000 Subject: [PATCH 09/17] studio: tighten verbose comments in tool-call parser sections Comments were narrating what the code already says. Cut historical "earlier revisions used X, then Y" narratives down to one-line WHY notes where the footgun still matters (canonical heal-key parity, balanced-brace vs non-greedy regex, ``(?:[^<]|<(?!\|))*`` over ``[^\n<]*``/``[^\n]*``). Drop section-header banners. No behaviour change. Re-ran: pytest studio/backend/tests/test_safetensors_tool_loop.py \ studio/backend/tests/test_safetensors_capability_advertise.py -q -> 118 passed. Regression replay (parser + _coerce_arguments on the 5 #5615 inputs) -> 21/21. --- studio/backend/core/inference/llama_cpp.py | 26 +- .../core/inference/safetensors_agentic.py | 44 ++-- .../core/inference/tool_call_parser.py | 239 ++++++------------ studio/backend/routes/inference.py | 64 ++--- 4 files changed, 118 insertions(+), 255 deletions(-) diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 44e6fa5bd9..f876a86c6e 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -40,11 +40,9 @@ parse_tool_calls_from_text, ) -# Stripping and signal-marker constants come from the multi-format -# parser so Llama-3 / Mistral / Gemma 4 emissions are also detected -# in the BUFFERING state machine and stripped from the assistant -# stream. Pre-PR-5615 we used the legacy two-format helper which -# only knew / / str: return text return _shared_strip_tool_markup(text, final = final) - # Markers the BUFFERING state machine watches for. Empty when - # auto-heal is off so the buffer never speculatively holds - # content. Covers all five emission formats the shared parser - # understands: Qwen , Qwen3.5 , Mistral [TOOL_CALLS], Gemma 4 <|tool_call>. + # Markers the BUFFERING state machine watches for; covers Qwen, + # Qwen3.5, Llama-3, Mistral, and Gemma 4. Empty when auto-heal + # is off so the buffer never speculatively holds content. _TOOL_XML_SIGNALS = _SHARED_TOOL_XML_SIGNALS if auto_heal_tool_calls else () _MAX_BUFFER_CHARS = 32 @@ -5020,12 +5016,10 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: arguments = json.loads(raw_args) except (json.JSONDecodeError, ValueError): if auto_heal_tool_calls: - # Per-tool canonical heal key so a bare - # string emission still runs the right - # tool: ``code`` for python, ``command`` - # for terminal, ``query`` for everything - # else (e.g. web_search). Mirrors - # safetensors_agentic._CANONICAL_HEAL_ARG. + # Canonical per-tool heal key (must match + # safetensors_agentic._CANONICAL_HEAL_ARG) + # so bare-string emissions still run the + # intended tool. _heal_key = { "python": "code", "terminal": "command", diff --git a/studio/backend/core/inference/safetensors_agentic.py b/studio/backend/core/inference/safetensors_agentic.py index e1b9c0bc8c..36d02d59ed 100644 --- a/studio/backend/core/inference/safetensors_agentic.py +++ b/studio/backend/core/inference/safetensors_agentic.py @@ -43,21 +43,15 @@ # Buffer cap while waiting to disambiguate a possible tool-call prefix. _MAX_BUFFER_CHARS = 32 -# Forward-looking intent signals that indicate the model is describing -# what it *will* do rather than giving a final answer. Mirrors the GGUF -# path so safetensors / MLX nudge the model to act when it stalls on -# planning instead of calling a tool. Excludes "I can", "I should", -# "I want to", "let's" which appear in direct answers / explanations. +# Forward-looking intent ("I'll...", "First, ...", "Step 1:") that +# means the model is planning rather than answering. Used to nudge it +# to call a tool. Excludes "I can / I should / I want / let's" because +# those also appear in direct answers and explanations. Mirrors GGUF. _INTENT_SIGNAL = re.compile( r"(?i)(" - # Direct intent: "I'll", "I will", "Let me", "I am going to". r"\b(i['’](ll|m going to|m gonna)|i am (going to|gonna)|i will|i shall|let me|allow me)\b" - r"|" - # Step / plan framing: "First", "Step 1:", "Here's my plan". - r"\b(?:first\b|step \d+:?|here['’]?s (?:my |the |a )?(?:plan|approach))" - r"|" - # "Now I" / "Next I" patterns. - r"\b(?:now i|next i)\b" + r"|\b(?:first\b|step \d+:?|here['’]?s (?:my |the |a )?(?:plan|approach))" + r"|\b(?:now i|next i)\b" r")" ) _MAX_REPROMPTS = 3 @@ -178,8 +172,7 @@ def run_safetensors_tool_loop( _state_streaming = 1 _state_draining = 2 - # Reserve extra iterations for re-prompts so they do not eat the - # caller's tool-call budget. Mirrors GGUF (_MAX_REPROMPTS slots). + # Reserve re-prompt slots so they don't eat the caller's tool budget. _extra_iters = _MAX_REPROMPTS if max_tool_iterations > 0 else 0 for iteration in range(max_tool_iterations + _extra_iters + 1): if cancel_event is not None and cancel_event.is_set(): @@ -271,11 +264,10 @@ def run_safetensors_tool_loop( if stripped and has_tool_signal(stripped): detect_state = _state_draining else: - # Emit the buffered content, then fall through to the - # STREAMING block so the intent re-prompt + safety-net - # parser still get a chance. Without this, a short - # intent emission like "Let me search." that never - # exits BUFFERING would silently terminate the loop. + # Drain the buffer and fall through to STREAMING so the + # intent re-prompt + safety-net parser can still fire on + # short emissions like "Let me search." that never exit + # BUFFERING (would otherwise silently end the loop). if content_buffer: cumulative_display += content_buffer cleaned = strip_tool_markup(cumulative_display, final = True) @@ -293,11 +285,9 @@ def run_safetensors_tool_loop( id_offset = next_call_id, ) if not safety_tc: - # Re-prompt on plan-without-action: if the model - # described what it intends to do but did not call a - # tool, nudge it to act. Mirrors the GGUF path. Only - # fires on responses that signal intent / planning -- - # direct answers like "4" or "Hello!" don't trigger. + # Re-prompt only when the model planned without acting + # (intent signal present); direct answers like "4" or + # "Hello!" never trigger. Mirrors GGUF. _stripped = content_accum.strip() if ( tools @@ -440,10 +430,8 @@ def run_safetensors_tool_loop( # Clear the status badge before the next turn. yield {"type": "status", "text": ""} - # Budget tracked against the caller-requested cap, ignoring - # the re-prompt slots so a stalling model still gets a final - # answer attempt. Tool-call iterations executed = iteration - - # reprompt_count. + # Track against the caller-requested cap, excluding re-prompt + # slots so a stalling model still gets a final-answer attempt. _tool_iters_done = iteration + 1 - reprompt_count if _tool_iters_done >= max_tool_iterations and not final_attempt_done: # Budget exhausted; nudge a final plain answer. diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 270c05fe8d..2d31061765 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -26,13 +26,8 @@ from typing import Any -# ── Streaming-buffer signal markers ───────────────────────────────── - - -# Prefixes the safetensors / MLX streaming buffer watches for to gate -# in-progress text. When ANY of these appear in the cumulative text, -# the state machine switches from STREAMING to DRAINING so we don't -# leak partial markup to the user before we can parse it. +# Markers that flip the streaming buffer from STREAMING to DRAINING so +# partial markup never leaks before the parser sees it. TOOL_XML_SIGNALS = ( "", ".*?", re.DOTALL), re.compile(r".*?", re.DOTALL), @@ -62,9 +53,6 @@ ] -# ── Nudges + error-result prefixes ────────────────────────────────── - - TOOL_ERROR_PREFIXES = ( "Error", "Search failed", @@ -95,19 +83,16 @@ ) -# ── Format-specific regexes ───────────────────────────────────────── - - -# Qwen / Hermes {json} +# Qwen / Hermes ``{json}``. _TC_JSON_START_RE = re.compile(r"\s*\{") -# Qwen3.5 / Hermes XML form v +# Qwen3.5 / Hermes XML ``v``. _TC_FUNC_START_RE = re.compile(r"\s*") _TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") _TC_PARAM_START_RE = re.compile(r"\s*") _TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") -# Llama-3 <|python_tag|>NAME.call(...) +# Llama-3 ``<|python_tag|>NAME.call(...)``. _LLAMA3_PYTHON_TAG = "<|python_tag|>" _LLAMA3_PY_CALL_RE = re.compile( r"<\|python_tag\|>\s*([\w\.\-]+)\s*\.\s*call\s*\(", @@ -117,29 +102,23 @@ re.VERBOSE, ) -# Mistral [TOOL_CALLS] trigger. v11+ chains multiple triggers, each -# followed by a bare name then either ``{json}`` (Magistral) or -# ``[ARGS]{json}`` (Ministral / Mistral Large 3). +# Mistral ``[TOOL_CALLS]`` trigger. v11+ chains them, each followed by +# a bare name plus ``{json}`` (Magistral) or ``[ARGS]{json}`` (Ministral +# / Large 3). _MISTRAL_TRIGGER = "[TOOL_CALLS]" _MISTRAL_ARGS_MARKER = "[ARGS]" _MISTRAL_V11_NAME_RE = re.compile(r"\s*([\w\.\-]+)\s*") -# Gemma 4 <|tool_call>call:NAME{...}. ``<|"|>`` wraps strings. +# Gemma 4: ``<|tool_call>call:NAME{...}``, ``<|"|>`` wraps strings. _GEMMA_TC_RE = re.compile(r"<\|tool_call>\s*call\s*:\s*([\w\.\-]+)\s*\{") _GEMMA_STR_BEGIN = '<|"|>' _GEMMA_STR_END = '<|"|>' _GEMMA_TC_END = "" -# ── Public API ────────────────────────────────────────────────────── - - def _balanced_bracket_end(text: str, start: int) -> int | None: - """Index of the matching ``]`` for the ``[`` at ``text[start]``. - - Skips brackets inside JSON string literals. Returns ``None`` if no - matching close is found. - """ + """Index of `]` matching `[` at ``text[start]``; ignores brackets + in JSON strings. None if unmatched.""" if start >= len(text) or text[start] != "[": return None depth = 0 @@ -169,18 +148,11 @@ def _balanced_bracket_end(text: str, start: int) -> int | None: def _strip_mistral_closed_calls(text: str) -> str: - """Strip ``[TOOL_CALLS]`` blocks with balanced brace/bracket scanning. - - Handles three Mistral emission shapes: + """Strip cleanly-closed ``[TOOL_CALLS]`` blocks (array, ``name{json}``, + or ``name[ARGS]{json}``) via balanced brace/bracket scanning. - - ``[TOOL_CALLS] [ {...}, {...} ]`` (v0.3 / Nemo / Small) - - ``[TOOL_CALLS] name { json }`` (v11+ / Magistral) - - ``[TOOL_CALLS] name [ARGS] { json }`` (Ministral / Large 3) - - The regex ``\\{.*?\\}`` truncates at the first ``}``, losing nested - JSON, so this walks balanced braces/brackets instead. Only matches - runs that close cleanly; unclosed trailing markup is left in place - for ``final=True`` cleanup via ``_TOOL_ALL_PATS``. + A non-greedy ``\\{.*?\\}`` would truncate at the first ``}`` and lose + nested JSON. Unclosed runs are left for ``final=True`` cleanup. """ n = len(text) out = [] @@ -192,24 +164,21 @@ def _strip_mistral_closed_calls(text: str) -> str: break out.append(text[cursor:idx]) body_start = idx + len(_MISTRAL_TRIGGER) - # Skip whitespace + optional name + optional ``[ARGS]``. i = body_start while i < n and text[i] in " \t\n\r": i += 1 - # Array shape: [TOOL_CALLS] [ ... ] + # Array shape: ``[TOOL_CALLS] [...]``. if i < n and text[i] == "[": end = _balanced_bracket_end(text, i) if end is None: - # Truncated mid-array; leave the trigger and everything - # after it in place so caller can buffer / final-strip. + # Truncated; let caller buffer / final-strip. out.append(text[idx:]) break cursor = end + 1 - # Optional trailing (Mistral EOS). if text.startswith("", cursor): cursor += len("") continue - # Named shape: [TOOL_CALLS] name [ARGS]? { json } + # Named shape: ``[TOOL_CALLS] name [ARGS]? { json }``. name_match = _MISTRAL_V11_NAME_RE.match(text, i) if not name_match: out.append(text[idx:body_start]) @@ -235,15 +204,9 @@ def _strip_mistral_closed_calls(text: str) -> str: def strip_tool_markup(text: str, *, final: bool = False) -> str: - """Strip tool-call markup from streamed text. - - ``final=False`` only removes closed pairs so in-progress markup - stays buffered. ``final=True`` also removes trailing unclosed runs - and trims the result. - """ - # Mistral patterns need balanced brace/bracket scanning -- a - # non-greedy regex would truncate at the first ``}`` inside a - # nested JSON object and leak the rest into user-visible text. + """Strip tool-call markup. ``final=False`` keeps in-progress + markup buffered; ``final=True`` also drops trailing unclosed runs + and trims.""" text = _strip_mistral_closed_calls(text) pats = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS for pat in pats: @@ -252,53 +215,29 @@ def strip_tool_markup(text: str, *, final: bool = False) -> str: def has_tool_signal(text: str) -> bool: - """True if ``text`` contains any known tool-call signal.""" return any(s in text for s in TOOL_XML_SIGNALS) def parse_tool_calls_from_text(content: str, *, id_offset: int = 0) -> list[dict]: - """Parse OpenAI-format ``tool_calls`` from model text. + """Return OpenAI-format tool calls. Tries each format and returns + as soon as one matches so we never double-count.""" + for parser in ( + _parse_tool_call_json, # Qwen / Hermes + _parse_function_xml, # Qwen3.5 / Hermes XML + _parse_llama3_python_tag, # Llama-3 + _parse_mistral_tool_calls, # Mistral + _parse_gemma_tool_calls, # Gemma 4 + ): + calls = parser(content, id_offset = id_offset) + if calls: + return calls - Returns ``[{"id", "type", "function": {"name", "arguments"}}]`` - where ``arguments`` is always a JSON string. Tries each known - emission format in turn; returns as soon as one yields calls so - we never double-count. - """ - # Qwen / Hermes {json} - calls = _parse_tool_call_json(content, id_offset = id_offset) - if calls: - return calls - - # Qwen3.5 / Hermes v - calls = _parse_function_xml(content, id_offset = id_offset) - if calls: - return calls - - # Llama-3 <|python_tag|>... - calls = _parse_llama3_python_tag(content, id_offset = id_offset) - if calls: - return calls - - # Mistral [TOOL_CALLS]... - calls = _parse_mistral_tool_calls(content, id_offset = id_offset) - if calls: - return calls - - # Gemma 4 <|tool_call>... - calls = _parse_gemma_tool_calls(content, id_offset = id_offset) - if calls: - return calls - - # Llama-3.2 bare JSON ``{"name":..., "parameters":...}`` (no tag). - # Strict: only fires when stripped content STARTS with ``{`` and - # parses as ``{name: str, parameters|arguments: dict}``. Keeps - # plain assistant prose unaffected. + # Llama-3.2 bare ``{"name":..., "parameters":...}``. Strict: only + # fires on content that starts with ``{`` and parses as the right + # shape, so plain prose stays untouched. return _parse_llama3_bare_json(content, id_offset = id_offset) -# ── Per-format parsers ────────────────────────────────────────────── - - def _parse_tool_call_json(content: str, *, id_offset: int) -> list[dict]: out: list[dict] = [] for m in _TC_JSON_START_RE.finditer(content): @@ -375,17 +314,15 @@ def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]: def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: - """Llama-3 emission shapes: - <|python_tag|>NAME.call(arg="v", ...) (built-in tools) - <|python_tag|>{"name":"NAME", "parameters":{...}} (custom tools) - <|python_tag|>{"name":...}; {"name":...} (multi-call, ``; `` sep) - Accepts both ``parameters`` and ``arguments`` keys per Llama 3.1/3.2. + """Parse the four Llama-3 emissions: ``<|python_tag|>NAME.call(...)`` + (built-in), ``<|python_tag|>{"name":..., "parameters":...}`` (custom), + multi-call via ``; `` separators, ``parameters`` or ``arguments`` key. """ out: list[dict] = [] if _LLAMA3_PYTHON_TAG not in content: return out - # 1. NAME.call(...) built-in form. + # 1. ``NAME.call(...)`` built-in form. for m in _LLAMA3_PY_CALL_RE.finditer(content): name = m.group(1) i = m.end() @@ -416,10 +353,10 @@ def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: for kv in _LLAMA3_KV_RE.finditer(body): k = kv.group(1) if kv.group(2) is not None: - # json.loads on a wrapped JSON string handles - # \n / \t / \uXXXX escapes correctly while preserving - # literal UTF-8 bytes (emoji, CJK, etc.) that the older - # ``bytes.decode("unicode_escape")`` path mangled. + # ``json.loads`` on a quoted string handles \n/\t/\uXXXX + # escapes correctly AND keeps literal UTF-8 bytes (emoji + # / CJK) intact -- the older ``bytes.decode('unicode_escape')`` + # path mangled non-ASCII. try: args[k] = json.loads('"' + kv.group(2) + '"') except (json.JSONDecodeError, ValueError): @@ -437,24 +374,19 @@ def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: } ) - # 2. <|python_tag|>{"name":..., "parameters":...} JSON form. Use a - # streaming JSON decoder (raw_decode) so we can peel multiple - # objects out of the same emission (separated by ``; `` per - # Llama 3 template). + # 2. ``<|python_tag|>{"name":..., "parameters":...}``. ``raw_decode`` + # peels multiple ``; ``-separated objects from one emission. if not out: decoder = json.JSONDecoder() idx = content.find(_LLAMA3_PYTHON_TAG) while idx >= 0: search_from = idx + len(_LLAMA3_PYTHON_TAG) - # Scan all `{` from this trigger; raw_decode jumps the - # cursor past each parsed object, but if a `{` falls - # inside an already-decoded object we skip it. cursor = search_from while cursor < len(content): brace = content.find("{", cursor) if brace < 0: break - # Stop if we've hit the next <|python_tag|>. + # Stop at the next ``<|python_tag|>``. next_tag = content.find(_LLAMA3_PYTHON_TAG, search_from, brace) if next_tag >= 0: break @@ -492,25 +424,13 @@ def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: - """Llama-3.2 ``custom_tools`` shape -- bare JSON ``{"name":..., - "parameters":{...}}`` emitted directly, no ``<|python_tag|>``. - - Strict to avoid firing on tool-message echoes: - - * Content must start with ``{`` once whitespace and any leading - ``<|begin_of_text|>`` / ``<|eot_id|>`` etc. sentinels are stripped. - * Object must have ``name`` (non-empty str) plus a dict in - ``parameters`` or ``arguments``. - * Loops via ``raw_decode`` to peel multiple ``;``-separated calls. - """ + """Llama-3.2 ``custom_tools``: bare ``{"name":..., "parameters":{...}}`` + without ``<|python_tag|>``. Strict (must start with ``{`` after sentinel + strip; ``name`` non-empty; ``parameters`` or ``arguments`` is a dict) so + plain prose and tool-message echoes don't trigger.""" out: list[dict] = [] stripped = content.lstrip() - # Strip leading Llama-3 sentinel tokens that sometimes precede the - # JSON (``<|eot_id|>`` from the prior turn, ``<|start_header_id|>``, - # ``<|begin_of_text|>``). Loop until no sentinel matches: the - # tokens can appear in any order and chain, so a single pass would - # leave later sentinels behind once an earlier one consumed its - # prefix. + # Sentinels can chain in any order, so loop until none match. _sentinels = ( "<|begin_of_text|>", "<|eot_id|>", @@ -535,7 +455,7 @@ def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: cursor = 0 n = len(stripped) while cursor < n: - # Skip whitespace and Llama 3 inter-call separator ``;``. + # Skip whitespace and the Llama-3 ``;`` inter-call separator. while cursor < n and stripped[cursor] in " \t\n\r;": cursor += 1 if cursor >= n or stripped[cursor] != "{": @@ -573,20 +493,16 @@ def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: - """Mistral emissions covered: - Pre-v11 array: ``[TOOL_CALLS] [{"name":..., "arguments":...}, ...]`` - Pre-v11 single: ``[TOOL_CALLS]{"name":..., "arguments":...}`` - v11+ single: ``[TOOL_CALLS]name{json_args}`` - v11+ parallel: ``[TOOL_CALLS]a{...}[TOOL_CALLS]b{...}`` - v11+ w/ [ARGS]: ``[TOOL_CALLS]name[ARGS]{json_args}`` (Ministral / Large 3) - """ + """Parse all Mistral emissions: pre-v11 ``[TOOL_CALLS][...]`` / + ``[TOOL_CALLS]{...}`` and v11+ ``[TOOL_CALLS]name{json}`` / + ``[TOOL_CALLS]name[ARGS]{json}`` (parallel-friendly).""" out: list[dict] = [] idx = content.find(_MISTRAL_TRIGGER) if idx < 0: return out - # Decide whether the FIRST occurrence is array / single-object - # (pre-v11) or v11+ bare-name. Skip whitespace, peek at next char. + # Disambiguate the first occurrence: array (pre-v11), single object + # (pre-v11), or bare-name (v11+). j = idx + len(_MISTRAL_TRIGGER) k = j while k < len(content) and content[k] in " \t\n\r": @@ -598,10 +514,8 @@ def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: return _parse_mistral_array(content, k, id_offset) if content[k] == "{": - # Could be pre-v11 single object ``{"name": ...}`` or a JSON - # blob immediately following the trigger (rare). Try parsing - # as an object that exposes ``name``; if not, fall through to - # v11+ handling so we don't drop emission silently. + # Pre-v11 single ``{"name":...}``; fall through if it doesn't + # carry a ``name`` so v11+ handling still gets a chance. end = _balanced_brace_end(content, k) if end is not None: try: @@ -612,8 +526,8 @@ def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: except (json.JSONDecodeError, ValueError): pass - # v11+ path: walk every ``[TOOL_CALLS]`` and parse ``name{json}`` - # or ``name[ARGS]{json}`` after each trigger. + # v11+: walk every ``[TOOL_CALLS]``, parsing ``name{json}`` or + # ``name[ARGS]{json}`` after each trigger. pos = idx while pos >= 0: cur = pos + len(_MISTRAL_TRIGGER) @@ -623,7 +537,6 @@ def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: continue name = nm.group(1) after_name = nm.end() - # Optional ``[ARGS]`` marker. if content.startswith(_MISTRAL_ARGS_MARKER, after_name): after_name += len(_MISTRAL_ARGS_MARKER) while after_name < len(content) and content[after_name] in " \t\n\r": @@ -657,7 +570,7 @@ def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: def _parse_mistral_array(content: str, start: int, id_offset: int) -> list[dict]: - """Parse pre-v11 ``[TOOL_CALLS] [{...}, ...]`` JSON array form.""" + """Pre-v11 ``[TOOL_CALLS] [{...}, ...]`` array form.""" out: list[dict] = [] j = start depth = 0 @@ -694,7 +607,7 @@ def _parse_mistral_array(content: str, start: int, id_offset: int) -> list[dict] except (json.JSONDecodeError, ValueError): pass - # Healing path: walk objects manually for unclosed array. + # Healing path for unclosed arrays: walk objects by hand. for m in re.finditer(r"\{", body): end = _balanced_brace_end(body, m.start()) if end is None: @@ -729,7 +642,7 @@ def _consume_mistral_call(obj_text: str, out: list[dict], id_offset: int) -> Non def _parse_gemma_tool_calls(content: str, *, id_offset: int) -> list[dict]: - """Gemma 4: <|tool_call>call:NAME{k:<|"|>v<|"|>, ...}.""" + """Gemma 4: ``<|tool_call>call:NAME{k:<|"|>v<|"|>, ...}``.""" out: list[dict] = [] for m in _GEMMA_TC_RE.finditer(content): name = m.group(1) @@ -754,12 +667,9 @@ def _parse_gemma_tool_calls(content: str, *, id_offset: int) -> list[dict]: return out -# ── Brace-balancing helpers ───────────────────────────────────────── - - def _balanced_brace_end(text: str, brace_pos: int) -> int | None: - """Index of `}` matching `{` at ``brace_pos`` -- ignores `{` `}` - inside JSON strings. Returns None if unmatched.""" + """Index of `}` matching `{` at ``brace_pos``; ignores braces inside + JSON strings. None if unmatched.""" if brace_pos >= len(text) or text[brace_pos] != "{": return None depth = 0 @@ -789,8 +699,8 @@ def _balanced_brace_end(text: str, brace_pos: int) -> int | None: def _gemma_balanced_brace_end(text: str, brace_pos: int, hard_stop: int) -> int | None: - """Same as ``_balanced_brace_end`` but respects Gemma ``<|"|>`` - string runs and matches `{`/`[` symmetrically.""" + """Like ``_balanced_brace_end`` but skips ``<|"|>`` strings and + matches `{`/`[` symmetrically.""" if brace_pos >= len(text) or text[brace_pos] != "{": return None depth = 0 @@ -814,8 +724,7 @@ def _gemma_balanced_brace_end(text: str, brace_pos: int, hard_stop: int) -> int def _gemma_parse_value(text: str, i: int): - """Parse one Gemma argument value starting at ``i``. Returns - ``(value, next_index)``.""" + """Parse one Gemma arg value at ``i``; returns ``(value, next_index)``.""" if text.startswith(_GEMMA_STR_BEGIN, i): close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) if close < 0: @@ -854,7 +763,7 @@ def _gemma_parse_value(text: str, i: int): v, k = _gemma_parse_value(body, k) items.append(v) return items, j + 1 - # Primitive: number, true/false/null, or bare identifier (rare). + # Primitive: number / true/false/null / bare identifier. end = i while ( end < len(text) @@ -881,7 +790,7 @@ def _gemma_parse_value(text: str, i: int): def _gemma_parse_mapping_body(body: str) -> dict[str, Any]: - """Parse content between `{` and `}` for a Gemma argument mapping.""" + """Parse a Gemma argument mapping (content between `{` and `}`).""" out: dict[str, Any] = {} i = 0 n = len(body) diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 118207ce54..08c5d980eb 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -260,16 +260,11 @@ def _detect_safetensors_features(backend, chat_template: Optional[str]) -> dict: "supports_tools": False, } ) - # The safetensors / MLX loop parses these emission formats: - # Qwen ``{json}``, Qwen3.5 ``...``, - # Llama-3 ``<|python_tag|>``, Llama-3.2 bare JSON ``{"name":..., - # "parameters":...}``, Mistral ``[TOOL_CALLS]`` (pre-v11 array + - # v11+ ``name{json}``), and Gemma 4 ``<|tool_call>...``. If the - # template advertises tools but does NOT use any of these markers, - # the parser cannot honour the emission - drop the pill. ``{"name":`` - # catches Llama-3.2's ``custom_tools`` shape whose template instructs - # the model to "Respond in the format {\"name\": ..., \"parameters\": - # ...}" without a ``<|python_tag|>`` prefix. + # Markers the safetensors / MLX parser recognises. If the template + # advertises tools but uses none of them, drop the pill (parser + # can't honour the emission). The two ``{"name":`` variants cover + # Llama-3.2 ``custom_tools`` whose template prompts the bare-JSON + # form without a ``<|python_tag|>`` prefix. _PARSER_MARKERS = ( "", " None: " Do NOT output code blocks -- use the python tool instead." ) -# Regex for stripping leaked tool-call markup from assistant messages / -# stream. Covers every emission format the shared parser handles -# (Qwen / Hermes ````, Qwen3.5 ````, Llama-3 -# ``<|python_tag|>``, Mistral ``[TOOL_CALLS]`` pre-v11 array and v11+ -# ``name{json}``, Gemma 4 ``<|tool_call>...``). -# -# Mistral ``[TOOL_CALLS]`` blocks need balanced brace/bracket scanning -# -- a non-greedy ``\{.*?\}`` truncates at the first ``}`` of a nested -# JSON arg, so those are handled by the parser module's -# ``_strip_mistral_closed_calls`` helper invoked by ``_strip_tool_xml`` -# below. -# -# We also have to scrub four leak shapes the speculative buffer in -# ``core/inference/llama_cpp.py`` can split across the visible/DRAIN -# boundary: -# 1. well-formed ``...`` / ``...`` -# 2. orphan opening to EOF (close was DRAINED) -- match to ``\Z`` -# 3. bare orphan close (open was DRAINED) -# 4. tail-only ```` (outer close truncated by EOS); anchored -# to ``\Z`` so mid-text ```` in user code samples survives. +# Strip leaked tool-call markup. Covers every shared-parser format AND +# the four leak shapes the speculative buffer in ``llama_cpp.py`` splits +# across the visible/DRAIN boundary (closed pair, orphan open to EOF, +# bare orphan close, tail-only ````). Mistral ``[TOOL_CALLS]`` +# is delegated to the parser's balanced-brace helper -- a non-greedy +# ``\{.*?\}`` here would truncate nested JSON at the first ``}``. _TOOL_XML_RE = _re.compile( "|".join( [ @@ -467,19 +448,14 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: r"<(?:tool_call|function=\w+)>.*?(?:|\Z)", # Bare orphan close (open was DRAINED upstream). r"", - # Gemma 4 ``<|tool_call>...``. + # Gemma 4. r"<\|tool_call>.*?", - # ``<|python_tag|>...`` runs until the next Llama-3 ``<|`` - # sentinel (``<|eot_id|>``, ``<|eom_id|>``, etc.) or end of - # text. Earlier revisions used ``[^\n<]*`` (leaked tools - # whose args contained a literal ``<`` like - # ``code="if x < 10"``) and then ``[^\n]*`` (single-line - # only; multi-line ``python.call(code="line1\nline2")`` - # leaked the second line). ``(?:[^<]|<(?!\|))*`` consumes - # any character that is not a sentinel start, so newlines, - # bare ``<``, and embedded JSON all stay inside the strip. + # Llama-3 ``<|python_tag|>...`` to the next ``<|`` sentinel + # or EOF. ``(?:[^<]|<(?!\|))*`` (not ``[^\n<]*`` or + # ``[^\n]*``) keeps literal ``<``, newlines, and embedded + # JSON inside the strip. r"<\|python_tag\|>(?:[^<]|<(?!\|))*", - # Tail-only ```` (truncated outer close, EOS). + # Tail-only ```` (anchored so mid-text survives). r"\s*\Z", ] ), @@ -488,11 +464,7 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: def _strip_tool_xml(text: str) -> str: - """Strip closed-pair tool-call markup with balanced brace scanning. - - Combines the shared parser's Mistral helper (handles nested JSON - correctly) with ``_TOOL_XML_RE`` for the remaining flat patterns. - """ + """Combine the Mistral balanced-brace helper with ``_TOOL_XML_RE``.""" from studio.backend.core.inference.tool_call_parser import ( _strip_mistral_closed_calls, ) From 9b3a6c665a26fe32b53ec657ed96e7ce4a62d324 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 12:37:48 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/core/inference/tool_call_parser.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 2d31061765..9f603b7839 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -222,11 +222,11 @@ def parse_tool_calls_from_text(content: str, *, id_offset: int = 0) -> list[dict """Return OpenAI-format tool calls. Tries each format and returns as soon as one matches so we never double-count.""" for parser in ( - _parse_tool_call_json, # Qwen / Hermes - _parse_function_xml, # Qwen3.5 / Hermes XML - _parse_llama3_python_tag, # Llama-3 - _parse_mistral_tool_calls, # Mistral - _parse_gemma_tool_calls, # Gemma 4 + _parse_tool_call_json, # Qwen / Hermes + _parse_function_xml, # Qwen3.5 / Hermes XML + _parse_llama3_python_tag, # Llama-3 + _parse_mistral_tool_calls, # Mistral + _parse_gemma_tool_calls, # Gemma 4 ): calls = parser(content, id_offset = id_offset) if calls: From 2a76a9b1efd42ad10c96281083f62afe3f679448 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 27 May 2026 13:31:18 +0000 Subject: [PATCH 11/17] studio: parser robustness fixes for PR #5620 Three surgical extensions to the multi-format tool-call parser, each covering a real fine-tune / template emission shape that the current parser silently drops. No path narrows; all changes widen what is accepted. 1. `_parse_tool_call_json` now accepts both `arguments` and `parameters` keys. A Hermes / Qwen `{json}` wrapper around a Llama-3.2 fine-tune that emits the `parameters` key was extracting the tool name and silently discarding the args, producing a working-shaped call with an empty payload. The bare-JSON and python_tag paths already accepted both keys; this path now matches them. 2. `_TC_FUNC_START_RE`, `_TC_PARAM_START_RE`, and `_TC_PARAM_CLOSE_RE` now also match the attribute form `v` used by MiniCPM-5 and MiniMax-M2. Names land in either capture group, and `` is accepted as a short close. 3. `_parse_llama3_bare_json` sentinel-strip now consumes the role label inserted between `<|start_header_id|>` and `<|end_header_id|>` by Meta's official Llama-3.x chat template. Without this, every assistant turn re-fed through the template prefix `<|start_header_id|>assistant<|end_header_id|>\n\n{json}` parsed to zero calls, so any history-with-tool-call round-trip in production silently dropped. Tests in `studio/backend/tests/test_safetensors_tool_loop.py`: * `TestParserRobustness::test_tool_call_json_accepts_parameters_key` * `TestParserRobustness::test_function_xml_attribute_form` * `TestParserRobustness::test_function_xml_attribute_form_multi_param` * `TestParserRobustness::test_function_xml_legacy_equals_form_still_works` (regression guard for the existing `` syntax) * `TestParserRobustness::test_llama3_chat_template_round_trip` * `TestParserRobustness::test_llama3_round_trip_all_roles` * `TestParserRobustness::test_llama3_round_trip_with_eot_prefix` `pytest studio/backend/tests/test_safetensors_tool_loop.py studio/backend/tests/test_safetensors_capability_advertise.py -q` goes from 118 to 125 passed. --- .../core/inference/tool_call_parser.py | 39 +++++-- .../tests/test_safetensors_tool_loop.py | 105 ++++++++++++++++++ 2 files changed, 136 insertions(+), 8 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 9f603b7839..8473a83260 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -85,12 +85,18 @@ # Qwen / Hermes ``{json}``. _TC_JSON_START_RE = re.compile(r"\s*\{") -# Qwen3.5 / Hermes XML ``v``. -_TC_FUNC_START_RE = re.compile(r"\s*") +# Qwen3.5 / Hermes ``v`` AND the attribute +# form ``v`` used by MiniCPM-5, +# MiniMax-M2, etc. Name lands in group(1) or group(2). +_TC_FUNC_START_RE = re.compile( + r'\s*' +) _TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") -_TC_PARAM_START_RE = re.compile(r"\s*") -_TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") +_TC_PARAM_START_RE = re.compile( + r'<(?:parameter|param)(?:=([\w\.\-]+)|\s+name="([\w\.\-]+)")>\s*' +) +_TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") # Llama-3 ``<|python_tag|>NAME.call(...)``. _LLAMA3_PYTHON_TAG = "<|python_tag|>" @@ -250,7 +256,12 @@ def _parse_tool_call_json(content: str, *, id_offset: int) -> list[dict]: except (json.JSONDecodeError, ValueError): continue name = obj.get("name", "") - args = obj.get("arguments", {}) + # Accept both ``arguments`` (Hermes/Qwen) and ``parameters`` + # (Llama-3 template drift) so a fine-tune that swaps the key + # keeps its payload instead of silently parsing to ``{}``. + args = obj.get("arguments") + if args is None: + args = obj.get("parameters", {}) if isinstance(args, dict): args_str = json.dumps(args) elif isinstance(args, str): @@ -273,7 +284,8 @@ def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]: out: list[dict] = [] func_starts = list(_TC_FUNC_START_RE.finditer(content)) for idx, fm in enumerate(func_starts): - func_name = fm.group(1) + # group(1) is ````, group(2) is ````. + func_name = fm.group(1) or fm.group(2) body_start = fm.end() next_func = ( func_starts[idx + 1].start() if idx + 1 < len(func_starts) else len(content) @@ -291,7 +303,7 @@ def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]: if len(param_starts) == 1: pm = param_starts[0] val = _TC_PARAM_CLOSE_RE.sub("", body[pm.end() :]) - args[pm.group(1)] = val.strip() + args[pm.group(1) or pm.group(2)] = val.strip() else: for pidx, pm in enumerate(param_starts): val_start = pm.end() @@ -301,7 +313,7 @@ def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]: else len(body) ) val = _TC_PARAM_CLOSE_RE.sub("", body[val_start:next_param]) - args[pm.group(1)] = val.strip() + args[pm.group(1) or pm.group(2)] = val.strip() out.append( { @@ -438,12 +450,23 @@ def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: "<|end_header_id|>", "<|eom_id|>", ) + # Role labels Meta's Llama-3 chat template inserts between + # ``<|start_header_id|>`` and ``<|end_header_id|>`` -- consume so a + # round-trip like + # ``<|start_header_id|>assistant<|end_header_id|>\n\n{json}`` + # reaches the JSON body. + _header_roles = ("assistant", "user", "system", "tool", "ipython") while True: stripped = stripped.lstrip() matched = False for sentinel in _sentinels: if stripped.startswith(sentinel): stripped = stripped[len(sentinel) :] + if sentinel == "<|start_header_id|>": + for role in _header_roles: + if stripped.startswith(role): + stripped = stripped[len(role) :] + break matched = True break if not matched: diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index 729e9ce766..cbe6da078d 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -1477,5 +1477,110 @@ def test_python_tag_with_eom_then_trailing_python_tag(self): assert self._strip(text) == "<|eom_id|>" +# ──────────────────────────────────────────────────────────────────── +# Robustness fixes uncovered while validating against vLLM / sglang. +# ──────────────────────────────────────────────────────────────────── + + +class TestParserRobustness: + def test_tool_call_json_accepts_parameters_key(self): + # Hermes wrapper around a Llama-3.2 bare-JSON object that uses + # ``parameters`` instead of ``arguments``. The bare-JSON and + # python_tag paths already accept both keys; this path now does + # too. Was extracting name only and silently dropping the args. + import json + text = ( + '\n' + '{"name": "search", "parameters": {"q": "ramen"}}\n' + '' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "search" + assert json.loads(result[0]["function"]["arguments"]) == {"q": "ramen"} + + def test_function_xml_attribute_form(self): + # MiniCPM-5 / MiniMax-M2 attribute syntax: + # ``v``. + import json + text = ( + '' + 'Tokyo' + '' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_function_xml_attribute_form_multi_param(self): + import json + text = ( + '' + 'Tokyo' + 'celsius' + '' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"city": "Tokyo", "unit": "celsius"} + + def test_function_xml_legacy_equals_form_still_works(self): + # Regression guard: the old ``v`` + # syntax must keep parsing after the regex broadening. + import json + text = ( + '' + 'Tokyo' + '' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_llama3_chat_template_round_trip(self): + # Meta's official Llama-3.x chat template prefixes every + # assistant turn with + # ``<|start_header_id|>assistant<|end_header_id|>\n\n``. The + # sentinel-strip in ``_parse_llama3_bare_json`` must reach past + # the role label to the JSON body, else every round-tripped + # tool call in history silently drops. + import json + text = ( + '<|start_header_id|>assistant<|end_header_id|>\n\n' + '{"name": "get_weather", "parameters": {"city": "Tokyo"}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_llama3_round_trip_all_roles(self): + # Same logic must work for every role the chat template inserts. + import json + for role in ("assistant", "user", "system", "tool", "ipython"): + text = ( + f'<|start_header_id|>{role}<|end_header_id|>\n\n' + '{"name": "f", "parameters": {"x": 1}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1, f"failed for role={role}" + assert json.loads(result[0]["function"]["arguments"]) == {"x": 1} + + def test_llama3_round_trip_with_eot_prefix(self): + # Prior assistant turn closes with ``<|eot_id|>``, then the + # new header opens. Both sentinels + the role must be consumed. + import json + text = ( + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + '{"name": "f", "parameters": {}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "f" + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 1919e2a953a72eb15560bf750b27db7dd5941510 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 13:31:43 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/inference/tool_call_parser.py | 4 +-- .../tests/test_safetensors_tool_loop.py | 25 +++++++++++-------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 8473a83260..dd0281948e 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -88,9 +88,7 @@ # Qwen3.5 / Hermes ``v`` AND the attribute # form ``v`` used by MiniCPM-5, # MiniMax-M2, etc. Name lands in group(1) or group(2). -_TC_FUNC_START_RE = re.compile( - r'\s*' -) +_TC_FUNC_START_RE = re.compile(r'\s*') _TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") _TC_PARAM_START_RE = re.compile( diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index cbe6da078d..ca86392fc2 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -1489,10 +1489,11 @@ def test_tool_call_json_accepts_parameters_key(self): # python_tag paths already accept both keys; this path now does # too. Was extracting name only and silently dropping the args. import json + text = ( - '\n' + "\n" '{"name": "search", "parameters": {"q": "ramen"}}\n' - '' + "" ) result = parse_tool_calls_from_text(text) assert len(result) == 1 @@ -1503,10 +1504,11 @@ def test_function_xml_attribute_form(self): # MiniCPM-5 / MiniMax-M2 attribute syntax: # ``v``. import json + text = ( '' 'Tokyo' - '' + "" ) result = parse_tool_calls_from_text(text) assert len(result) == 1 @@ -1515,11 +1517,12 @@ def test_function_xml_attribute_form(self): def test_function_xml_attribute_form_multi_param(self): import json + text = ( '' 'Tokyo' 'celsius' - '' + "" ) result = parse_tool_calls_from_text(text) assert len(result) == 1 @@ -1530,10 +1533,9 @@ def test_function_xml_legacy_equals_form_still_works(self): # Regression guard: the old ``v`` # syntax must keep parsing after the regex broadening. import json + text = ( - '' - 'Tokyo' - '' + "" "Tokyo" "" ) result = parse_tool_calls_from_text(text) assert len(result) == 1 @@ -1548,8 +1550,9 @@ def test_llama3_chat_template_round_trip(self): # the role label to the JSON body, else every round-tripped # tool call in history silently drops. import json + text = ( - '<|start_header_id|>assistant<|end_header_id|>\n\n' + "<|start_header_id|>assistant<|end_header_id|>\n\n" '{"name": "get_weather", "parameters": {"city": "Tokyo"}}' ) result = parse_tool_calls_from_text(text) @@ -1560,9 +1563,10 @@ def test_llama3_chat_template_round_trip(self): def test_llama3_round_trip_all_roles(self): # Same logic must work for every role the chat template inserts. import json + for role in ("assistant", "user", "system", "tool", "ipython"): text = ( - f'<|start_header_id|>{role}<|end_header_id|>\n\n' + f"<|start_header_id|>{role}<|end_header_id|>\n\n" '{"name": "f", "parameters": {"x": 1}}' ) result = parse_tool_calls_from_text(text) @@ -1573,8 +1577,9 @@ def test_llama3_round_trip_with_eot_prefix(self): # Prior assistant turn closes with ``<|eot_id|>``, then the # new header opens. Both sentinels + the role must be consumed. import json + text = ( - '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" '{"name": "f", "parameters": {}}' ) result = parse_tool_calls_from_text(text) From c4bbecfa99dad1d8045f4f5d610567d960a6c83c Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 27 May 2026 14:23:43 +0000 Subject: [PATCH 13/17] studio: terminate function-XML body at , not just `_parse_function_xml` was looking for `` (the Hermes wrapper) as the body terminator. When a model emits a standalone `v` followed by explanatory prose (which models routinely do), no `` is present, so the body extended to end-of-string and the trailing prose leaked into the LAST parameter value. Pre-existing on main (the legacy `` form had this bug too). Same affects PR #5620's new attribute-form `v` emission used by MiniCPM-5 / MiniMax-M2. Fix: `_TC_END_TAG_RE` now matches either `` OR ``. The existing `_TC_FUNC_CLOSE_RE` / `_TC_PARAM_CLOSE_RE` strips are unchanged. Multi-call inputs still bound each function at the next `` is preserved because the embedded close tag is ``, not ``). `pytest studio/backend/tests/test_safetensors_tool_loop.py studio/backend/tests/test_safetensors_capability_advertise.py -q` goes from 125 to 127 passed. --- .../core/inference/tool_call_parser.py | 6 +++- .../tests/test_safetensors_tool_loop.py | 29 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 6858b92a73..927e90c10b 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -93,7 +93,11 @@ # with hyphens (mcp__srv__list-issues) and dotted module names parse # the same as the built-ins. Name lands in group(1) or group(2). _TC_FUNC_START_RE = re.compile(r'\s*') -_TC_END_TAG_RE = re.compile(r"") +# Body terminates at either ```` (Hermes wrapper) OR +# ```` (Qwen3.5 / MiniCPM-5 standalone) so the parser stops +# at the close tag even when prose follows. Without ````, +# trailing prose leaked into the last parameter value. +_TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") _TC_PARAM_START_RE = re.compile( r'<(?:parameter|param)(?:=([\w\.\-]+)|\s+name="([\w\.\-]+)")>\s*' diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index ca86392fc2..4ff949f699 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -1586,6 +1586,35 @@ def test_llama3_round_trip_with_eot_prefix(self): assert len(result) == 1 assert result[0]["function"]["name"] == "f" + def test_function_xml_followed_by_prose(self): + # Models routinely follow a tool call with explanatory prose. + # Body must terminate at ```` even without a + # ```` wrapper, else trailing prose leaks into the + # last parameter value. + import json + + text = ( + "" + "Tokyo" + "\n\nHere is what I found." + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_function_attribute_xml_followed_by_prose(self): + # Same expectation for the MiniCPM-5 attribute form. + import json + + text = ( + '' + 'Tokyo' + "\n\nLet me know if you need anything else." + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 85660c1e05787d0873f80be22c357d742be1937b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 27 May 2026 14:40:08 +0000 Subject: [PATCH 14/17] Studio: tighten Llama-3.2 bare-JSON guard A fuzz pass on PR #5811 turned up that ``_parse_llama3_bare_json`` accepted ``parameters`` as a string, contradicting the docstring's "parameters or arguments is a dict" guard. Prose JSON like ``{"name":"foo","parameters":"a sentence"}`` would wrongly fire the parser, which the agentic loop would then heal into a real ``foo(query="a sentence")`` call. Same code lives on this branch, so the same fix applies here. Tightened guard: - ``parameters`` must be a dict (Llama-3 spec). - ``arguments`` may be a dict, or a JSON-encoded string that decodes to a dict (OpenAI shape, e.g. ``"arguments":"{\"q\":\"x\"}"``). Plain non-JSON strings or JSON-strings of lists / scalars / null no longer pass. Mirrors the fix landed in PR #5811 commit 615b8608. Adds the same 4 regression tests under TestParserMultiFormat. Existing test suite stays green: 127 -> 131 passing. --- .../core/inference/tool_call_parser.py | 24 ++++++++++---- .../tests/test_safetensors_tool_loop.py | 31 +++++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 927e90c10b..4de422dbc9 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -498,16 +498,28 @@ def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: name = obj.get("name") or obj.get("function") or "" if not isinstance(name, str) or not name: break + # ``parameters`` must be a dict (Llama-3 spec). + # ``arguments`` may be a dict or a JSON-string of a dict (OpenAI shape). + # Anything looser would fire on prose like ``{"name":"x","parameters":"sentence"}``. if "parameters" in obj: args = obj.get("parameters") + if not isinstance(args, dict): + break + args_str = json.dumps(args) elif "arguments" in obj: args = obj.get("arguments") - else: - break - if isinstance(args, dict): - args_str = json.dumps(args) - elif isinstance(args, str): - args_str = args + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + try: + parsed = json.loads(args) + except (json.JSONDecodeError, ValueError): + break + if not isinstance(parsed, dict): + break + args_str = args + else: + break else: break out.append( diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index 4ff949f699..d43f0c6641 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -260,6 +260,37 @@ def test_llama3_2_bare_json_args_not_dict_does_not_fire(self): text = '{"name":"x","parameters":42}' assert parse_tool_calls_from_text(text) == [] + def test_llama3_2_bare_json_string_parameters_does_not_fire(self): + # Llama-3 spec: parameters must be a dict. Prose like + # ``{"name":"foo","parameters":"a sentence"}`` must NOT trigger. + text = '{"name":"foo","parameters":"this is a sentence"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_string_arguments_not_json_does_not_fire(self): + # OpenAI ``arguments`` may be a JSON-string of a dict, but a + # plain non-JSON string must not pass the guard. + text = '{"name":"foo","arguments":"not json"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_string_arguments_json_dict_fires(self): + # OpenAI shape: arguments is a JSON-encoded string of a dict. + text = '{"name":"foo","arguments":"{\\"q\\":\\"x\\"}"}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "foo" + # arguments stays as the original JSON-string. + assert result[0]["function"]["arguments"] == '{"q":"x"}' + + def test_llama3_2_bare_json_string_arguments_json_non_dict_does_not_fire(self): + # JSON-string that parses to a list / scalar / null must NOT fire. + for bad in ( + '{"name":"foo","arguments":"[1,2,3]"}', + '{"name":"foo","arguments":"\\"plain\\""}', + '{"name":"foo","arguments":"null"}', + '{"name":"foo","arguments":"42"}', + ): + assert parse_tool_calls_from_text(bad) == [], bad + # ── Mistral pre-v11 ─────────────────────────────────────────── def test_mistral_pre_v11_array(self): From bc13a389babc91fbad51c84b7d763df2459cc1be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 31 May 2026 13:17:27 +0000 Subject: [PATCH 15/17] studio: fix safetensors tool-call parser gaps vs llama.cpp (Mistral CALL_ID / THINK, attribute-form signal) Three GGUF-parity fixes to the safetensors tool-call parser, each matching llama.cpp's reference behaviour: - Mistral Small 3.2 emits [TOOL_CALLS]name[CALL_ID][ARGS]{json}. The parser stopped after the name on seeing [CALL_ID] (neither [ARGS] nor {), dropping the call. Skip an optional [CALL_ID] segment in both the parse and strip paths. llama.cpp parses this (test-chat.cpp:4785). - Magistral wraps reasoning in [THINK]...[/THINK]. A [TOOL_CALLS] inside the reasoning was parsed as a real call, producing a phantom call. Strip a leading [THINK] block before scanning so only the post-reasoning call counts (test-chat.cpp:2285); a literal [THINK] inside a later argument is left intact. - The standalone MiniCPM-5 / MiniMax-M2 attribute form parsed correctly but was absent from TOOL_XML_SIGNALS and the markup strip patterns, so the streaming safety-net parse was gated off (dropping the call) and markup leaked into displayed text. Add the signal and broaden the strip regexes. Adds regression tests for all three. --- .../core/inference/tool_call_parser.py | 58 ++++++++++++++- .../tests/test_safetensors_tool_loop.py | 72 +++++++++++++++++++ 2 files changed, 128 insertions(+), 2 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 4de422dbc9..87041134ef 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -31,6 +31,7 @@ TOOL_XML_SIGNALS = ( "", "", "[TOOL_CALLS]", "<|tool_call>", @@ -43,12 +44,12 @@ # (mcp__srv__list-issues) parse the same as the built-ins. _TOOL_CLOSED_PATS = [ re.compile(r".*?", re.DOTALL), - re.compile(r".*?", re.DOTALL), + re.compile(r'.*?', re.DOTALL), re.compile(r"<\|tool_call>.*?", re.DOTALL), ] _TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [ re.compile(r".*$", re.DOTALL), - re.compile(r".*$", re.DOTALL), + re.compile(r'.*$', re.DOTALL), re.compile(r"<\|tool_call>.*$", re.DOTALL), re.compile(r"\[TOOL_CALLS\].*$", re.DOTALL), re.compile(r"<\|python_tag\|>.*$", re.DOTALL), @@ -119,6 +120,14 @@ # / Large 3). _MISTRAL_TRIGGER = "[TOOL_CALLS]" _MISTRAL_ARGS_MARKER = "[ARGS]" +# Mistral Small 3.2 emits ``name[CALL_ID][ARGS]{json}``; the call-id +# segment is absent on Ministral / Magistral. llama.cpp distinguishes the +# two on the presence of ``[CALL_ID]`` (common/chat.cpp). +_MISTRAL_CALL_ID_MARKER = "[CALL_ID]" +# Magistral wraps reasoning in ``[THINK] ... [/THINK]`` before the answer. +# A ``[TOOL_CALLS]`` inside that block is chain-of-thought, not a real call. +_MISTRAL_THINK_OPEN = "[THINK]" +_MISTRAL_THINK_CLOSE = "[/THINK]" _MISTRAL_V11_NAME_RE = re.compile(r"\s*([\w\.\-]+)\s*") # Gemma 4: ``<|tool_call>call:NAME{...}``, ``<|"|>`` wraps strings. @@ -159,6 +168,48 @@ def _balanced_bracket_end(text: str, start: int) -> int | None: return None +def _skip_mistral_call_id(text: str, pos: int) -> int: + """Skip an optional ``[CALL_ID]`` segment (Mistral Small 3.2) at + ``pos``. Returns the position of the next meaningful token (``[ARGS]`` + or ``{``), or ``pos`` unchanged when no ``[CALL_ID]`` is present.""" + n = len(text) + i = pos + while i < n and text[i] in " \t\n\r": + i += 1 + if not text.startswith(_MISTRAL_CALL_ID_MARKER, i): + return pos + i += len(_MISTRAL_CALL_ID_MARKER) + while i < n and text[i] in " \t\n\r": + i += 1 + # The id is a short opaque token; stop at whitespace or the next marker. + while i < n and text[i] not in " \t\n\r[{": + i += 1 + while i < n and text[i] in " \t\n\r": + i += 1 + return i + + +def _strip_mistral_reasoning(content: str) -> str: + """Drop a leading Magistral ``[THINK] ... [/THINK]`` reasoning block so a + ``[TOOL_CALLS]`` emitted *inside* the chain-of-thought is not mistaken for + a real call (llama.cpp parses reasoning separately; see test-chat.cpp). + + Only a leading block is removed -- the reasoning prefix is always first, + so a literal ``[THINK]`` inside a later tool argument is left untouched. + An unclosed leading ``[THINK]`` (still streaming) means nothing has been + committed yet, so everything from it onward is dropped.""" + i = 0 + n = len(content) + while i < n and content[i] in " \t\n\r": + i += 1 + if not content.startswith(_MISTRAL_THINK_OPEN, i): + return content + close = content.find(_MISTRAL_THINK_CLOSE, i + len(_MISTRAL_THINK_OPEN)) + if close == -1: + return content[:i] + return content[:i] + content[close + len(_MISTRAL_THINK_CLOSE):] + + def _strip_mistral_closed_calls(text: str) -> str: """Strip cleanly-closed ``[TOOL_CALLS]`` blocks (array, ``name{json}``, or ``name[ARGS]{json}``) via balanced brace/bracket scanning. @@ -199,6 +250,7 @@ def _strip_mistral_closed_calls(text: str) -> str: i = name_match.end() while i < n and text[i] in " \t\n\r": i += 1 + i = _skip_mistral_call_id(text, i) if text.startswith(_MISTRAL_ARGS_MARKER, i): i += len(_MISTRAL_ARGS_MARKER) while i < n and text[i] in " \t\n\r": @@ -538,6 +590,7 @@ def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: ``[TOOL_CALLS]{...}`` and v11+ ``[TOOL_CALLS]name{json}`` / ``[TOOL_CALLS]name[ARGS]{json}`` (parallel-friendly).""" out: list[dict] = [] + content = _strip_mistral_reasoning(content) idx = content.find(_MISTRAL_TRIGGER) if idx < 0: return out @@ -578,6 +631,7 @@ def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: continue name = nm.group(1) after_name = nm.end() + after_name = _skip_mistral_call_id(content, after_name) if content.startswith(_MISTRAL_ARGS_MARKER, after_name): after_name += len(_MISTRAL_ARGS_MARKER) while after_name < len(content) and content[after_name] in " \t\n\r": diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index d43f0c6641..c790fdc7e5 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -358,6 +358,66 @@ def test_mistral_strip_markup_v11(self): text = '[TOOL_CALLS]add{"a":1}' assert strip_tool_markup(text, final = True) == "" + def test_mistral_call_id_form(self): + # Mistral Small 3.2: ``[TOOL_CALLS]name[CALL_ID][ARGS]{json}``. + # The ``[CALL_ID]`` segment must be skipped, not treated as a stop + # (llama.cpp test-chat.cpp:4785 parses this to one call). + import json + + text = '[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{"arg1": 1}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "special_function" + assert json.loads(result[0]["function"]["arguments"]) == {"arg1": 1} + + def test_mistral_call_id_form_parallel(self): + text = ( + '[TOOL_CALLS]special_function[CALL_ID]000000001[ARGS]{"arg1": 1}' + '[TOOL_CALLS]special_function_with_opt[CALL_ID]000000002' + '[ARGS]{"arg1": 1, "arg2": 2}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "special_function" + assert result[1]["function"]["name"] == "special_function_with_opt" + + def test_mistral_call_id_form_stripped(self): + text = '[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{"arg1": 1}' + assert strip_tool_markup(text, final = True) == "" + + def test_mistral_think_reasoning_ignored(self): + # Magistral wraps reasoning in ``[THINK]...[/THINK]``. A ``[TOOL_CALLS]`` + # inside the reasoning is chain-of-thought, not a real call; only the + # call after ``[/THINK]`` counts (llama.cpp test-chat.cpp:2285). + import json + + text = ( + "[THINK]Let me think about [TOOL_CALLS]fake[ARGS]{\"x\":1} " + "and more[/THINK][TOOL_CALLS]real_fn[ARGS]{\"y\":2}" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "real_fn" + assert json.loads(result[0]["function"]["arguments"]) == {"y": 2} + + def test_mistral_think_reasoning_no_real_call(self): + # Reasoning that merely mentions a tool call but does not emit one + # after ``[/THINK]`` yields no calls. + text = '[THINK]I might call [TOOL_CALLS]fake[ARGS]{"x":1}[/THINK]Done.' + assert parse_tool_calls_from_text(text) == [] + + def test_mistral_think_literal_in_argument_preserved(self): + # A literal ``[THINK]`` inside a real tool argument (after the call) + # must not be stripped or corrupt the parse. + import json + + text = '[TOOL_CALLS]search[ARGS]{"q":"explain the [THINK] token"}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert json.loads(result[0]["function"]["arguments"]) == { + "q": "explain the [THINK] token" + } + # ── Gemma 4 ─────────────────────────────────────────────────── def test_gemma4_simple_call(self): @@ -1573,6 +1633,18 @@ def test_function_xml_legacy_equals_form_still_works(self): assert result[0]["function"]["name"] == "get_weather" assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + def test_function_attribute_form_has_tool_signal(self): + # The standalone ```` attribute form must flip + # the streaming buffer; otherwise the end-of-turn safety-net parse in + # the agentic loop is gated off and the real call is dropped. + assert has_tool_signal('') is True + + def test_function_attribute_form_strip_markup(self): + # The attribute form must also be stripped from displayed text, like + # the legacy ```` form. + text = 'result X' + assert strip_tool_markup(text, final = True) == "result" + def test_llama3_chat_template_round_trip(self): # Meta's official Llama-3.x chat template prefixes every # assistant turn with From 5a021206fb0cd2b78e02663f270429fd50f2700f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 31 May 2026 13:17:44 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- studio/backend/core/inference/tool_call_parser.py | 2 +- studio/backend/tests/test_safetensors_tool_loop.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 87041134ef..0bb881d1a7 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -207,7 +207,7 @@ def _strip_mistral_reasoning(content: str) -> str: close = content.find(_MISTRAL_THINK_CLOSE, i + len(_MISTRAL_THINK_OPEN)) if close == -1: return content[:i] - return content[:i] + content[close + len(_MISTRAL_THINK_CLOSE):] + return content[:i] + content[close + len(_MISTRAL_THINK_CLOSE) :] def _strip_mistral_closed_calls(text: str) -> str: diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index c790fdc7e5..9bf6ce154a 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -373,7 +373,7 @@ def test_mistral_call_id_form(self): def test_mistral_call_id_form_parallel(self): text = ( '[TOOL_CALLS]special_function[CALL_ID]000000001[ARGS]{"arg1": 1}' - '[TOOL_CALLS]special_function_with_opt[CALL_ID]000000002' + "[TOOL_CALLS]special_function_with_opt[CALL_ID]000000002" '[ARGS]{"arg1": 1, "arg2": 2}' ) result = parse_tool_calls_from_text(text) @@ -392,8 +392,8 @@ def test_mistral_think_reasoning_ignored(self): import json text = ( - "[THINK]Let me think about [TOOL_CALLS]fake[ARGS]{\"x\":1} " - "and more[/THINK][TOOL_CALLS]real_fn[ARGS]{\"y\":2}" + '[THINK]Let me think about [TOOL_CALLS]fake[ARGS]{"x":1} ' + 'and more[/THINK][TOOL_CALLS]real_fn[ARGS]{"y":2}' ) result = parse_tool_calls_from_text(text) assert len(result) == 1 From c6d1478d9e2e9bb847f92b599ecdad577af06439 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 31 May 2026 13:51:20 +0000 Subject: [PATCH 17/17] studio: fire safetensors tool calls for the bare-JSON (Llama-3.2) form The agentic loop's streaming safety-net parse was gated on has_tool_signal(), which is False for the Llama-3.1 / 3.2 bare-JSON tool form {"name":..,"parameters":..} (no XML marker). Real tool calls were therefore dropped: the loop logged "model planned without calling tools", re-prompted three times, then gave up with zero tool calls, while GGUF's llama-server parses the same emission natively. Run parse_tool_calls_from_text() unconditionally in the safety net. The parser is strict (only fires on a valid tool-call shape) so plain answers are unaffected. Reproduced on a real unsloth/Llama-3.1-8B-Instruct run: the model emits {"name":"web_search","parameters":{...}} which now executes the tool instead of being re-prompted into a no-op. Adds a loop regression test for the bare-JSON form. --- .../core/inference/safetensors_agentic.py | 18 +++++++++++------- .../tests/test_safetensors_tool_loop.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/studio/backend/core/inference/safetensors_agentic.py b/studio/backend/core/inference/safetensors_agentic.py index 36d02d59ed..1f442161e9 100644 --- a/studio/backend/core/inference/safetensors_agentic.py +++ b/studio/backend/core/inference/safetensors_agentic.py @@ -277,13 +277,17 @@ def run_safetensors_tool_loop( detect_state = _state_streaming if detect_state == _state_streaming: - # No tool detected mid-stream -- check for late tool XML. - safety_tc = None - if has_tool_signal(content_accum): - safety_tc = parse_tool_calls_from_text( - content_accum, - id_offset = next_call_id, - ) + # No tool XML detected mid-stream -- run the parser anyway. + # The Llama-3.2 bare-JSON tool form ``{"name":..,"parameters":..}`` + # carries no XML signal, so gating this on has_tool_signal() + # silently dropped real tool calls and re-prompted the model into + # giving up. parse_tool_calls_from_text is strict (it only fires + # on a valid tool-call shape), so plain answers stay untouched. + # This mirrors what llama-server already does for GGUF. + safety_tc = parse_tool_calls_from_text( + content_accum, + id_offset = next_call_id, + ) if not safety_tc: # Re-prompt only when the model planned without acting # (intent signal present); direct answers like "4" or diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index c790fdc7e5..3ec61f4e64 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -676,6 +676,25 @@ def test_llama3_python_tag_form(self): contents = [e for e in events if e["type"] == "content"] assert "sunny" in contents[-1]["text"].lower() + def test_llama3_bare_json_form_fires_tool(self): + # Llama-3.1 / 3.2 emit a bare-JSON tool call + # ``{"name":..,"parameters":..}`` with NO XML signal. The loop's + # safety-net parse must still fire the tool instead of treating the + # turn as "planned without calling tools" and re-prompting the model + # into giving up. Regression for the has_tool_signal gate that + # dropped these; GGUF's llama-server parses them natively. + loop, exec_fn = _make_loop( + turns = [ + ['{"name": "web_search", "parameters": {"query": "weather in SF"}}'], + ["The weather is sunny."], + ], + exec_results = ["Sunny, 18C"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather in SF"})] + contents = [e for e in events if e["type"] == "content"] + assert "sunny" in contents[-1]["text"].lower() + def test_mistral_pre_v11_form(self): # Pre-v11 Mistral emission: ``[TOOL_CALLS] [{...}]``. loop, exec_fn = _make_loop(