diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py
index a0ab8a2a53..7b0c9b1f62 100644
--- a/studio/backend/core/inference/tool_call_parser.py
+++ b/studio/backend/core/inference/tool_call_parser.py
@@ -2,32 +2,72 @@
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
-Backend-neutral tool-call XML parser shared by GGUF and safetensors.
-Tolerates missing closing tags in either ``{json}``
-or ``v...`` shape.
+Backend-neutral tool-call parser shared by GGUF, safetensors, and MLX.
+
+Covers the emission formats so the safetensors + MLX agentic loop sees
+the same call shape llama-server normalises for GGUF:
+
+ - ``{json}`` (Qwen / Hermes)
+ - ``v`` (Qwen3.5 xml)
+ - ``<|python_tag|>NAME.call(k="v", ...)`` (Llama-3 built-in tools)
+ - ``<|python_tag|>{"name":..., "parameters":...}`` (Llama-3 custom)
+ - ``{"name":..., "parameters":...}`` (Llama-3.2 bare JSON)
+ - ``[TOOL_CALLS] [{...}, ...]`` (Mistral v0.3 / Nemo / Small)
+ - ``[TOOL_CALLS]name{json}`` (Mistral v11+ / Magistral)
+ - ``[TOOL_CALLS]name[ARGS]{json}`` (Ministral / Mistral Large 3)
+ - ``<|tool_call>call:NAME{k:<|"|>v<|"|>}`` (Gemma 4)
+
+Closing tags / brackets are tolerated when missing because models
+frequently truncate them mid-stream.
"""
import json
import re
+from typing import Any
+
+
+# ── Streaming-buffer signal markers ─────────────────────────────────
+
+
+# Prefixes the safetensors / MLX streaming buffer watches for to gate
+# in-progress text. When ANY of these appear in the cumulative text,
+# the state machine switches from STREAMING to DRAINING so we don't
+# leak partial markup to the user before we can parse it.
+TOOL_XML_SIGNALS = (
+ "",
+ "",
+ "[TOOL_CALLS]",
+ "<|tool_call>",
+)
-# _TOOL_CLOSED_PATS: closed pairs only. _TOOL_ALL_PATS: also trailing
-# unclosed runs so truncated tails don't leak markup.
+# ── Strip patterns for ``strip_tool_markup`` ────────────────────────
+
+
+# _TOOL_CLOSED_PATS: closed pairs only (used during streaming so
+# in-progress XML stays buffered). _TOOL_ALL_PATS: also matches trailing
+# unclosed runs so truncated tails don't leak markup at end-of-turn.
_TOOL_CLOSED_PATS = [
re.compile(r".*?", re.DOTALL),
re.compile(r".*?", re.DOTALL),
+ re.compile(r"<\|tool_call>.*?", re.DOTALL),
+ re.compile(r"\[TOOL_CALLS\]\s*\[.*?\](?:\s*)?", re.DOTALL),
+ # Mistral v11+ ``[TOOL_CALLS]name{json}`` (may chain), close at ``}``.
+ re.compile(r"\[TOOL_CALLS\]\s*[\w\.\-]+\s*(?:\[ARGS\])?\s*\{.*?\}", re.DOTALL),
]
_TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [
re.compile(r".*$", re.DOTALL),
re.compile(r".*$", re.DOTALL),
+ re.compile(r"<\|tool_call>.*$", re.DOTALL),
+ re.compile(r"\[TOOL_CALLS\].*$", re.DOTALL),
+ re.compile(r"<\|python_tag\|>.*$", re.DOTALL),
]
-# Prefixes the streaming buffer watches for to gate in-progress text.
-TOOL_XML_SIGNALS = ("", "{json}
_TC_JSON_START_RE = re.compile(r"\s*\{")
-_TC_FUNC_START_RE = re.compile(r"\s*")
+# Qwen3.5 / Hermes XML form v
+_TC_FUNC_START_RE = re.compile(r"\s*")
_TC_END_TAG_RE = re.compile(r"")
_TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$")
-_TC_PARAM_START_RE = re.compile(r"\s*")
+_TC_PARAM_START_RE = re.compile(r"\s*")
_TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$")
+# Llama-3 <|python_tag|>NAME.call(...)
+_LLAMA3_PYTHON_TAG = "<|python_tag|>"
+_LLAMA3_PY_CALL_RE = re.compile(
+ r"<\|python_tag\|>\s*([\w\.\-]+)\s*\.\s*call\s*\(",
+)
+_LLAMA3_KV_RE = re.compile(
+ r"""(\w+)\s*=\s*(?:"((?:\\.|[^"\\])*)"|(-?\d+(?:\.\d+)?)|(true|false|null))""",
+ re.VERBOSE,
+)
+
+# Mistral [TOOL_CALLS] trigger. v11+ chains multiple triggers, each
+# followed by a bare name then either ``{json}`` (Magistral) or
+# ``[ARGS]{json}`` (Ministral / Mistral Large 3).
+_MISTRAL_TRIGGER = "[TOOL_CALLS]"
+_MISTRAL_ARGS_MARKER = "[ARGS]"
+_MISTRAL_V11_NAME_RE = re.compile(r"\s*([\w\.\-]+)\s*")
+
+# Gemma 4 <|tool_call>call:NAME{...}. ``<|"|>`` wraps strings.
+_GEMMA_TC_RE = re.compile(r"<\|tool_call>\s*call\s*:\s*([\w\.\-]+)\s*\{")
+_GEMMA_STR_BEGIN = '<|"|>'
+_GEMMA_STR_END = '<|"|>'
+_GEMMA_TC_END = ""
+
+
+# ── Public API ──────────────────────────────────────────────────────
+
def strip_tool_markup(text: str, *, final: bool = False) -> str:
- """Strip tool-call XML from streamed text.
+ """Strip tool-call markup from streamed text.
- ``final=False`` only removes closed pairs (used during streaming so
- in-progress XML stays buffered). ``final=True`` also removes a
- trailing unclosed run and trims the result.
+ ``final=False`` only removes closed pairs so in-progress markup
+ stays buffered. ``final=True`` also removes trailing unclosed runs
+ and trims the result.
"""
pats = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS
for pat in pats:
@@ -80,125 +150,651 @@ def strip_tool_markup(text: str, *, final: bool = False) -> str:
return text.strip() if final else text
+def has_tool_signal(text: str) -> bool:
+ """True if ``text`` contains any known tool-call signal."""
+ return any(s in text for s in TOOL_XML_SIGNALS)
+
+
def parse_tool_calls_from_text(content: str, *, id_offset: int = 0) -> list[dict]:
"""Parse OpenAI-format ``tool_calls`` from model text.
- Returns a list of ``{"id", "type", "function": {"name", "arguments"}}``
- dicts. ``arguments`` is always a JSON string so callers can hand it
- straight back into an OpenAI-style response.
+ Returns ``[{"id", "type", "function": {"name", "arguments"}}]``
+ where ``arguments`` is always a JSON string. Tries each known
+ emission format in turn; returns as soon as one yields calls so
+ we never double-count.
+ """
+ # Qwen / Hermes {json}
+ calls = _parse_tool_call_json(content, id_offset = id_offset)
+ if calls:
+ return calls
- Handles two shapes:
+ # Qwen3.5 / Hermes v
+ calls = _parse_function_xml(content, id_offset = id_offset)
+ if calls:
+ return calls
- - JSON inside ```` tags:
- ``{"name":"web_search","arguments":{"query":"..."}}``
- - XML-style function blocks:
- ``v``
+ # Llama-3 <|python_tag|>...
+ calls = _parse_llama3_python_tag(content, id_offset = id_offset)
+ if calls:
+ return calls
+
+ # Mistral [TOOL_CALLS]...
+ calls = _parse_mistral_tool_calls(content, id_offset = id_offset)
+ if calls:
+ return calls
+
+ # Gemma 4 <|tool_call>...
+ calls = _parse_gemma_tool_calls(content, id_offset = id_offset)
+ if calls:
+ return calls
+
+ # Llama-3.2 bare JSON ``{"name":..., "parameters":...}`` (no tag).
+ # Strict: only fires when stripped content STARTS with ``{`` and
+ # parses as ``{name: str, parameters|arguments: dict}``. Keeps
+ # plain assistant prose unaffected.
+ return _parse_llama3_bare_json(content, id_offset = id_offset)
- Closing tags (````, ````, ````)
- are all optional since models frequently omit them.
- """
- tool_calls: list[dict] = []
- # Pattern 1: {json}. Balanced-brace scan that skips
- # braces inside JSON strings.
+# ── Per-format parsers ──────────────────────────────────────────────
+
+
+def _parse_tool_call_json(content: str, *, id_offset: int) -> list[dict]:
+ out: list[dict] = []
for m in _TC_JSON_START_RE.finditer(content):
- brace_start = m.end() - 1 # position of the opening {
- depth, i = 0, brace_start
+ brace_start = m.end() - 1
+ end = _balanced_brace_end(content, brace_start)
+ if end is None:
+ continue
+ try:
+ obj = json.loads(content[brace_start : end + 1])
+ except (json.JSONDecodeError, ValueError):
+ continue
+ name = obj.get("name", "")
+ args = obj.get("arguments", {})
+ if isinstance(args, dict):
+ args_str = json.dumps(args)
+ elif isinstance(args, str):
+ args_str = args
+ else:
+ args_str = json.dumps({"value": args})
+ if not name:
+ continue
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {"name": name, "arguments": args_str},
+ }
+ )
+ return out
+
+
+def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]:
+ out: list[dict] = []
+ func_starts = list(_TC_FUNC_START_RE.finditer(content))
+ for idx, fm in enumerate(func_starts):
+ func_name = fm.group(1)
+ body_start = fm.end()
+ next_func = (
+ func_starts[idx + 1].start() if idx + 1 < len(func_starts) else len(content)
+ )
+ end_tag = _TC_END_TAG_RE.search(content[body_start:])
+ if end_tag:
+ body_end = body_start + end_tag.start()
+ else:
+ body_end = len(content)
+ body_end = min(body_end, next_func)
+ body = _TC_FUNC_CLOSE_RE.sub("", content[body_start:body_end])
+
+ args: dict = {}
+ param_starts = list(_TC_PARAM_START_RE.finditer(body))
+ if len(param_starts) == 1:
+ pm = param_starts[0]
+ val = _TC_PARAM_CLOSE_RE.sub("", body[pm.end() :])
+ args[pm.group(1)] = val.strip()
+ else:
+ for pidx, pm in enumerate(param_starts):
+ val_start = pm.end()
+ next_param = (
+ param_starts[pidx + 1].start()
+ if pidx + 1 < len(param_starts)
+ else len(body)
+ )
+ val = _TC_PARAM_CLOSE_RE.sub("", body[val_start:next_param])
+ args[pm.group(1)] = val.strip()
+
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {"name": func_name, "arguments": json.dumps(args)},
+ }
+ )
+ return out
+
+
+def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]:
+ """Llama-3 emission shapes:
+ <|python_tag|>NAME.call(arg="v", ...) (built-in tools)
+ <|python_tag|>{"name":"NAME", "parameters":{...}} (custom tools)
+ <|python_tag|>{"name":...}; {"name":...} (multi-call, ``; `` sep)
+ Accepts both ``parameters`` and ``arguments`` keys per Llama 3.1/3.2.
+ """
+ out: list[dict] = []
+ if _LLAMA3_PYTHON_TAG not in content:
+ return out
+
+ # 1. NAME.call(...) built-in form.
+ for m in _LLAMA3_PY_CALL_RE.finditer(content):
+ name = m.group(1)
+ i = m.end()
+ depth = 1
in_string = False
- while i < len(content):
+ esc = False
+ while i < len(content) and depth > 0:
ch = content[i]
if in_string:
- if ch == "\\" and i + 1 < len(content):
- i += 2
- continue
- if ch == '"':
+ if esc:
+ esc = False
+ elif ch == "\\":
+ esc = True
+ elif ch == '"':
in_string = False
+ else:
+ if ch == '"':
+ in_string = True
+ elif ch == "(":
+ depth += 1
+ elif ch == ")":
+ depth -= 1
+ if depth == 0:
+ break
+ i += 1
+ body = content[m.end() : i]
+ args: dict[str, Any] = {}
+ for kv in _LLAMA3_KV_RE.finditer(body):
+ k = kv.group(1)
+ if kv.group(2) is not None:
+ try:
+ args[k] = bytes(kv.group(2), "utf-8").decode("unicode_escape")
+ except (UnicodeDecodeError, ValueError):
+ args[k] = kv.group(2)
+ elif kv.group(3) is not None:
+ v = kv.group(3)
+ args[k] = float(v) if "." in v else int(v)
+ elif kv.group(4) is not None:
+ args[k] = {"true": True, "false": False, "null": None}[kv.group(4)]
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {"name": name, "arguments": json.dumps(args)},
+ }
+ )
+
+ # 2. <|python_tag|>{"name":..., "parameters":...} JSON form. Use a
+ # streaming JSON decoder (raw_decode) so we can peel multiple
+ # objects out of the same emission (separated by ``; `` per
+ # Llama 3 template).
+ if not out:
+ decoder = json.JSONDecoder()
+ idx = content.find(_LLAMA3_PYTHON_TAG)
+ while idx >= 0:
+ search_from = idx + len(_LLAMA3_PYTHON_TAG)
+ # Scan all `{` from this trigger; raw_decode jumps the
+ # cursor past each parsed object, but if a `{` falls
+ # inside an already-decoded object we skip it.
+ cursor = search_from
+ while cursor < len(content):
+ brace = content.find("{", cursor)
+ if brace < 0:
+ break
+ # Stop if we've hit the next <|python_tag|>.
+ next_tag = content.find(_LLAMA3_PYTHON_TAG, search_from, brace)
+ if next_tag >= 0:
+ break
+ try:
+ obj, end_offset = decoder.raw_decode(content[brace:])
+ except (json.JSONDecodeError, ValueError):
+ cursor = brace + 1
+ continue
+ if not isinstance(obj, dict):
+ cursor = brace + end_offset
+ continue
+ name = obj.get("name") or obj.get("function") or ""
+ args = (
+ obj.get("parameters")
+ if "parameters" in obj
+ else obj.get("arguments", {})
+ )
+ if isinstance(args, dict):
+ args_str = json.dumps(args)
+ elif isinstance(args, str):
+ args_str = args
+ else:
+ args_str = json.dumps({"value": args})
+ if name:
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {"name": name, "arguments": args_str},
+ }
+ )
+ cursor = brace + end_offset
+ idx = content.find(_LLAMA3_PYTHON_TAG, cursor)
+ return out
+
+
+def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]:
+ """Llama-3.2 ``custom_tools`` shape -- bare JSON ``{"name":...,
+ "parameters":{...}}`` emitted directly, no ``<|python_tag|>``.
+
+ Strict to avoid firing on tool-message echoes:
+
+ * Content must start with ``{`` once whitespace and any leading
+ ``<|begin_of_text|>`` / ``<|eot_id|>`` etc. sentinels are stripped.
+ * Object must have ``name`` (non-empty str) plus a dict in
+ ``parameters`` or ``arguments``.
+ * Loops via ``raw_decode`` to peel multiple ``;``-separated calls.
+ """
+ out: list[dict] = []
+ stripped = content.lstrip()
+ # Strip leading Llama-3 sentinel tokens that sometimes precede the
+ # JSON (``<|eot_id|>`` from the prior turn, ``<|start_header_id|>``).
+ for sentinel in (
+ "<|begin_of_text|>",
+ "<|eot_id|>",
+ "<|start_header_id|>",
+ "<|end_header_id|>",
+ "<|eom_id|>",
+ ):
+ stripped = stripped.lstrip()
+ if stripped.startswith(sentinel):
+ stripped = stripped[len(sentinel) :]
+ stripped = stripped.lstrip()
+ if not stripped.startswith("{"):
+ return out
+
+ decoder = json.JSONDecoder()
+ cursor = 0
+ n = len(stripped)
+ while cursor < n:
+ # Skip whitespace and Llama 3 inter-call separator ``;``.
+ while cursor < n and stripped[cursor] in " \t\n\r;":
+ cursor += 1
+ if cursor >= n or stripped[cursor] != "{":
+ break
+ try:
+ obj, end_offset = decoder.raw_decode(stripped[cursor:])
+ except (json.JSONDecodeError, ValueError):
+ break
+ if not isinstance(obj, dict):
+ break
+ name = obj.get("name") or obj.get("function") or ""
+ if not isinstance(name, str) or not name:
+ break
+ if "parameters" in obj:
+ args = obj.get("parameters")
+ elif "arguments" in obj:
+ args = obj.get("arguments")
+ else:
+ break
+ if isinstance(args, dict):
+ args_str = json.dumps(args)
+ elif isinstance(args, str):
+ args_str = args
+ else:
+ break
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {"name": name, "arguments": args_str},
+ }
+ )
+ cursor += end_offset
+ return out
+
+
+def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]:
+ """Mistral emissions covered:
+ Pre-v11 array: ``[TOOL_CALLS] [{"name":..., "arguments":...}, ...]``
+ Pre-v11 single: ``[TOOL_CALLS]{"name":..., "arguments":...}``
+ v11+ single: ``[TOOL_CALLS]name{json_args}``
+ v11+ parallel: ``[TOOL_CALLS]a{...}[TOOL_CALLS]b{...}``
+ v11+ w/ [ARGS]: ``[TOOL_CALLS]name[ARGS]{json_args}`` (Ministral / Large 3)
+ """
+ out: list[dict] = []
+ idx = content.find(_MISTRAL_TRIGGER)
+ if idx < 0:
+ return out
+
+ # Decide whether the FIRST occurrence is array / single-object
+ # (pre-v11) or v11+ bare-name. Skip whitespace, peek at next char.
+ j = idx + len(_MISTRAL_TRIGGER)
+ k = j
+ while k < len(content) and content[k] in " \t\n\r":
+ k += 1
+ if k >= len(content):
+ return out
+
+ if content[k] == "[":
+ return _parse_mistral_array(content, k, id_offset)
+
+ if content[k] == "{":
+ # Could be pre-v11 single object ``{"name": ...}`` or a JSON
+ # blob immediately following the trigger (rare). Try parsing
+ # as an object that exposes ``name``; if not, fall through to
+ # v11+ handling so we don't drop emission silently.
+ end = _balanced_brace_end(content, k)
+ if end is not None:
+ try:
+ obj = json.loads(content[k : end + 1])
+ if isinstance(obj, dict) and obj.get("name"):
+ _consume_mistral_call(content[k : end + 1], out, id_offset)
+ return out
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ # v11+ path: walk every ``[TOOL_CALLS]`` and parse ``name{json}``
+ # or ``name[ARGS]{json}`` after each trigger.
+ pos = idx
+ while pos >= 0:
+ cur = pos + len(_MISTRAL_TRIGGER)
+ nm = _MISTRAL_V11_NAME_RE.match(content, cur)
+ if not nm:
+ pos = content.find(_MISTRAL_TRIGGER, cur)
+ continue
+ name = nm.group(1)
+ after_name = nm.end()
+ # Optional ``[ARGS]`` marker.
+ if content.startswith(_MISTRAL_ARGS_MARKER, after_name):
+ after_name += len(_MISTRAL_ARGS_MARKER)
+ while after_name < len(content) and content[after_name] in " \t\n\r":
+ after_name += 1
+ if after_name >= len(content) or content[after_name] != "{":
+ pos = content.find(_MISTRAL_TRIGGER, cur)
+ continue
+ end = _balanced_brace_end(content, after_name)
+ if end is None:
+ break
+ try:
+ args = json.loads(content[after_name : end + 1])
+ except (json.JSONDecodeError, ValueError):
+ pos = content.find(_MISTRAL_TRIGGER, end + 1)
+ continue
+ if not isinstance(args, dict):
+ pos = content.find(_MISTRAL_TRIGGER, end + 1)
+ continue
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {
+ "name": name,
+ "arguments": json.dumps(args),
+ },
+ }
+ )
+ pos = content.find(_MISTRAL_TRIGGER, end + 1)
+ return out
+
+
+def _parse_mistral_array(content: str, start: int, id_offset: int) -> list[dict]:
+ """Parse pre-v11 ``[TOOL_CALLS] [{...}, ...]`` JSON array form."""
+ out: list[dict] = []
+ j = start
+ depth = 0
+ in_string = False
+ esc = False
+ while j < len(content):
+ ch = content[j]
+ if in_string:
+ if esc:
+ esc = False
+ elif ch == "\\":
+ esc = True
elif ch == '"':
+ in_string = False
+ else:
+ if ch == '"':
in_string = True
- elif ch == "{":
+ elif ch == "[":
depth += 1
- elif ch == "}":
+ elif ch == "]":
depth -= 1
if depth == 0:
break
- i += 1
- if depth == 0:
- json_str = content[brace_start : i + 1]
- try:
- obj = json.loads(json_str)
- tc = {
- "id": f"call_{id_offset + len(tool_calls)}",
- "type": "function",
- "function": {
- "name": obj.get("name", ""),
- "arguments": obj.get("arguments", {}),
- },
- }
- if isinstance(tc["function"]["arguments"], dict):
- tc["function"]["arguments"] = json.dumps(
- tc["function"]["arguments"]
- )
- tool_calls.append(tc)
- except (json.JSONDecodeError, ValueError):
- pass
+ j += 1
+ body = content[start : j + 1] if depth == 0 else content[start:]
- # Pattern 2: v... -- closing tags
- # optional; don't use as body boundary because code
- # values can contain that literal.
- if not tool_calls:
- func_starts = list(_TC_FUNC_START_RE.finditer(content))
- for idx, fm in enumerate(func_starts):
- func_name = fm.group(1)
- body_start = fm.end()
- next_func = (
- func_starts[idx + 1].start()
- if idx + 1 < len(func_starts)
- else len(content)
- )
- end_tag = _TC_END_TAG_RE.search(content[body_start:])
- if end_tag:
- body_end = body_start + end_tag.start()
- else:
- body_end = len(content)
- body_end = min(body_end, next_func)
- body = content[body_start:body_end]
- body = _TC_FUNC_CLOSE_RE.sub("", body)
-
- arguments: dict = {}
- param_starts = list(_TC_PARAM_START_RE.finditer(body))
- if len(param_starts) == 1:
- # Single param: take everything to body end so
- # embedded in code strings is preserved.
- pm = param_starts[0]
- val = body[pm.end() :]
- val = _TC_PARAM_CLOSE_RE.sub("", val)
- arguments[pm.group(1)] = val.strip()
- else:
- for pidx, pm in enumerate(param_starts):
- param_name = pm.group(1)
- val_start = pm.end()
- next_param = (
- param_starts[pidx + 1].start()
- if pidx + 1 < len(param_starts)
- else len(body)
- )
- val = body[val_start:next_param]
- val = _TC_PARAM_CLOSE_RE.sub("", val)
- arguments[param_name] = val.strip()
+ try:
+ arr = json.loads(body)
+ if isinstance(arr, list):
+ for obj in arr:
+ if isinstance(obj, dict):
+ _consume_mistral_call(json.dumps(obj), out, id_offset)
+ return out
+ except (json.JSONDecodeError, ValueError):
+ pass
- tc = {
- "id": f"call_{id_offset + len(tool_calls)}",
+ # Healing path: walk objects manually for unclosed array.
+ for m in re.finditer(r"\{", body):
+ end = _balanced_brace_end(body, m.start())
+ if end is None:
+ continue
+ _consume_mistral_call(body[m.start() : end + 1], out, id_offset)
+ return out
+
+
+def _consume_mistral_call(obj_text: str, out: list[dict], id_offset: int) -> None:
+ try:
+ obj = json.loads(obj_text)
+ except (json.JSONDecodeError, ValueError):
+ return
+ if not isinstance(obj, dict):
+ return
+ name = obj.get("name") or ""
+ args = obj.get("arguments") or {}
+ if isinstance(args, dict):
+ args_str = json.dumps(args)
+ elif isinstance(args, str):
+ args_str = args
+ else:
+ args_str = json.dumps({"value": args})
+ if name:
+ out.append(
+ {
+ "id": obj.get("id") or f"call_{id_offset + len(out)}",
"type": "function",
- "function": {
- "name": func_name,
- "arguments": json.dumps(arguments),
- },
+ "function": {"name": name, "arguments": args_str},
}
- tool_calls.append(tc)
+ )
- return tool_calls
+def _parse_gemma_tool_calls(content: str, *, id_offset: int) -> list[dict]:
+ """Gemma 4: <|tool_call>call:NAME{k:<|"|>v<|"|>, ...}."""
+ out: list[dict] = []
+ for m in _GEMMA_TC_RE.finditer(content):
+ name = m.group(1)
+ body_start = m.end() - 1
+ end_marker = content.find(_GEMMA_TC_END, body_start)
+ scan_end = end_marker if end_marker >= 0 else len(content)
+ end = _gemma_balanced_brace_end(content, body_start, scan_end)
+ if end is None:
+ continue
+ body = content[body_start + 1 : end]
+ try:
+ args = _gemma_parse_mapping_body(body)
+ except Exception:
+ args = {}
+ out.append(
+ {
+ "id": f"call_{id_offset + len(out)}",
+ "type": "function",
+ "function": {"name": name, "arguments": json.dumps(args)},
+ }
+ )
+ return out
-def has_tool_signal(text: str) -> bool:
- """Return True if ``text`` contains any tool-call XML signal."""
- return any(s in text for s in TOOL_XML_SIGNALS)
+
+# ── Brace-balancing helpers ─────────────────────────────────────────
+
+
+def _balanced_brace_end(text: str, brace_pos: int) -> int | None:
+ """Index of `}` matching `{` at ``brace_pos`` -- ignores `{` `}`
+ inside JSON strings. Returns None if unmatched."""
+ if brace_pos >= len(text) or text[brace_pos] != "{":
+ return None
+ depth = 0
+ in_string = False
+ esc = False
+ i = brace_pos
+ while i < len(text):
+ ch = text[i]
+ if in_string:
+ if esc:
+ esc = False
+ elif ch == "\\":
+ esc = True
+ elif ch == '"':
+ in_string = False
+ else:
+ if ch == '"':
+ in_string = True
+ elif ch == "{":
+ depth += 1
+ elif ch == "}":
+ depth -= 1
+ if depth == 0:
+ return i
+ i += 1
+ return None
+
+
+def _gemma_balanced_brace_end(text: str, brace_pos: int, hard_stop: int) -> int | None:
+ """Same as ``_balanced_brace_end`` but respects Gemma ``<|"|>``
+ string runs and matches `{`/`[` symmetrically."""
+ if brace_pos >= len(text) or text[brace_pos] != "{":
+ return None
+ depth = 0
+ i = brace_pos
+ while i < hard_stop:
+ if text.startswith(_GEMMA_STR_BEGIN, i):
+ close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN))
+ if close < 0:
+ return None
+ i = close + len(_GEMMA_STR_END)
+ continue
+ ch = text[i]
+ if ch == "{" or ch == "[":
+ depth += 1
+ elif ch == "}" or ch == "]":
+ depth -= 1
+ if depth == 0:
+ return i
+ i += 1
+ return None
+
+
+def _gemma_parse_value(text: str, i: int):
+ """Parse one Gemma argument value starting at ``i``. Returns
+ ``(value, next_index)``."""
+ if text.startswith(_GEMMA_STR_BEGIN, i):
+ close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN))
+ if close < 0:
+ return text[i + len(_GEMMA_STR_BEGIN) :], len(text)
+ return text[i + len(_GEMMA_STR_BEGIN) : close], close + len(_GEMMA_STR_END)
+ if text[i] == "{":
+ end = _gemma_balanced_brace_end(text, i, len(text))
+ if end is None:
+ return {}, len(text)
+ return _gemma_parse_mapping_body(text[i + 1 : end]), end + 1
+ if text[i] == "[":
+ j, depth = i, 0
+ while j < len(text):
+ if text.startswith(_GEMMA_STR_BEGIN, j):
+ k = text.find(_GEMMA_STR_END, j + len(_GEMMA_STR_BEGIN))
+ if k < 0:
+ j = len(text)
+ break
+ j = k + len(_GEMMA_STR_END)
+ continue
+ ch = text[j]
+ if ch == "[":
+ depth += 1
+ elif ch == "]":
+ depth -= 1
+ if depth == 0:
+ break
+ j += 1
+ body = text[i + 1 : j]
+ items: list[Any] = []
+ k = 0
+ while k < len(body):
+ if body[k] in " \t\n\r,":
+ k += 1
+ continue
+ v, k = _gemma_parse_value(body, k)
+ items.append(v)
+ return items, j + 1
+ # Primitive: number, true/false/null, or bare identifier (rare).
+ end = i
+ while (
+ end < len(text)
+ and text[end] not in ",}]"
+ and not text.startswith(_GEMMA_STR_BEGIN, end)
+ ):
+ end += 1
+ raw = text[i:end].strip()
+ if raw == "true":
+ return True, end
+ if raw == "false":
+ return False, end
+ if raw == "null":
+ return None, end
+ try:
+ return int(raw), end
+ except ValueError:
+ pass
+ try:
+ return float(raw), end
+ except ValueError:
+ pass
+ return raw, end
+
+
+def _gemma_parse_mapping_body(body: str) -> dict[str, Any]:
+ """Parse content between `{` and `}` for a Gemma argument mapping."""
+ out: dict[str, Any] = {}
+ i = 0
+ n = len(body)
+ while i < n:
+ while i < n and body[i] in " \t\n\r,":
+ i += 1
+ if i >= n:
+ break
+ if body.startswith(_GEMMA_STR_BEGIN, i):
+ close = body.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN))
+ if close < 0:
+ break
+ key = body[i + len(_GEMMA_STR_BEGIN) : close]
+ i = close + len(_GEMMA_STR_END)
+ else:
+ kstart = i
+ while i < n and body[i] != ":":
+ i += 1
+ key = body[kstart:i].strip()
+ while i < n and body[i] in " \t\n\r":
+ i += 1
+ if i < n and body[i] == ":":
+ i += 1
+ while i < n and body[i] in " \t\n\r":
+ i += 1
+ if i >= n:
+ out[key] = None
+ break
+ v, i = _gemma_parse_value(body, i)
+ out[key] = v
+ return out
diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py
index 1b4e7051b0..39f2004fa5 100644
--- a/studio/backend/routes/inference.py
+++ b/studio/backend/routes/inference.py
@@ -256,16 +256,29 @@ def _detect_safetensors_features(backend, chat_template: Optional[str]) -> dict:
"supports_tools": False,
}
)
- # Our safetensors loop only parses {json}
- # and .... Llama uses <|python_tag|>,
- # Mistral uses [TOOL_CALLS]; advertising tools for those would
- # enable a pill the parser cannot honour. GGUF is unaffected --
- # llama-server normalises every format into structured deltas.
+ # The safetensors / MLX loop parses these emission formats:
+ # Qwen ``{json}``, Qwen3.5 ``...``,
+ # Llama-3 ``<|python_tag|>``, Llama-3.2 bare JSON ``{"name":...,
+ # "parameters":...}``, Mistral ``[TOOL_CALLS]`` (pre-v11 array +
+ # v11+ ``name{json}``), and Gemma 4 ``<|tool_call>...``. If the
+ # template advertises tools but does NOT use any of these markers,
+ # the parser cannot honour the emission - drop the pill. ``{"name":``
+ # catches Llama-3.2's ``custom_tools`` shape whose template instructs
+ # the model to "Respond in the format {\"name\": ..., \"parameters\":
+ # ...}" without a ``<|python_tag|>`` prefix.
+ _PARSER_MARKERS = (
+ "",
+ "",
+ "[TOOL_CALLS]",
+ "<|tool_call>",
+ '{"name":',
+ '{\\"name\\":',
+ )
if (
flags.get("supports_tools")
and chat_template
- and "" not in chat_template
- and " None:
" Do NOT output code blocks -- use the python tool instead."
)
-# Regex for stripping leaked tool-call XML from assistant messages/stream
+# Regex for stripping leaked tool-call markup from assistant messages /
+# stream. Covers every emission format the shared parser handles
+# (Qwen / Hermes ````, Qwen3.5 ````, Llama-3
+# ``<|python_tag|>``, Mistral ``[TOOL_CALLS]`` pre-v11 array and v11+
+# ``name{json}``, Gemma 4 ``<|tool_call>...``). Closed
+# pairs only so in-progress markup stays buffered upstream.
_TOOL_XML_RE = _re.compile(
- r".*?|.*?",
+ "|".join(
+ [
+ r".*?",
+ r".*?",
+ r"<\|tool_call>.*?",
+ r"\[TOOL_CALLS\]\s*\[.*?\](?:\s*)?",
+ r"\[TOOL_CALLS\]\s*[\w\.\-]+\s*(?:\[ARGS\])?\s*\{.*?\}",
+ r"<\|python_tag\|>[^\n<]*",
+ ]
+ ),
_re.DOTALL,
)
logger = get_logger(__name__)
diff --git a/studio/backend/tests/test_safetensors_capability_advertise.py b/studio/backend/tests/test_safetensors_capability_advertise.py
index c3ee5b9ff1..b63e835d2a 100644
--- a/studio/backend/tests/test_safetensors_capability_advertise.py
+++ b/studio/backend/tests/test_safetensors_capability_advertise.py
@@ -129,11 +129,11 @@ def test_detect_safetensors_features_gptoss_disables_tools():
assert flags["supports_tools"] is False
-# Llama-3 / Mistral templates advertise tool handling but the model emits
-# tool calls in <|python_tag|> / [TOOL_CALLS] format -- not the
-# / , [TOOL_CALLS], and
+# <|tool_call>). The route helper must surface supports_tools=True for
+# all of them so the UI enables the pill. Only templates whose tool
+# format is NONE of the five known markers should be suppressed.
LLAMA3_TEMPLATE = """
{%- if tools %}
@@ -165,27 +165,88 @@ def test_detect_safetensors_features_gptoss_disables_tools():
{%- endfor %}
"""
+GEMMA4_TEMPLATE = """
+{%- if tools %}
+ {{- 'Tools available. Emit calls as ' }}
+ {{- '<|tool_call>call:NAME{key:<|"|>val<|"|>}' }}
+ {%- for tool in tools %}
+ {{- tool | tojson }}
+ {%- endfor %}
+{%- endif %}
+"""
+
-def test_detect_safetensors_features_llama3_template_suppresses_tools():
- """Llama-3 emits <|python_tag|>; safetensors loop cannot parse it."""
+def test_detect_safetensors_features_llama3_template_keeps_tools_on():
+ """Llama-3 emits <|python_tag|>; parser now supports it."""
from routes.inference import _detect_safetensors_features
backend = SimpleNamespace(active_model_name = "unsloth/Llama-3.2-3B-Instruct")
flags = _detect_safetensors_features(backend, LLAMA3_TEMPLATE)
- assert flags["supports_tools"] is False
+ assert flags["supports_tools"] is True
-def test_detect_safetensors_features_mistral_template_suppresses_tools():
- """Mistral emits [TOOL_CALLS]; safetensors loop cannot parse it."""
+def test_detect_safetensors_features_mistral_template_keeps_tools_on():
+ """Mistral emits [TOOL_CALLS]; parser now supports it."""
from routes.inference import _detect_safetensors_features
backend = SimpleNamespace(active_model_name = "unsloth/mistral-7b-instruct-v0.3")
flags = _detect_safetensors_features(backend, MISTRAL_TEMPLATE)
+ assert flags["supports_tools"] is True
+
+
+def test_detect_safetensors_features_gemma4_template_keeps_tools_on():
+ """Gemma 4 emits <|tool_call>; parser now supports it."""
+ from routes.inference import _detect_safetensors_features
+
+ backend = SimpleNamespace(active_model_name = "unsloth/gemma-4-E2B-it-UD-MLX-4bit")
+ flags = _detect_safetensors_features(backend, GEMMA4_TEMPLATE)
+ assert flags["supports_tools"] is True
+
+
+LLAMA3_2_BARE_JSON_TEMPLATE = """
+{%- if tools %}
+ {{- 'Given the following functions, respond with JSON for a function call.' }}
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary}.' }}
+ {%- for tool in tools %}
+ {{- tool | tojson }}
+ {%- endfor %}
+{%- endif %}
+{%- for message in messages %}
+ {%- if 'tool_calls' in message %}
+ {{- '{"name": "' + message.tool_calls[0].function.name + '", '}}
+ {{- '"parameters": ' + (message.tool_calls[0].function.arguments | tojson) + '}' }}
+ {%- endif %}
+{%- endfor %}
+"""
+
+
+def test_detect_safetensors_features_llama3_2_bare_json_keeps_tools_on():
+ """Llama-3.2 emits bare JSON ``{"name":..., "parameters":...}`` -- the
+ parser now handles that path, so the pill must stay enabled."""
+ from routes.inference import _detect_safetensors_features
+
+ backend = SimpleNamespace(active_model_name = "unsloth/Llama-3.2-3B-Instruct")
+ flags = _detect_safetensors_features(backend, LLAMA3_2_BARE_JSON_TEMPLATE)
+ assert flags["supports_tools"] is True
+
+
+def test_detect_safetensors_features_unknown_format_suppresses_tools():
+ """A template that advertises tools but uses no known marker must
+ be suppressed so the UI does not enable an unsupported pill."""
+ from routes.inference import _detect_safetensors_features
+
+ tpl = (
+ "{%- if tools %}<|im_start|>system\n"
+ "Emit tool calls as JSON-RPC notifications inside the response."
+ "<|im_end|>{%- endif %}"
+ )
+ backend = SimpleNamespace(active_model_name = "custom/unknown-tool-format")
+ flags = _detect_safetensors_features(backend, tpl)
assert flags["supports_tools"] is False
def test_detect_safetensors_features_qwen_tool_call_keeps_tools_on():
- """Sanity check: gate only suppresses non-Qwen formats."""
+ """Sanity check: Qwen marker still flips supports_tools."""
from routes.inference import _detect_safetensors_features
backend = SimpleNamespace(active_model_name = "unsloth/Qwen3-0.6B")
diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py
index 923af87c4f..c838cab72d 100644
--- a/studio/backend/tests/test_safetensors_tool_loop.py
+++ b/studio/backend/tests/test_safetensors_tool_loop.py
@@ -130,6 +130,292 @@ def test_strip_markup_unclosed_final(self):
assert "partial" in strip_tool_markup(text)
+class TestParserMultiFormat:
+ """Parser coverage for Llama-3 / Mistral / Gemma 4 emission formats.
+
+ Each model family upstream of GGUF emits a different tool-call
+ shape. The shared parser must turn all of them into the same
+ OpenAI ``{name, arguments}`` shape so the safetensors / MLX
+ agentic loop is family-agnostic.
+ """
+
+ # ── Llama-3 ────────────────────────────────────────────────────
+
+ def test_llama3_python_tag_dot_call(self):
+ # Llama-3 built-in tools: <|python_tag|>NAME.call(k="v", ...).
+ import json
+
+ text = '<|python_tag|>brave_search.call(query="weather in Tokyo")'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "brave_search"
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"query": "weather in Tokyo"}
+
+ def test_llama3_python_tag_dot_call_multi_arg(self):
+ import json
+
+ text = (
+ "<|python_tag|>get_weather.call("
+ 'location="Tokyo", units="celsius", days=5)'
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"location": "Tokyo", "units": "celsius", "days": 5}
+
+ def test_llama3_python_tag_json_form(self):
+ import json
+
+ text = (
+ '<|python_tag|>{"name":"web_search",' '"parameters":{"query":"hi","n":5}}'
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "web_search"
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"query": "hi", "n": 5}
+
+ def test_llama3_python_tag_json_form_with_eom(self):
+ # Llama-3 emits ``<|eom_id|>`` after the JSON; must not break parsing.
+ import json
+
+ text = (
+ '<|python_tag|>{"name":"python",'
+ '"parameters":{"code":"print(2+2)"}}<|eom_id|>'
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"code": "print(2+2)"}
+
+ def test_llama3_strip_markup_final(self):
+ text = '<|python_tag|>brave_search.call(query="x")'
+ assert strip_tool_markup(text, final = True) == ""
+
+ # ── Llama-3.2 bare JSON ``custom_tools`` ─────────────────────
+
+ def test_llama3_2_bare_json_parameters(self):
+ # Llama-3.2-Instruct emits bare JSON directly as content; no
+ # <|python_tag|> prefix per its training template.
+ import json
+
+ text = '{"name":"web_search","parameters":{"query":"Tokyo weather"}}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "web_search"
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"query": "Tokyo weather"}
+
+ def test_llama3_2_bare_json_arguments_key(self):
+ import json
+
+ text = '{"name":"add","arguments":{"a":1,"b":2}}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"a": 1, "b": 2}
+
+ def test_llama3_2_bare_json_multi_call(self):
+ # Llama-3 may chain calls with ``; `` per training template.
+ text = '{"name":"a","parameters":{}}; ' '{"name":"b","parameters":{}}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 2
+ assert result[0]["function"]["name"] == "a"
+ assert result[1]["function"]["name"] == "b"
+
+ def test_llama3_2_bare_json_with_eom_sentinel(self):
+ text = '{"name":"x","parameters":{"y":1}}<|eom_id|>'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "x"
+
+ def test_llama3_2_bare_json_leading_sentinel_skipped(self):
+ # Sometimes prior <|eot_id|> leaks into the next turn.
+ text = '<|eot_id|>{"name":"x","parameters":{}}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "x"
+
+ def test_llama3_2_bare_json_plain_prose_does_not_fire(self):
+ # Defensive: must NOT fire on plain assistant prose.
+ text = "Hello world, how are you today?"
+ assert parse_tool_calls_from_text(text) == []
+
+ def test_llama3_2_bare_json_embedded_in_prose_does_not_fire(self):
+ # Defensive: JSON embedded in prose must NOT fire (parser is
+ # strict about content STARTING with `{`).
+ text = 'The tool result was: {"name":"foo"}'
+ assert parse_tool_calls_from_text(text) == []
+
+ def test_llama3_2_bare_json_missing_name_does_not_fire(self):
+ text = '{"result":"ok","data":[1,2,3]}'
+ assert parse_tool_calls_from_text(text) == []
+
+ def test_llama3_2_bare_json_missing_args_does_not_fire(self):
+ text = '{"name":"x"}'
+ assert parse_tool_calls_from_text(text) == []
+
+ def test_llama3_2_bare_json_args_not_dict_does_not_fire(self):
+ text = '{"name":"x","parameters":42}'
+ assert parse_tool_calls_from_text(text) == []
+
+ # ── Mistral pre-v11 ───────────────────────────────────────────
+
+ def test_mistral_pre_v11_array(self):
+ import json
+
+ text = (
+ '[TOOL_CALLS] [{"name":"web_search",'
+ '"arguments":{"query":"hello"},"id":"abc"}]'
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "web_search"
+ # Mistral provides its own id; preserve it.
+ assert result[0]["id"] == "abc"
+ assert json.loads(result[0]["function"]["arguments"]) == {"query": "hello"}
+
+ def test_mistral_pre_v11_array_multi(self):
+ text = (
+ '[TOOL_CALLS] [{"name":"a","arguments":{"x":1},"id":"id1"},'
+ '{"name":"b","arguments":{"y":2},"id":"id2"}]'
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 2
+ assert result[0]["function"]["name"] == "a"
+ assert result[1]["function"]["name"] == "b"
+
+ def test_mistral_pre_v11_unclosed_array(self):
+ # Closing ``]`` truncated -- parser must heal off individual objects.
+ text = '[TOOL_CALLS] [{"name":"web_search","arguments":{"q":"x"},"id":"id"}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "web_search"
+
+ # ── Mistral v11+ ───────────────────────────────────────────────
+
+ def test_mistral_v11_single(self):
+ # Magistral / Mistral Small 3.1: bare ``name{json}`` after trigger.
+ import json
+
+ text = '[TOOL_CALLS]add{"a":3.5,"b":4}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "add"
+ assert json.loads(result[0]["function"]["arguments"]) == {"a": 3.5, "b": 4}
+
+ def test_mistral_v11_parallel(self):
+ # v11+ parallel: ``[TOOL_CALLS]a{...}[TOOL_CALLS]b{...}``.
+ text = '[TOOL_CALLS]add{"a":1}[TOOL_CALLS]sub{"b":2}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 2
+ assert result[0]["function"]["name"] == "add"
+ assert result[1]["function"]["name"] == "sub"
+
+ def test_mistral_v11_with_args_marker(self):
+ # Ministral / Mistral Large 3: ``[TOOL_CALLS]name[ARGS]{json}``.
+ import json
+
+ text = '[TOOL_CALLS]add[ARGS]{"a":1,"b":2}'
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "add"
+ assert json.loads(result[0]["function"]["arguments"]) == {"a": 1, "b": 2}
+
+ def test_mistral_strip_markup_v11(self):
+ text = '[TOOL_CALLS]add{"a":1}'
+ assert strip_tool_markup(text, final = True) == ""
+
+ # ── Gemma 4 ───────────────────────────────────────────────────
+
+ def test_gemma4_simple_call(self):
+ import json
+
+ text = (
+ "<|tool_call>call:get_weather{"
+ 'location:<|"|>Tokyo<|"|>,units:<|"|>celsius<|"|>}'
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "get_weather"
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {"location": "Tokyo", "units": "celsius"}
+
+ def test_gemma4_with_primitives(self):
+ import json
+
+ text = (
+ "<|tool_call>call:set_pref{"
+ "enabled:true,attempts:5,threshold:1.5,nickname:null}"
+ )
+ result = parse_tool_calls_from_text(text)
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args == {
+ "enabled": True,
+ "attempts": 5,
+ "threshold": 1.5,
+ "nickname": None,
+ }
+
+ def test_gemma4_nested_args(self):
+ # Gemma 4 nests dicts / lists with bare keys and ``<|"|>`` strings.
+ import json
+
+ text = (
+ "<|tool_call>call:search{"
+ 'query:<|"|>foo<|"|>,filters:{site:<|"|>example.com<|"|>,recent:true},'
+ 'tags:[<|"|>a<|"|>,<|"|>b<|"|>]}'
+ )
+ result = parse_tool_calls_from_text(text)
+ args = json.loads(result[0]["function"]["arguments"])
+ assert args["query"] == "foo"
+ assert args["filters"] == {"site": "example.com", "recent": True}
+ assert args["tags"] == ["a", "b"]
+
+ def test_gemma4_multi_call(self):
+ text = (
+ "<|tool_call>call:a{x:1}" "<|tool_call>call:b{y:2}"
+ )
+ result = parse_tool_calls_from_text(text)
+ assert len(result) == 2
+ assert result[0]["function"]["name"] == "a"
+ assert result[1]["function"]["name"] == "b"
+
+ def test_gemma4_unclosed_does_not_raise(self):
+ # Truncated mid-stream; must not raise.
+ text = '<|tool_call>call:foo{x:<|"|>bar<|"|>'
+ result = parse_tool_calls_from_text(text)
+ assert isinstance(result, list)
+
+ def test_gemma4_strip_markup_final(self):
+ text = "<|tool_call>call:foo{x:1}"
+ assert strip_tool_markup(text, final = True) == ""
+
+ # ── Cross-format sentinels ────────────────────────────────────
+
+ def test_all_markers_in_tool_xml_signals(self):
+ # Streaming buffer wakes up on every emission marker.
+ from core.inference.tool_call_parser import TOOL_XML_SIGNALS
+
+ for marker in (
+ "",
+ "",
+ "[TOOL_CALLS]",
+ "<|tool_call>",
+ ):
+ assert (
+ marker in TOOL_XML_SIGNALS
+ ), f"streaming loop would not wake on {marker!r}"
+
+ def test_has_tool_signal_for_all_formats(self):
+ assert has_tool_signal('<|python_tag|>brave_search.call(q="x")')
+ assert has_tool_signal('[TOOL_CALLS] [{"name":"x"}]')
+ assert has_tool_signal('[TOOL_CALLS]add{"a":1}')
+ assert has_tool_signal("<|tool_call>call:foo{}")
+
+
# ────────────────────────────────────────────────────────────────────
# run_safetensors_tool_loop
# ────────────────────────────────────────────────────────────────────
@@ -280,6 +566,71 @@ def test_function_xml_form(self):
contents = [e for e in events if e["type"] == "content"]
assert "Result: 1" in contents[-1]["text"]
+ def test_llama3_python_tag_form(self):
+ # The agentic loop must recognise Llama-3's <|python_tag|>
+ # marker, drain the rest of the turn, and execute the call.
+ loop, exec_fn = _make_loop(
+ turns = [
+ [
+ "<|python_tag|>web_search.call(",
+ 'query="weather in Tokyo"',
+ ")",
+ ],
+ ["The weather is sunny."],
+ ],
+ exec_results = ["Sunny, 22C"],
+ )
+ events = _collect_events(loop)
+ assert exec_fn.calls == [("web_search", {"query": "weather in Tokyo"})]
+ contents = [e for e in events if e["type"] == "content"]
+ assert "sunny" in contents[-1]["text"].lower()
+
+ def test_mistral_pre_v11_form(self):
+ # Pre-v11 Mistral emission: ``[TOOL_CALLS] [{...}]``.
+ loop, exec_fn = _make_loop(
+ turns = [
+ [
+ '[TOOL_CALLS] [{"name":"web_search",',
+ '"arguments":{"query":"hi"},"id":"abc"}]',
+ ],
+ ["done"],
+ ],
+ exec_results = ["ok"],
+ )
+ events = _collect_events(loop)
+ assert exec_fn.calls == [("web_search", {"query": "hi"})]
+ # Mistral-provided ids must propagate to tool_start events.
+ tool_start = next(e for e in events if e["type"] == "tool_start")
+ assert tool_start["tool_call_id"] == "abc"
+
+ def test_mistral_v11_form(self):
+ # v11+ Mistral emission: bare ``name{json}`` after the trigger.
+ loop, exec_fn = _make_loop(
+ turns = [
+ ['[TOOL_CALLS]web_search{"query":"hi"}'],
+ ["done"],
+ ],
+ exec_results = ["ok"],
+ )
+ events = _collect_events(loop)
+ assert exec_fn.calls == [("web_search", {"query": "hi"})]
+
+ def test_gemma4_form(self):
+ # Gemma 4 emission: ``<|tool_call>call:NAME{...}``.
+ loop, exec_fn = _make_loop(
+ turns = [
+ [
+ "<|tool_call>call:web_search{",
+ 'query:<|"|>weather<|"|>',
+ "}",
+ ],
+ ["sunny"],
+ ],
+ exec_results = ["Sunny, 22C"],
+ )
+ events = _collect_events(loop)
+ assert exec_fn.calls == [("web_search", {"query": "weather"})]
+
def test_truncated_unclosed_tool_call(self):
loop, exec_fn = _make_loop(
turns = [