diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 0325620b2d..3f8704d7bc 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -38,7 +38,14 @@ _TOOL_ALL_PATS, _TOOL_CLOSED_PATS, parse_tool_calls_from_text, - strip_tool_call_markup, +) + +# Share strip / signal constants with the multi-format parser so the +# BUFFERING state machine also catches Llama-3 / Mistral / Gemma 4 +# emissions (legacy helper only knew / str: if not auto_heal_tool_calls: return text - return strip_tool_call_markup(text, final = final) + return _shared_strip_tool_markup(text, final = final) - # XML prefixes that signal a tool call in content. - # Empty when auto_heal is disabled so the buffer never - # speculatively holds content for XML detection. - _TOOL_XML_SIGNALS = ( - ("", " str: arguments = json.loads(raw_args) except (json.JSONDecodeError, ValueError): if auto_heal_tool_calls: - arguments = {"query": raw_args} + # Canonical per-tool heal key (must match + # safetensors_agentic._CANONICAL_HEAL_ARG) + # so bare-string emissions still run the + # intended tool. + _heal_key = { + "python": "code", + "terminal": "command", + }.get(tool_name, "query") + arguments = {_heal_key: raw_args} else: arguments = {"raw": raw_args} else: diff --git a/studio/backend/core/inference/safetensors_agentic.py b/studio/backend/core/inference/safetensors_agentic.py index 73bb3d090a..1f442161e9 100644 --- a/studio/backend/core/inference/safetensors_agentic.py +++ b/studio/backend/core/inference/safetensors_agentic.py @@ -18,6 +18,7 @@ """ import json +import re import threading from typing import Callable, Generator, Optional from urllib.parse import urlparse @@ -42,6 +43,24 @@ # Buffer cap while waiting to disambiguate a possible tool-call prefix. _MAX_BUFFER_CHARS = 32 +# Forward-looking intent ("I'll...", "First, ...", "Step 1:") that +# means the model is planning rather than answering. Used to nudge it +# to call a tool. Excludes "I can / I should / I want / let's" because +# those also appear in direct answers and explanations. Mirrors GGUF. +_INTENT_SIGNAL = re.compile( + r"(?i)(" + r"\b(i['’](ll|m going to|m gonna)|i am (going to|gonna)|i will|i shall|let me|allow me)\b" + r"|\b(?:first\b|step \d+:?|here['’]?s (?:my |the |a )?(?:plan|approach))" + r"|\b(?:now i|next i)\b" + r")" +) +_MAX_REPROMPTS = 3 +_REPROMPT_MAX_CHARS = 2000 +_REPROMPT_INSTRUCTION = ( + "STOP. Do NOT write code or explain. You MUST call a tool NOW. " + "Call web_search or python immediately." +) + def _status_for_tool(tool_name: str, arguments: dict) -> str: """Return a human-readable status line matching the GGUF path.""" @@ -142,6 +161,7 @@ def run_safetensors_tool_loop( if (tool.get("function") or {}).get("name") } next_call_id = 0 + reprompt_count = 0 if max_tool_iterations <= 0: # 0 = disabled (same contract as the GGUF loop). @@ -152,7 +172,9 @@ def run_safetensors_tool_loop( _state_streaming = 1 _state_draining = 2 - for iteration in range(max_tool_iterations + 1): + # Reserve re-prompt slots so they don't eat the caller's tool budget. + _extra_iters = _MAX_REPROMPTS if max_tool_iterations > 0 else 0 + for iteration in range(max_tool_iterations + _extra_iters + 1): if cancel_event is not None and cancel_event.is_set(): return @@ -242,24 +264,57 @@ def run_safetensors_tool_loop( if stripped and has_tool_signal(stripped): detect_state = _state_draining else: + # Drain the buffer and fall through to STREAMING so the + # intent re-prompt + safety-net parser can still fire on + # short emissions like "Let me search." that never exit + # BUFFERING (would otherwise silently end the loop). if content_buffer: cumulative_display += content_buffer - yield { - "type": "content", - "text": strip_tool_markup(cumulative_display, final = True), - } - yield {"type": "status", "text": ""} - return + cleaned = strip_tool_markup(cumulative_display, final = True) + if len(cleaned) > len(last_emitted): + last_emitted = cleaned + yield {"type": "content", "text": cleaned} + detect_state = _state_streaming if detect_state == _state_streaming: - # No tool detected mid-stream -- check for late tool XML. - safety_tc = None - if has_tool_signal(content_accum): - safety_tc = parse_tool_calls_from_text( - content_accum, - id_offset = next_call_id, - ) + # No tool XML detected mid-stream -- run the parser anyway. + # The Llama-3.2 bare-JSON tool form ``{"name":..,"parameters":..}`` + # carries no XML signal, so gating this on has_tool_signal() + # silently dropped real tool calls and re-prompted the model into + # giving up. parse_tool_calls_from_text is strict (it only fires + # on a valid tool-call shape), so plain answers stay untouched. + # This mirrors what llama-server already does for GGUF. + safety_tc = parse_tool_calls_from_text( + content_accum, + id_offset = next_call_id, + ) if not safety_tc: + # Re-prompt only when the model planned without acting + # (intent signal present); direct answers like "4" or + # "Hello!" never trigger. Mirrors GGUF. + _stripped = content_accum.strip() + if ( + tools + and reprompt_count < _MAX_REPROMPTS + and 0 < len(_stripped) < _REPROMPT_MAX_CHARS + and _INTENT_SIGNAL.search(_stripped) + and not final_attempt_done + ): + reprompt_count += 1 + logger.info( + "Safetensors re-prompt %d/%d: model planned without " + "calling tools (%d chars)", + reprompt_count, + _MAX_REPROMPTS, + len(_stripped), + ) + conversation.append({"role": "assistant", "content": _stripped}) + conversation.append( + {"role": "user", "content": _REPROMPT_INSTRUCTION} + ) + yield {"type": "status", "text": ""} + continue + # Final answer: streaming already emitted content. # Skip a final=True re-strip so literal "" # in prose survives when no real tool call parsed. @@ -379,7 +434,10 @@ def run_safetensors_tool_loop( # Clear the status badge before the next turn. yield {"type": "status", "text": ""} - if iteration + 1 >= max_tool_iterations and not final_attempt_done: + # Track against the caller-requested cap, excluding re-prompt + # slots so a stalling model still gets a final-answer attempt. + _tool_iters_done = iteration + 1 - reprompt_count + if _tool_iters_done >= max_tool_iterations and not final_attempt_done: # Budget exhausted; nudge a final plain answer. final_attempt_done = True conversation.append( diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index 2f94990623..0bb881d1a7 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -2,35 +2,60 @@ # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ -Backend-neutral tool-call XML parser shared by GGUF and safetensors. -Tolerates missing closing tags in either ``{json}`` -or ``v...`` shape. +Backend-neutral tool-call parser shared by GGUF, safetensors, and MLX. + +Covers the emission formats so the safetensors + MLX agentic loop sees +the same call shape llama-server normalises for GGUF: + + - ``{json}`` (Qwen / Hermes) + - ``v`` (Qwen3.5 xml) + - ``<|python_tag|>NAME.call(k="v", ...)`` (Llama-3 built-in tools) + - ``<|python_tag|>{"name":..., "parameters":...}`` (Llama-3 custom) + - ``{"name":..., "parameters":...}`` (Llama-3.2 bare JSON) + - ``[TOOL_CALLS] [{...}, ...]`` (Mistral v0.3 / Nemo / Small) + - ``[TOOL_CALLS]name{json}`` (Mistral v11+ / Magistral) + - ``[TOOL_CALLS]name[ARGS]{json}`` (Ministral / Mistral Large 3) + - ``<|tool_call>call:NAME{k:<|"|>v<|"|>}`` (Gemma 4) + +Closing tags / brackets are tolerated when missing because models +frequently truncate them mid-stream. """ import json import re +from typing import Any + + +# Markers that flip the streaming buffer from STREAMING to DRAINING so +# partial markup never leaks before the parser sees it. +TOOL_XML_SIGNALS = ( + "", + "", + "[TOOL_CALLS]", + "<|tool_call>", +) -# _TOOL_CLOSED_PATS: closed pairs only. _TOOL_ALL_PATS: also trailing -# unclosed runs so truncated tails don't leak markup. -# Function-name char set tracks OpenAI's ^[a-zA-Z0-9_-]{1,64}$ so MCP -# tool names that contain a hyphen (e.g. mcp__srv__list-issues) parse -# the same as the built-in web_search/python/terminal names. +# Closed pairs only (mid-stream); _TOOL_ALL_PATS also eats unclosed +# tails for end-of-turn cleanup. ``[\w-]+`` on ```` tracks +# OpenAI's ``^[a-zA-Z0-9_-]{1,64}$`` so MCP tool names with hyphens +# (mcp__srv__list-issues) parse the same as the built-ins. _TOOL_CLOSED_PATS = [ re.compile(r".*?", re.DOTALL), - re.compile(r".*?", re.DOTALL), + re.compile(r'.*?', re.DOTALL), + re.compile(r"<\|tool_call>.*?", re.DOTALL), ] _TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [ re.compile(r".*$", re.DOTALL), - re.compile(r".*$", re.DOTALL), + re.compile(r'.*$', re.DOTALL), + re.compile(r"<\|tool_call>.*$", re.DOTALL), + re.compile(r"\[TOOL_CALLS\].*$", re.DOTALL), + re.compile(r"<\|python_tag\|>.*$", re.DOTALL), ] -# Prefixes the streaming buffer watches for to gate in-progress text. -TOOL_XML_SIGNALS = ("", "{json}``. _TC_JSON_START_RE = re.compile(r"\s*\{") -_TC_FUNC_START_RE = re.compile(r"\s*") -_TC_END_TAG_RE = re.compile(r"") +# Qwen3.5 / Hermes ``v`` AND the attribute +# form ``v`` used by MiniCPM-5, +# MiniMax-M2, etc. Name char class is ``[\w\.\-]+`` so MCP tool names +# with hyphens (mcp__srv__list-issues) and dotted module names parse +# the same as the built-ins. Name lands in group(1) or group(2). +_TC_FUNC_START_RE = re.compile(r'\s*') +# Body terminates at either ```` (Hermes wrapper) OR +# ```` (Qwen3.5 / MiniCPM-5 standalone) so the parser stops +# at the close tag even when prose follows. Without ````, +# trailing prose leaked into the last parameter value. +_TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") -# Parameter names can carry hyphens too (e.g. MCP tool schemas with -# `issue-number`, `repo-name`); using `\w+` here dropped those keys. -_TC_PARAM_START_RE = re.compile(r"\s*") -_TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") +_TC_PARAM_START_RE = re.compile( + r'<(?:parameter|param)(?:=([\w\.\-]+)|\s+name="([\w\.\-]+)")>\s*' +) +_TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") +# Llama-3 ``<|python_tag|>NAME.call(...)``. +_LLAMA3_PYTHON_TAG = "<|python_tag|>" +_LLAMA3_PY_CALL_RE = re.compile( + r"<\|python_tag\|>\s*([\w\.\-]+)\s*\.\s*call\s*\(", +) +_LLAMA3_KV_RE = re.compile( + r"""(\w+)\s*=\s*(?:"((?:\\.|[^"\\])*)"|(-?\d+(?:\.\d+)?)|(true|false|null))""", + re.VERBOSE, +) -def strip_tool_markup(text: str, *, final: bool = False) -> str: - """Strip tool-call XML from streamed text. +# Mistral ``[TOOL_CALLS]`` trigger. v11+ chains them, each followed by +# a bare name plus ``{json}`` (Magistral) or ``[ARGS]{json}`` (Ministral +# / Large 3). +_MISTRAL_TRIGGER = "[TOOL_CALLS]" +_MISTRAL_ARGS_MARKER = "[ARGS]" +# Mistral Small 3.2 emits ``name[CALL_ID][ARGS]{json}``; the call-id +# segment is absent on Ministral / Magistral. llama.cpp distinguishes the +# two on the presence of ``[CALL_ID]`` (common/chat.cpp). +_MISTRAL_CALL_ID_MARKER = "[CALL_ID]" +# Magistral wraps reasoning in ``[THINK] ... [/THINK]`` before the answer. +# A ``[TOOL_CALLS]`` inside that block is chain-of-thought, not a real call. +_MISTRAL_THINK_OPEN = "[THINK]" +_MISTRAL_THINK_CLOSE = "[/THINK]" +_MISTRAL_V11_NAME_RE = re.compile(r"\s*([\w\.\-]+)\s*") - ``final=False`` only removes closed pairs (used during streaming so - in-progress XML stays buffered). ``final=True`` also removes a - trailing unclosed run and trims the result. +# Gemma 4: ``<|tool_call>call:NAME{...}``, ``<|"|>`` wraps strings. +_GEMMA_TC_RE = re.compile(r"<\|tool_call>\s*call\s*:\s*([\w\.\-]+)\s*\{") +_GEMMA_STR_BEGIN = '<|"|>' +_GEMMA_STR_END = '<|"|>' +_GEMMA_TC_END = "" + + +def _balanced_bracket_end(text: str, start: int) -> int | None: + """Index of `]` matching `[` at ``text[start]``; ignores brackets + in JSON strings. None if unmatched.""" + if start >= len(text) or text[start] != "[": + return None + depth = 0 + in_string = False + esc = False + i = start + while i < len(text): + ch = text[i] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + elif ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + +def _skip_mistral_call_id(text: str, pos: int) -> int: + """Skip an optional ``[CALL_ID]`` segment (Mistral Small 3.2) at + ``pos``. Returns the position of the next meaningful token (``[ARGS]`` + or ``{``), or ``pos`` unchanged when no ``[CALL_ID]`` is present.""" + n = len(text) + i = pos + while i < n and text[i] in " \t\n\r": + i += 1 + if not text.startswith(_MISTRAL_CALL_ID_MARKER, i): + return pos + i += len(_MISTRAL_CALL_ID_MARKER) + while i < n and text[i] in " \t\n\r": + i += 1 + # The id is a short opaque token; stop at whitespace or the next marker. + while i < n and text[i] not in " \t\n\r[{": + i += 1 + while i < n and text[i] in " \t\n\r": + i += 1 + return i + + +def _strip_mistral_reasoning(content: str) -> str: + """Drop a leading Magistral ``[THINK] ... [/THINK]`` reasoning block so a + ``[TOOL_CALLS]`` emitted *inside* the chain-of-thought is not mistaken for + a real call (llama.cpp parses reasoning separately; see test-chat.cpp). + + Only a leading block is removed -- the reasoning prefix is always first, + so a literal ``[THINK]`` inside a later tool argument is left untouched. + An unclosed leading ``[THINK]`` (still streaming) means nothing has been + committed yet, so everything from it onward is dropped.""" + i = 0 + n = len(content) + while i < n and content[i] in " \t\n\r": + i += 1 + if not content.startswith(_MISTRAL_THINK_OPEN, i): + return content + close = content.find(_MISTRAL_THINK_CLOSE, i + len(_MISTRAL_THINK_OPEN)) + if close == -1: + return content[:i] + return content[:i] + content[close + len(_MISTRAL_THINK_CLOSE) :] + + +def _strip_mistral_closed_calls(text: str) -> str: + """Strip cleanly-closed ``[TOOL_CALLS]`` blocks (array, ``name{json}``, + or ``name[ARGS]{json}``) via balanced brace/bracket scanning. + + A non-greedy ``\\{.*?\\}`` would truncate at the first ``}`` and lose + nested JSON. Unclosed runs are left for ``final=True`` cleanup. """ + n = len(text) + out = [] + cursor = 0 + while cursor < n: + idx = text.find(_MISTRAL_TRIGGER, cursor) + if idx == -1: + out.append(text[cursor:]) + break + out.append(text[cursor:idx]) + body_start = idx + len(_MISTRAL_TRIGGER) + i = body_start + while i < n and text[i] in " \t\n\r": + i += 1 + # Array shape: ``[TOOL_CALLS] [...]``. + if i < n and text[i] == "[": + end = _balanced_bracket_end(text, i) + if end is None: + # Truncated; let caller buffer / final-strip. + out.append(text[idx:]) + break + cursor = end + 1 + if text.startswith("", cursor): + cursor += len("") + continue + # Named shape: ``[TOOL_CALLS] name [ARGS]? { json }``. + name_match = _MISTRAL_V11_NAME_RE.match(text, i) + if not name_match: + out.append(text[idx:body_start]) + cursor = body_start + continue + i = name_match.end() + while i < n and text[i] in " \t\n\r": + i += 1 + i = _skip_mistral_call_id(text, i) + if text.startswith(_MISTRAL_ARGS_MARKER, i): + i += len(_MISTRAL_ARGS_MARKER) + while i < n and text[i] in " \t\n\r": + i += 1 + if i >= n or text[i] != "{": + out.append(text[idx:i]) + cursor = i + continue + end = _balanced_brace_end(text, i) + if end is None: + out.append(text[idx:]) + break + cursor = end + 1 + return "".join(out) + + +def strip_tool_markup(text: str, *, final: bool = False) -> str: + """Strip tool-call markup. ``final=False`` keeps in-progress + markup buffered; ``final=True`` also drops trailing unclosed runs + and trims.""" + text = _strip_mistral_closed_calls(text) pats = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS for pat in pats: text = pat.sub("", text) return text.strip() if final else text +def has_tool_signal(text: str) -> bool: + return any(s in text for s in TOOL_XML_SIGNALS) + + def parse_tool_calls_from_text(content: str, *, id_offset: int = 0) -> list[dict]: - """Parse OpenAI-format ``tool_calls`` from model text. + """Return OpenAI-format tool calls. Tries each format and returns + as soon as one matches so we never double-count.""" + for parser in ( + _parse_tool_call_json, # Qwen / Hermes + _parse_function_xml, # Qwen3.5 / Hermes XML + _parse_llama3_python_tag, # Llama-3 + _parse_mistral_tool_calls, # Mistral + _parse_gemma_tool_calls, # Gemma 4 + ): + calls = parser(content, id_offset = id_offset) + if calls: + return calls + + # Llama-3.2 bare ``{"name":..., "parameters":...}``. Strict: only + # fires on content that starts with ``{`` and parses as the right + # shape, so plain prose stays untouched. + return _parse_llama3_bare_json(content, id_offset = id_offset) + + +def _parse_tool_call_json(content: str, *, id_offset: int) -> list[dict]: + out: list[dict] = [] + for m in _TC_JSON_START_RE.finditer(content): + brace_start = m.end() - 1 + end = _balanced_brace_end(content, brace_start) + if end is None: + continue + try: + obj = json.loads(content[brace_start : end + 1]) + except (json.JSONDecodeError, ValueError): + continue + name = obj.get("name", "") + # Accept both ``arguments`` (Hermes/Qwen) and ``parameters`` + # (Llama-3 template drift) so a fine-tune that swaps the key + # keeps its payload instead of silently parsing to ``{}``. + args = obj.get("arguments") + if args is None: + args = obj.get("parameters", {}) + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + args_str = json.dumps({"value": args}) + if not name: + continue + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + return out + - Returns a list of ``{"id", "type", "function": {"name", "arguments"}}`` - dicts. ``arguments`` is always a JSON string so callers can hand it - straight back into an OpenAI-style response. +def _parse_function_xml(content: str, *, id_offset: int) -> list[dict]: + out: list[dict] = [] + func_starts = list(_TC_FUNC_START_RE.finditer(content)) + for idx, fm in enumerate(func_starts): + # group(1) is ````, group(2) is ````. + func_name = fm.group(1) or fm.group(2) + body_start = fm.end() + next_func = ( + func_starts[idx + 1].start() if idx + 1 < len(func_starts) else len(content) + ) + end_tag = _TC_END_TAG_RE.search(content[body_start:]) + if end_tag: + body_end = body_start + end_tag.start() + else: + body_end = len(content) + body_end = min(body_end, next_func) + body = _TC_FUNC_CLOSE_RE.sub("", content[body_start:body_end]) - Handles two shapes: + args: dict = {} + param_starts = list(_TC_PARAM_START_RE.finditer(body)) + if len(param_starts) == 1: + pm = param_starts[0] + val = _TC_PARAM_CLOSE_RE.sub("", body[pm.end() :]) + args[pm.group(1) or pm.group(2)] = val.strip() + else: + for pidx, pm in enumerate(param_starts): + val_start = pm.end() + next_param = ( + param_starts[pidx + 1].start() + if pidx + 1 < len(param_starts) + else len(body) + ) + val = _TC_PARAM_CLOSE_RE.sub("", body[val_start:next_param]) + args[pm.group(1) or pm.group(2)] = val.strip() - - JSON inside ```` tags: - ``{"name":"web_search","arguments":{"query":"..."}}`` - - XML-style function blocks: - ``v`` + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": func_name, "arguments": json.dumps(args)}, + } + ) + return out - Closing tags (````, ````, ````) - are all optional since models frequently omit them. + +def _parse_llama3_python_tag(content: str, *, id_offset: int) -> list[dict]: + """Parse the four Llama-3 emissions: ``<|python_tag|>NAME.call(...)`` + (built-in), ``<|python_tag|>{"name":..., "parameters":...}`` (custom), + multi-call via ``; `` separators, ``parameters`` or ``arguments`` key. """ - tool_calls: list[dict] = [] + out: list[dict] = [] + if _LLAMA3_PYTHON_TAG not in content: + return out - # Pattern 1: {json}. Balanced-brace scan that skips - # braces inside JSON strings. - for m in _TC_JSON_START_RE.finditer(content): - brace_start = m.end() - 1 # position of the opening { - depth, i = 0, brace_start + # 1. ``NAME.call(...)`` built-in form. + for m in _LLAMA3_PY_CALL_RE.finditer(content): + name = m.group(1) + i = m.end() + depth = 1 in_string = False - while i < len(content): + esc = False + while i < len(content) and depth > 0: ch = content[i] if in_string: - if ch == "\\" and i + 1 < len(content): - i += 2 - continue - if ch == '"': + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': in_string = False + else: + if ch == '"': + in_string = True + elif ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + break + i += 1 + body = content[m.end() : i] + args: dict[str, Any] = {} + for kv in _LLAMA3_KV_RE.finditer(body): + k = kv.group(1) + if kv.group(2) is not None: + # ``json.loads`` on a quoted string handles \n/\t/\uXXXX + # escapes correctly AND keeps literal UTF-8 bytes (emoji + # / CJK) intact -- the older ``bytes.decode('unicode_escape')`` + # path mangled non-ASCII. + try: + args[k] = json.loads('"' + kv.group(2) + '"') + except (json.JSONDecodeError, ValueError): + args[k] = kv.group(2) + elif kv.group(3) is not None: + v = kv.group(3) + args[k] = float(v) if "." in v else int(v) + elif kv.group(4) is not None: + args[k] = {"true": True, "false": False, "null": None}[kv.group(4)] + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": json.dumps(args)}, + } + ) + + # 2. ``<|python_tag|>{"name":..., "parameters":...}``. ``raw_decode`` + # peels multiple ``; ``-separated objects from one emission. + if not out: + decoder = json.JSONDecoder() + idx = content.find(_LLAMA3_PYTHON_TAG) + while idx >= 0: + search_from = idx + len(_LLAMA3_PYTHON_TAG) + cursor = search_from + while cursor < len(content): + brace = content.find("{", cursor) + if brace < 0: + break + # Stop at the next ``<|python_tag|>``. + next_tag = content.find(_LLAMA3_PYTHON_TAG, search_from, brace) + if next_tag >= 0: + break + try: + obj, end_offset = decoder.raw_decode(content[brace:]) + except (json.JSONDecodeError, ValueError): + cursor = brace + 1 + continue + if not isinstance(obj, dict): + cursor = brace + end_offset + continue + name = obj.get("name") or obj.get("function") or "" + args = ( + obj.get("parameters") + if "parameters" in obj + else obj.get("arguments", {}) + ) + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + args_str = json.dumps({"value": args}) + if name: + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + cursor = brace + end_offset + idx = content.find(_LLAMA3_PYTHON_TAG, cursor) + return out + + +def _parse_llama3_bare_json(content: str, *, id_offset: int) -> list[dict]: + """Llama-3.2 ``custom_tools``: bare ``{"name":..., "parameters":{...}}`` + without ``<|python_tag|>``. Strict (must start with ``{`` after sentinel + strip; ``name`` non-empty; ``parameters`` or ``arguments`` is a dict) so + plain prose and tool-message echoes don't trigger.""" + out: list[dict] = [] + stripped = content.lstrip() + # Sentinels can chain in any order, so loop until none match. + _sentinels = ( + "<|begin_of_text|>", + "<|eot_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", + ) + # Role labels Meta's Llama-3 chat template inserts between + # ``<|start_header_id|>`` and ``<|end_header_id|>`` -- consume so a + # round-trip like + # ``<|start_header_id|>assistant<|end_header_id|>\n\n{json}`` + # reaches the JSON body. + _header_roles = ("assistant", "user", "system", "tool", "ipython") + while True: + stripped = stripped.lstrip() + matched = False + for sentinel in _sentinels: + if stripped.startswith(sentinel): + stripped = stripped[len(sentinel) :] + if sentinel == "<|start_header_id|>": + for role in _header_roles: + if stripped.startswith(role): + stripped = stripped[len(role) :] + break + matched = True + break + if not matched: + break + if not stripped.startswith("{"): + return out + + decoder = json.JSONDecoder() + cursor = 0 + n = len(stripped) + while cursor < n: + # Skip whitespace and the Llama-3 ``;`` inter-call separator. + while cursor < n and stripped[cursor] in " \t\n\r;": + cursor += 1 + if cursor >= n or stripped[cursor] != "{": + break + try: + obj, end_offset = decoder.raw_decode(stripped[cursor:]) + except (json.JSONDecodeError, ValueError): + break + if not isinstance(obj, dict): + break + name = obj.get("name") or obj.get("function") or "" + if not isinstance(name, str) or not name: + break + # ``parameters`` must be a dict (Llama-3 spec). + # ``arguments`` may be a dict or a JSON-string of a dict (OpenAI shape). + # Anything looser would fire on prose like ``{"name":"x","parameters":"sentence"}``. + if "parameters" in obj: + args = obj.get("parameters") + if not isinstance(args, dict): + break + args_str = json.dumps(args) + elif "arguments" in obj: + args = obj.get("arguments") + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + try: + parsed = json.loads(args) + except (json.JSONDecodeError, ValueError): + break + if not isinstance(parsed, dict): + break + args_str = args + else: + break + else: + break + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": args_str}, + } + ) + cursor += end_offset + return out + + +def _parse_mistral_tool_calls(content: str, *, id_offset: int) -> list[dict]: + """Parse all Mistral emissions: pre-v11 ``[TOOL_CALLS][...]`` / + ``[TOOL_CALLS]{...}`` and v11+ ``[TOOL_CALLS]name{json}`` / + ``[TOOL_CALLS]name[ARGS]{json}`` (parallel-friendly).""" + out: list[dict] = [] + content = _strip_mistral_reasoning(content) + idx = content.find(_MISTRAL_TRIGGER) + if idx < 0: + return out + + # Disambiguate the first occurrence: array (pre-v11), single object + # (pre-v11), or bare-name (v11+). + j = idx + len(_MISTRAL_TRIGGER) + k = j + while k < len(content) and content[k] in " \t\n\r": + k += 1 + if k >= len(content): + return out + + if content[k] == "[": + return _parse_mistral_array(content, k, id_offset) + + if content[k] == "{": + # Pre-v11 single ``{"name":...}``; fall through if it doesn't + # carry a ``name`` so v11+ handling still gets a chance. + end = _balanced_brace_end(content, k) + if end is not None: + try: + obj = json.loads(content[k : end + 1]) + if isinstance(obj, dict) and obj.get("name"): + _consume_mistral_call(content[k : end + 1], out, id_offset) + return out + except (json.JSONDecodeError, ValueError): + pass + + # v11+: walk every ``[TOOL_CALLS]``, parsing ``name{json}`` or + # ``name[ARGS]{json}`` after each trigger. + pos = idx + while pos >= 0: + cur = pos + len(_MISTRAL_TRIGGER) + nm = _MISTRAL_V11_NAME_RE.match(content, cur) + if not nm: + pos = content.find(_MISTRAL_TRIGGER, cur) + continue + name = nm.group(1) + after_name = nm.end() + after_name = _skip_mistral_call_id(content, after_name) + if content.startswith(_MISTRAL_ARGS_MARKER, after_name): + after_name += len(_MISTRAL_ARGS_MARKER) + while after_name < len(content) and content[after_name] in " \t\n\r": + after_name += 1 + if after_name >= len(content) or content[after_name] != "{": + pos = content.find(_MISTRAL_TRIGGER, cur) + continue + end = _balanced_brace_end(content, after_name) + if end is None: + break + try: + args = json.loads(content[after_name : end + 1]) + except (json.JSONDecodeError, ValueError): + pos = content.find(_MISTRAL_TRIGGER, end + 1) + continue + if not isinstance(args, dict): + pos = content.find(_MISTRAL_TRIGGER, end + 1) + continue + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(args), + }, + } + ) + pos = content.find(_MISTRAL_TRIGGER, end + 1) + return out + + +def _parse_mistral_array(content: str, start: int, id_offset: int) -> list[dict]: + """Pre-v11 ``[TOOL_CALLS] [{...}, ...]`` array form.""" + out: list[dict] = [] + j = start + depth = 0 + in_string = False + esc = False + while j < len(content): + ch = content[j] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True elif ch == '"': + in_string = False + else: + if ch == '"': in_string = True - elif ch == "{": + elif ch == "[": depth += 1 - elif ch == "}": + elif ch == "]": depth -= 1 if depth == 0: break - i += 1 - if depth == 0: - json_str = content[brace_start : i + 1] - try: - obj = json.loads(json_str) - tc = { - "id": f"call_{id_offset + len(tool_calls)}", - "type": "function", - "function": { - "name": obj.get("name", ""), - "arguments": obj.get("arguments", {}), - }, - } - if isinstance(tc["function"]["arguments"], dict): - tc["function"]["arguments"] = json.dumps( - tc["function"]["arguments"] - ) - tool_calls.append(tc) - except (json.JSONDecodeError, ValueError): - pass + j += 1 + body = content[start : j + 1] if depth == 0 else content[start:] - # Pattern 2: v... -- closing tags - # optional; don't use as body boundary because code - # values can contain that literal. - if not tool_calls: - func_starts = list(_TC_FUNC_START_RE.finditer(content)) - for idx, fm in enumerate(func_starts): - func_name = fm.group(1) - body_start = fm.end() - next_func = ( - func_starts[idx + 1].start() - if idx + 1 < len(func_starts) - else len(content) - ) - end_tag = _TC_END_TAG_RE.search(content[body_start:]) - if end_tag: - body_end = body_start + end_tag.start() - else: - body_end = len(content) - body_end = min(body_end, next_func) - body = content[body_start:body_end] - body = _TC_FUNC_CLOSE_RE.sub("", body) - - arguments: dict = {} - param_starts = list(_TC_PARAM_START_RE.finditer(body)) - if len(param_starts) == 1: - # Single param: take everything to body end so - # embedded in code strings is preserved. - pm = param_starts[0] - val = body[pm.end() :] - val = _TC_PARAM_CLOSE_RE.sub("", val) - arguments[pm.group(1)] = val.strip() - else: - for pidx, pm in enumerate(param_starts): - param_name = pm.group(1) - val_start = pm.end() - next_param = ( - param_starts[pidx + 1].start() - if pidx + 1 < len(param_starts) - else len(body) - ) - val = body[val_start:next_param] - val = _TC_PARAM_CLOSE_RE.sub("", val) - arguments[param_name] = val.strip() + try: + arr = json.loads(body) + if isinstance(arr, list): + for obj in arr: + if isinstance(obj, dict): + _consume_mistral_call(json.dumps(obj), out, id_offset) + return out + except (json.JSONDecodeError, ValueError): + pass - tc = { - "id": f"call_{id_offset + len(tool_calls)}", + # Healing path for unclosed arrays: walk objects by hand. + for m in re.finditer(r"\{", body): + end = _balanced_brace_end(body, m.start()) + if end is None: + continue + _consume_mistral_call(body[m.start() : end + 1], out, id_offset) + return out + + +def _consume_mistral_call(obj_text: str, out: list[dict], id_offset: int) -> None: + try: + obj = json.loads(obj_text) + except (json.JSONDecodeError, ValueError): + return + if not isinstance(obj, dict): + return + name = obj.get("name") or "" + args = obj.get("arguments") or {} + if isinstance(args, dict): + args_str = json.dumps(args) + elif isinstance(args, str): + args_str = args + else: + args_str = json.dumps({"value": args}) + if name: + out.append( + { + "id": obj.get("id") or f"call_{id_offset + len(out)}", "type": "function", - "function": { - "name": func_name, - "arguments": json.dumps(arguments), - }, + "function": {"name": name, "arguments": args_str}, } - tool_calls.append(tc) + ) - return tool_calls +def _parse_gemma_tool_calls(content: str, *, id_offset: int) -> list[dict]: + """Gemma 4: ``<|tool_call>call:NAME{k:<|"|>v<|"|>, ...}``.""" + out: list[dict] = [] + for m in _GEMMA_TC_RE.finditer(content): + name = m.group(1) + body_start = m.end() - 1 + end_marker = content.find(_GEMMA_TC_END, body_start) + scan_end = end_marker if end_marker >= 0 else len(content) + end = _gemma_balanced_brace_end(content, body_start, scan_end) + if end is None: + continue + body = content[body_start + 1 : end] + try: + args = _gemma_parse_mapping_body(body) + except Exception: + args = {} + out.append( + { + "id": f"call_{id_offset + len(out)}", + "type": "function", + "function": {"name": name, "arguments": json.dumps(args)}, + } + ) + return out -def has_tool_signal(text: str) -> bool: - """Return True if ``text`` contains any tool-call XML signal.""" - return any(s in text for s in TOOL_XML_SIGNALS) + +def _balanced_brace_end(text: str, brace_pos: int) -> int | None: + """Index of `}` matching `{` at ``brace_pos``; ignores braces inside + JSON strings. None if unmatched.""" + if brace_pos >= len(text) or text[brace_pos] != "{": + return None + depth = 0 + in_string = False + esc = False + i = brace_pos + while i < len(text): + ch = text[i] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + elif ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + +def _gemma_balanced_brace_end(text: str, brace_pos: int, hard_stop: int) -> int | None: + """Like ``_balanced_brace_end`` but skips ``<|"|>`` strings and + matches `{`/`[` symmetrically.""" + if brace_pos >= len(text) or text[brace_pos] != "{": + return None + depth = 0 + i = brace_pos + while i < hard_stop: + if text.startswith(_GEMMA_STR_BEGIN, i): + close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) + if close < 0: + return None + i = close + len(_GEMMA_STR_END) + continue + ch = text[i] + if ch == "{" or ch == "[": + depth += 1 + elif ch == "}" or ch == "]": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + +def _gemma_parse_value(text: str, i: int): + """Parse one Gemma arg value at ``i``; returns ``(value, next_index)``.""" + if text.startswith(_GEMMA_STR_BEGIN, i): + close = text.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) + if close < 0: + return text[i + len(_GEMMA_STR_BEGIN) :], len(text) + return text[i + len(_GEMMA_STR_BEGIN) : close], close + len(_GEMMA_STR_END) + if text[i] == "{": + end = _gemma_balanced_brace_end(text, i, len(text)) + if end is None: + return {}, len(text) + return _gemma_parse_mapping_body(text[i + 1 : end]), end + 1 + if text[i] == "[": + j, depth = i, 0 + while j < len(text): + if text.startswith(_GEMMA_STR_BEGIN, j): + k = text.find(_GEMMA_STR_END, j + len(_GEMMA_STR_BEGIN)) + if k < 0: + j = len(text) + break + j = k + len(_GEMMA_STR_END) + continue + ch = text[j] + if ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + if depth == 0: + break + j += 1 + body = text[i + 1 : j] + items: list[Any] = [] + k = 0 + while k < len(body): + if body[k] in " \t\n\r,": + k += 1 + continue + v, k = _gemma_parse_value(body, k) + items.append(v) + return items, j + 1 + # Primitive: number / true/false/null / bare identifier. + end = i + while ( + end < len(text) + and text[end] not in ",}]" + and not text.startswith(_GEMMA_STR_BEGIN, end) + ): + end += 1 + raw = text[i:end].strip() + if raw == "true": + return True, end + if raw == "false": + return False, end + if raw == "null": + return None, end + try: + return int(raw), end + except ValueError: + pass + try: + return float(raw), end + except ValueError: + pass + return raw, end + + +def _gemma_parse_mapping_body(body: str) -> dict[str, Any]: + """Parse a Gemma argument mapping (content between `{` and `}`).""" + out: dict[str, Any] = {} + i = 0 + n = len(body) + while i < n: + while i < n and body[i] in " \t\n\r,": + i += 1 + if i >= n: + break + if body.startswith(_GEMMA_STR_BEGIN, i): + close = body.find(_GEMMA_STR_END, i + len(_GEMMA_STR_BEGIN)) + if close < 0: + break + key = body[i + len(_GEMMA_STR_BEGIN) : close] + i = close + len(_GEMMA_STR_END) + else: + kstart = i + while i < n and body[i] != ":": + i += 1 + key = body[kstart:i].strip() + while i < n and body[i] in " \t\n\r": + i += 1 + if i < n and body[i] == ":": + i += 1 + while i < n and body[i] in " \t\n\r": + i += 1 + if i >= n: + out[key] = None + break + v, i = _gemma_parse_value(body, i) + out[key] = v + return out diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 259337616c..6ba878e488 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -260,16 +260,24 @@ def _detect_safetensors_features(backend, chat_template: Optional[str]) -> dict: "supports_tools": False, } ) - # Our safetensors loop only parses {json} - # and .... Llama uses <|python_tag|>, - # Mistral uses [TOOL_CALLS]; advertising tools for those would - # enable a pill the parser cannot honour. GGUF is unaffected -- - # llama-server normalises every format into structured deltas. + # Markers the safetensors / MLX parser recognises. If the template + # advertises tools but uses none of them, drop the pill (parser + # can't honour the emission). The two ``{"name":`` variants cover + # Llama-3.2 ``custom_tools`` whose template prompts the bare-JSON + # form without a ``<|python_tag|>`` prefix. + _PARSER_MARKERS = ( + "", + "", + "[TOOL_CALLS]", + "<|tool_call>", + '{"name":', + '{\\"name\\":', + ) if ( flags.get("supports_tools") and chat_template - and "" not in chat_template - and " None: " Do NOT output code blocks -- use the python tool instead." ) -# Strip tool-call XML the speculative buffer in core/inference/llama_cpp.py -# split across the visible/DRAIN boundary. Four leak shapes: -# 1. well-formed `...` / `...` -# 2. orphan opening to EOF (close was DRAINED) -# 3. bare orphan close (open was DRAINED) -# 4. tail-only `` (outer close truncated by EOS); anchored to -# `\Z` so mid-text `` in user code samples survives. +# Strip leaked tool-call markup. Covers every shared-parser format AND +# the four leak shapes the speculative buffer in ``llama_cpp.py`` splits +# across the visible/DRAIN boundary (closed pair, orphan open to EOF, +# bare orphan close, tail-only ````). Mistral ``[TOOL_CALLS]`` +# is delegated to the parser's balanced-brace helper -- a non-greedy +# ``\{.*?\}`` here would truncate nested JSON at the first ``}``. _TOOL_XML_RE = _re.compile( - # Hyphen in the name char-class matches MCP tool names with dashes - # (mcp__srv__list-issues) which would otherwise leak past this strip. - r"<(?:tool_call|function=[\w-]+)>.*?(?:|\Z)" - r"|" - r"|\s*\Z", + "|".join( + [ + # Tool-call / function XML: closed pair OR orphan open to EOF. + # ``[\w-]+`` on the name accepts MCP tool ids with hyphens + # (mcp__srv__list-issues) that ``\w+`` alone would let leak. + r"<(?:tool_call|function=[\w-]+)>.*?(?:|\Z)", + # Bare orphan close (open was DRAINED upstream). + r"", + # Gemma 4. + r"<\|tool_call>.*?", + # Llama-3 ``<|python_tag|>...`` to the next ``<|`` sentinel + # or EOF. ``(?:[^<]|<(?!\|))*`` (not ``[^\n<]*`` or + # ``[^\n]*``) keeps literal ``<``, newlines, and embedded + # JSON inside the strip. + r"<\|python_tag\|>(?:[^<]|<(?!\|))*", + # Tail-only ```` (anchored so mid-text survives). + r"\s*\Z", + ] + ), _re.DOTALL, ) + + +def _strip_tool_xml(text: str) -> str: + """Combine the Mistral balanced-brace helper with ``_TOOL_XML_RE``.""" + from studio.backend.core.inference.tool_call_parser import ( + _strip_mistral_closed_calls, + ) + + return _TOOL_XML_RE.sub("", _strip_mistral_closed_calls(text)) + + logger = get_logger(__name__) @@ -2763,7 +2795,7 @@ async def audio_input_stream(): if _msg.get("role") == "assistant" and isinstance( _msg.get("content"), str ): - _msg["content"] = _TOOL_XML_RE.sub("", _msg["content"]).strip() + _msg["content"] = _strip_tool_xml(_msg["content"]).strip() def gguf_generate_with_tools(): return llama_backend.generate_chat_completion_with_tools( @@ -2864,7 +2896,7 @@ async def gguf_tool_stream(): # the last sanitized snapshot so cross-chunk XML # tags are handled correctly. raw_cumulative = event.get("text", "") - clean_cumulative = _TOOL_XML_RE.sub("", raw_cumulative) + clean_cumulative = _strip_tool_xml(raw_cumulative) new_text = clean_cumulative[len(prev_text) :] prev_text = clean_cumulative if not new_text: @@ -3268,7 +3300,7 @@ async def gguf_stream_chunks(): _sf_chat_messages.append( { **_msg, - "content": _TOOL_XML_RE.sub("", _msg["content"]).strip(), + "content": _strip_tool_xml(_msg["content"]).strip(), } ) else: @@ -3355,7 +3387,7 @@ async def sf_tool_stream(): # Diff cumulative cleaned text against last snapshot. raw_cumulative = event.get("text", "") - clean_cumulative = _TOOL_XML_RE.sub("", raw_cumulative) + clean_cumulative = _strip_tool_xml(raw_cumulative) new_text = clean_cumulative[len(prev_text) :] prev_text = clean_cumulative if not new_text: @@ -3427,7 +3459,7 @@ def _drain_to_text(): if cancel_event.is_set(): break if event.get("type") == "content": - full_text = _TOOL_XML_RE.sub("", event.get("text", "")) + full_text = _strip_tool_xml(event.get("text", "")) return full_text content_text = await asyncio.to_thread(_drain_to_text) @@ -4959,7 +4991,7 @@ async def anthropic_messages( # Strip stale tool-call XML from conversation for _msg in openai_messages: if _msg.get("role") == "assistant" and isinstance(_msg.get("content"), str): - _msg["content"] = _TOOL_XML_RE.sub("", _msg["content"]).strip() + _msg["content"] = _strip_tool_xml(_msg["content"]).strip() def _run_tool_gen(): return llama_backend.generate_chat_completion_with_tools( @@ -5051,7 +5083,7 @@ async def _stream(): # Strip leaked tool-call XML from content events if event.get("type") == "content": event = dict(event) - event["text"] = _TOOL_XML_RE.sub("", event["text"]) + event["text"] = _strip_tool_xml(event["text"]) for line in emitter.feed(event): yield line except Exception as e: @@ -5142,7 +5174,7 @@ async def _anthropic_tool_non_streaming(run_gen, message_id, model_name): etype = event.get("type", "") if etype == "content": # Strip leaked tool-call XML - clean = _TOOL_XML_RE.sub("", event["text"]) + clean = _strip_tool_xml(event["text"]) new = clean[len(prev_text) :] prev_text = clean if new: @@ -5479,7 +5511,7 @@ async def _anthropic_passthrough_non_streaming( content_blocks = [] text = message.get("content") or "" if text: - text = _TOOL_XML_RE.sub("", text).strip() + text = _strip_tool_xml(text).strip() if text: content_blocks.append(AnthropicResponseTextBlock(text = text)) diff --git a/studio/backend/tests/test_cpu_threads.py b/studio/backend/tests/test_cpu_threads.py index 0dbcbdb74b..1224941622 100644 --- a/studio/backend/tests/test_cpu_threads.py +++ b/studio/backend/tests/test_cpu_threads.py @@ -63,16 +63,18 @@ def test_cpu_thread_cap_is_opt_in(raw): # Anything that is not a positive integer raises a clear ValueError. -@pytest.mark.parametrize("raw", ["zero", "0", "-3", "1.5", "abc", "8a", "0x4", "1e3", "4 0"]) +@pytest.mark.parametrize( + "raw", ["zero", "0", "-3", "1.5", "abc", "8a", "0x4", "1e3", "4 0"] +) def test_cpu_thread_cap_requires_positive_integer(raw): - with pytest.raises(ValueError, match="must be a positive integer"): + with pytest.raises(ValueError, match = "must be a positive integer"): configure_cpu_threads({"UNSLOTH_CPU_THREADS": raw}) # env=None path uses real os.environ (production call from run.py / main.py). def test_cpu_thread_cap_uses_os_environ_when_env_is_none(monkeypatch): for variable in (*_THREAD_POOL_ENV_VARS, "UNSLOTH_CPU_THREADS"): - monkeypatch.delenv(variable, raising=False) + monkeypatch.delenv(variable, raising = False) monkeypatch.setenv("UNSLOTH_CPU_THREADS", "3") configure_cpu_threads() @@ -84,7 +86,7 @@ def test_cpu_thread_cap_uses_os_environ_when_env_is_none(monkeypatch): # Calling twice must not flip any seeded value. def test_cpu_thread_cap_idempotent(monkeypatch): for variable in (*_THREAD_POOL_ENV_VARS, "UNSLOTH_CPU_THREADS"): - monkeypatch.delenv(variable, raising=False) + monkeypatch.delenv(variable, raising = False) monkeypatch.setenv("UNSLOTH_CPU_THREADS", "5") configure_cpu_threads() @@ -138,9 +140,9 @@ def test_invalid_cpu_thread_cap_exits_without_traceback(entry_point): result = subprocess.run( [sys.executable, str(entry_point)], - env=env, - capture_output=True, - text=True, + env = env, + capture_output = True, + text = True, ) assert result.returncode == 1 diff --git a/studio/backend/tests/test_safetensors_capability_advertise.py b/studio/backend/tests/test_safetensors_capability_advertise.py index c3ee5b9ff1..b63e835d2a 100644 --- a/studio/backend/tests/test_safetensors_capability_advertise.py +++ b/studio/backend/tests/test_safetensors_capability_advertise.py @@ -129,11 +129,11 @@ def test_detect_safetensors_features_gptoss_disables_tools(): assert flags["supports_tools"] is False -# Llama-3 / Mistral templates advertise tool handling but the model emits -# tool calls in <|python_tag|> / [TOOL_CALLS] format -- not the -# / , [TOOL_CALLS], and +# <|tool_call>). The route helper must surface supports_tools=True for +# all of them so the UI enables the pill. Only templates whose tool +# format is NONE of the five known markers should be suppressed. LLAMA3_TEMPLATE = """ {%- if tools %} @@ -165,27 +165,88 @@ def test_detect_safetensors_features_gptoss_disables_tools(): {%- endfor %} """ +GEMMA4_TEMPLATE = """ +{%- if tools %} + {{- 'Tools available. Emit calls as ' }} + {{- '<|tool_call>call:NAME{key:<|"|>val<|"|>}' }} + {%- for tool in tools %} + {{- tool | tojson }} + {%- endfor %} +{%- endif %} +""" + -def test_detect_safetensors_features_llama3_template_suppresses_tools(): - """Llama-3 emits <|python_tag|>; safetensors loop cannot parse it.""" +def test_detect_safetensors_features_llama3_template_keeps_tools_on(): + """Llama-3 emits <|python_tag|>; parser now supports it.""" from routes.inference import _detect_safetensors_features backend = SimpleNamespace(active_model_name = "unsloth/Llama-3.2-3B-Instruct") flags = _detect_safetensors_features(backend, LLAMA3_TEMPLATE) - assert flags["supports_tools"] is False + assert flags["supports_tools"] is True -def test_detect_safetensors_features_mistral_template_suppresses_tools(): - """Mistral emits [TOOL_CALLS]; safetensors loop cannot parse it.""" +def test_detect_safetensors_features_mistral_template_keeps_tools_on(): + """Mistral emits [TOOL_CALLS]; parser now supports it.""" from routes.inference import _detect_safetensors_features backend = SimpleNamespace(active_model_name = "unsloth/mistral-7b-instruct-v0.3") flags = _detect_safetensors_features(backend, MISTRAL_TEMPLATE) + assert flags["supports_tools"] is True + + +def test_detect_safetensors_features_gemma4_template_keeps_tools_on(): + """Gemma 4 emits <|tool_call>; parser now supports it.""" + from routes.inference import _detect_safetensors_features + + backend = SimpleNamespace(active_model_name = "unsloth/gemma-4-E2B-it-UD-MLX-4bit") + flags = _detect_safetensors_features(backend, GEMMA4_TEMPLATE) + assert flags["supports_tools"] is True + + +LLAMA3_2_BARE_JSON_TEMPLATE = """ +{%- if tools %} + {{- 'Given the following functions, respond with JSON for a function call.' }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary}.' }} + {%- for tool in tools %} + {{- tool | tojson }} + {%- endfor %} +{%- endif %} +{%- for message in messages %} + {%- if 'tool_calls' in message %} + {{- '{"name": "' + message.tool_calls[0].function.name + '", '}} + {{- '"parameters": ' + (message.tool_calls[0].function.arguments | tojson) + '}' }} + {%- endif %} +{%- endfor %} +""" + + +def test_detect_safetensors_features_llama3_2_bare_json_keeps_tools_on(): + """Llama-3.2 emits bare JSON ``{"name":..., "parameters":...}`` -- the + parser now handles that path, so the pill must stay enabled.""" + from routes.inference import _detect_safetensors_features + + backend = SimpleNamespace(active_model_name = "unsloth/Llama-3.2-3B-Instruct") + flags = _detect_safetensors_features(backend, LLAMA3_2_BARE_JSON_TEMPLATE) + assert flags["supports_tools"] is True + + +def test_detect_safetensors_features_unknown_format_suppresses_tools(): + """A template that advertises tools but uses no known marker must + be suppressed so the UI does not enable an unsupported pill.""" + from routes.inference import _detect_safetensors_features + + tpl = ( + "{%- if tools %}<|im_start|>system\n" + "Emit tool calls as JSON-RPC notifications inside the response." + "<|im_end|>{%- endif %}" + ) + backend = SimpleNamespace(active_model_name = "custom/unknown-tool-format") + flags = _detect_safetensors_features(backend, tpl) assert flags["supports_tools"] is False def test_detect_safetensors_features_qwen_tool_call_keeps_tools_on(): - """Sanity check: gate only suppresses non-Qwen formats.""" + """Sanity check: Qwen marker still flips supports_tools.""" from routes.inference import _detect_safetensors_features backend = SimpleNamespace(active_model_name = "unsloth/Qwen3-0.6B") diff --git a/studio/backend/tests/test_safetensors_tool_loop.py b/studio/backend/tests/test_safetensors_tool_loop.py index 923af87c4f..5bfc98eeb9 100644 --- a/studio/backend/tests/test_safetensors_tool_loop.py +++ b/studio/backend/tests/test_safetensors_tool_loop.py @@ -130,6 +130,383 @@ def test_strip_markup_unclosed_final(self): assert "partial" in strip_tool_markup(text) +class TestParserMultiFormat: + """Parser coverage for Llama-3 / Mistral / Gemma 4 emission formats. + + Each model family upstream of GGUF emits a different tool-call + shape. The shared parser must turn all of them into the same + OpenAI ``{name, arguments}`` shape so the safetensors / MLX + agentic loop is family-agnostic. + """ + + # ── Llama-3 ──────────────────────────────────────────────────── + + def test_llama3_python_tag_dot_call(self): + # Llama-3 built-in tools: <|python_tag|>NAME.call(k="v", ...). + import json + + text = '<|python_tag|>brave_search.call(query="weather in Tokyo")' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "brave_search" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"query": "weather in Tokyo"} + + def test_llama3_python_tag_dot_call_multi_arg(self): + import json + + text = ( + "<|python_tag|>get_weather.call(" + 'location="Tokyo", units="celsius", days=5)' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"location": "Tokyo", "units": "celsius", "days": 5} + + def test_llama3_python_tag_json_form(self): + import json + + text = ( + '<|python_tag|>{"name":"web_search",' '"parameters":{"query":"hi","n":5}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"query": "hi", "n": 5} + + def test_llama3_python_tag_json_form_with_eom(self): + # Llama-3 emits ``<|eom_id|>`` after the JSON; must not break parsing. + import json + + text = ( + '<|python_tag|>{"name":"python",' + '"parameters":{"code":"print(2+2)"}}<|eom_id|>' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"code": "print(2+2)"} + + def test_llama3_strip_markup_final(self): + text = '<|python_tag|>brave_search.call(query="x")' + assert strip_tool_markup(text, final = True) == "" + + # ── Llama-3.2 bare JSON ``custom_tools`` ───────────────────── + + def test_llama3_2_bare_json_parameters(self): + # Llama-3.2-Instruct emits bare JSON directly as content; no + # <|python_tag|> prefix per its training template. + import json + + text = '{"name":"web_search","parameters":{"query":"Tokyo weather"}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"query": "Tokyo weather"} + + def test_llama3_2_bare_json_arguments_key(self): + import json + + text = '{"name":"add","arguments":{"a":1,"b":2}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"a": 1, "b": 2} + + def test_llama3_2_bare_json_multi_call(self): + # Llama-3 may chain calls with ``; `` per training template. + text = '{"name":"a","parameters":{}}; ' '{"name":"b","parameters":{}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "a" + assert result[1]["function"]["name"] == "b" + + def test_llama3_2_bare_json_with_eom_sentinel(self): + text = '{"name":"x","parameters":{"y":1}}<|eom_id|>' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "x" + + def test_llama3_2_bare_json_leading_sentinel_skipped(self): + # Sometimes prior <|eot_id|> leaks into the next turn. + text = '<|eot_id|>{"name":"x","parameters":{}}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "x" + + def test_llama3_2_bare_json_plain_prose_does_not_fire(self): + # Defensive: must NOT fire on plain assistant prose. + text = "Hello world, how are you today?" + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_embedded_in_prose_does_not_fire(self): + # Defensive: JSON embedded in prose must NOT fire (parser is + # strict about content STARTING with `{`). + text = 'The tool result was: {"name":"foo"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_missing_name_does_not_fire(self): + text = '{"result":"ok","data":[1,2,3]}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_missing_args_does_not_fire(self): + text = '{"name":"x"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_args_not_dict_does_not_fire(self): + text = '{"name":"x","parameters":42}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_string_parameters_does_not_fire(self): + # Llama-3 spec: parameters must be a dict. Prose like + # ``{"name":"foo","parameters":"a sentence"}`` must NOT trigger. + text = '{"name":"foo","parameters":"this is a sentence"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_string_arguments_not_json_does_not_fire(self): + # OpenAI ``arguments`` may be a JSON-string of a dict, but a + # plain non-JSON string must not pass the guard. + text = '{"name":"foo","arguments":"not json"}' + assert parse_tool_calls_from_text(text) == [] + + def test_llama3_2_bare_json_string_arguments_json_dict_fires(self): + # OpenAI shape: arguments is a JSON-encoded string of a dict. + text = '{"name":"foo","arguments":"{\\"q\\":\\"x\\"}"}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "foo" + # arguments stays as the original JSON-string. + assert result[0]["function"]["arguments"] == '{"q":"x"}' + + def test_llama3_2_bare_json_string_arguments_json_non_dict_does_not_fire(self): + # JSON-string that parses to a list / scalar / null must NOT fire. + for bad in ( + '{"name":"foo","arguments":"[1,2,3]"}', + '{"name":"foo","arguments":"\\"plain\\""}', + '{"name":"foo","arguments":"null"}', + '{"name":"foo","arguments":"42"}', + ): + assert parse_tool_calls_from_text(bad) == [], bad + + # ── Mistral pre-v11 ─────────────────────────────────────────── + + def test_mistral_pre_v11_array(self): + import json + + text = ( + '[TOOL_CALLS] [{"name":"web_search",' + '"arguments":{"query":"hello"},"id":"abc"}]' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + # Mistral provides its own id; preserve it. + assert result[0]["id"] == "abc" + assert json.loads(result[0]["function"]["arguments"]) == {"query": "hello"} + + def test_mistral_pre_v11_array_multi(self): + text = ( + '[TOOL_CALLS] [{"name":"a","arguments":{"x":1},"id":"id1"},' + '{"name":"b","arguments":{"y":2},"id":"id2"}]' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "a" + assert result[1]["function"]["name"] == "b" + + def test_mistral_pre_v11_unclosed_array(self): + # Closing ``]`` truncated -- parser must heal off individual objects. + text = '[TOOL_CALLS] [{"name":"web_search","arguments":{"q":"x"},"id":"id"}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "web_search" + + # ── Mistral v11+ ─────────────────────────────────────────────── + + def test_mistral_v11_single(self): + # Magistral / Mistral Small 3.1: bare ``name{json}`` after trigger. + import json + + text = '[TOOL_CALLS]add{"a":3.5,"b":4}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "add" + assert json.loads(result[0]["function"]["arguments"]) == {"a": 3.5, "b": 4} + + def test_mistral_v11_parallel(self): + # v11+ parallel: ``[TOOL_CALLS]a{...}[TOOL_CALLS]b{...}``. + text = '[TOOL_CALLS]add{"a":1}[TOOL_CALLS]sub{"b":2}' + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "add" + assert result[1]["function"]["name"] == "sub" + + def test_mistral_v11_with_args_marker(self): + # Ministral / Mistral Large 3: ``[TOOL_CALLS]name[ARGS]{json}``. + import json + + text = '[TOOL_CALLS]add[ARGS]{"a":1,"b":2}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "add" + assert json.loads(result[0]["function"]["arguments"]) == {"a": 1, "b": 2} + + def test_mistral_strip_markup_v11(self): + text = '[TOOL_CALLS]add{"a":1}' + assert strip_tool_markup(text, final = True) == "" + + def test_mistral_call_id_form(self): + # Mistral Small 3.2: ``[TOOL_CALLS]name[CALL_ID][ARGS]{json}``. + # The ``[CALL_ID]`` segment must be skipped, not treated as a stop + # (llama.cpp test-chat.cpp:4785 parses this to one call). + import json + + text = '[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{"arg1": 1}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "special_function" + assert json.loads(result[0]["function"]["arguments"]) == {"arg1": 1} + + def test_mistral_call_id_form_parallel(self): + text = ( + '[TOOL_CALLS]special_function[CALL_ID]000000001[ARGS]{"arg1": 1}' + "[TOOL_CALLS]special_function_with_opt[CALL_ID]000000002" + '[ARGS]{"arg1": 1, "arg2": 2}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "special_function" + assert result[1]["function"]["name"] == "special_function_with_opt" + + def test_mistral_call_id_form_stripped(self): + text = '[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{"arg1": 1}' + assert strip_tool_markup(text, final = True) == "" + + def test_mistral_think_reasoning_ignored(self): + # Magistral wraps reasoning in ``[THINK]...[/THINK]``. A ``[TOOL_CALLS]`` + # inside the reasoning is chain-of-thought, not a real call; only the + # call after ``[/THINK]`` counts (llama.cpp test-chat.cpp:2285). + import json + + text = ( + '[THINK]Let me think about [TOOL_CALLS]fake[ARGS]{"x":1} ' + 'and more[/THINK][TOOL_CALLS]real_fn[ARGS]{"y":2}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "real_fn" + assert json.loads(result[0]["function"]["arguments"]) == {"y": 2} + + def test_mistral_think_reasoning_no_real_call(self): + # Reasoning that merely mentions a tool call but does not emit one + # after ``[/THINK]`` yields no calls. + text = '[THINK]I might call [TOOL_CALLS]fake[ARGS]{"x":1}[/THINK]Done.' + assert parse_tool_calls_from_text(text) == [] + + def test_mistral_think_literal_in_argument_preserved(self): + # A literal ``[THINK]`` inside a real tool argument (after the call) + # must not be stripped or corrupt the parse. + import json + + text = '[TOOL_CALLS]search[ARGS]{"q":"explain the [THINK] token"}' + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert json.loads(result[0]["function"]["arguments"]) == { + "q": "explain the [THINK] token" + } + + # ── Gemma 4 ─────────────────────────────────────────────────── + + def test_gemma4_simple_call(self): + import json + + text = ( + "<|tool_call>call:get_weather{" + 'location:<|"|>Tokyo<|"|>,units:<|"|>celsius<|"|>}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"location": "Tokyo", "units": "celsius"} + + def test_gemma4_with_primitives(self): + import json + + text = ( + "<|tool_call>call:set_pref{" + "enabled:true,attempts:5,threshold:1.5,nickname:null}" + ) + result = parse_tool_calls_from_text(text) + args = json.loads(result[0]["function"]["arguments"]) + assert args == { + "enabled": True, + "attempts": 5, + "threshold": 1.5, + "nickname": None, + } + + def test_gemma4_nested_args(self): + # Gemma 4 nests dicts / lists with bare keys and ``<|"|>`` strings. + import json + + text = ( + "<|tool_call>call:search{" + 'query:<|"|>foo<|"|>,filters:{site:<|"|>example.com<|"|>,recent:true},' + 'tags:[<|"|>a<|"|>,<|"|>b<|"|>]}' + ) + result = parse_tool_calls_from_text(text) + args = json.loads(result[0]["function"]["arguments"]) + assert args["query"] == "foo" + assert args["filters"] == {"site": "example.com", "recent": True} + assert args["tags"] == ["a", "b"] + + def test_gemma4_multi_call(self): + text = ( + "<|tool_call>call:a{x:1}" "<|tool_call>call:b{y:2}" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 2 + assert result[0]["function"]["name"] == "a" + assert result[1]["function"]["name"] == "b" + + def test_gemma4_unclosed_does_not_raise(self): + # Truncated mid-stream; must not raise. + text = '<|tool_call>call:foo{x:<|"|>bar<|"|>' + result = parse_tool_calls_from_text(text) + assert isinstance(result, list) + + def test_gemma4_strip_markup_final(self): + text = "<|tool_call>call:foo{x:1}" + assert strip_tool_markup(text, final = True) == "" + + # ── Cross-format sentinels ──────────────────────────────────── + + def test_all_markers_in_tool_xml_signals(self): + # Streaming buffer wakes up on every emission marker. + from core.inference.tool_call_parser import TOOL_XML_SIGNALS + + for marker in ( + "", + "", + "[TOOL_CALLS]", + "<|tool_call>", + ): + assert ( + marker in TOOL_XML_SIGNALS + ), f"streaming loop would not wake on {marker!r}" + + def test_has_tool_signal_for_all_formats(self): + assert has_tool_signal('<|python_tag|>brave_search.call(q="x")') + assert has_tool_signal('[TOOL_CALLS] [{"name":"x"}]') + assert has_tool_signal('[TOOL_CALLS]add{"a":1}') + assert has_tool_signal("<|tool_call>call:foo{}") + + # ──────────────────────────────────────────────────────────────────── # run_safetensors_tool_loop # ──────────────────────────────────────────────────────────────────── @@ -280,6 +657,90 @@ def test_function_xml_form(self): contents = [e for e in events if e["type"] == "content"] assert "Result: 1" in contents[-1]["text"] + def test_llama3_python_tag_form(self): + # The agentic loop must recognise Llama-3's <|python_tag|> + # marker, drain the rest of the turn, and execute the call. + loop, exec_fn = _make_loop( + turns = [ + [ + "<|python_tag|>web_search.call(", + 'query="weather in Tokyo"', + ")", + ], + ["The weather is sunny."], + ], + exec_results = ["Sunny, 22C"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather in Tokyo"})] + contents = [e for e in events if e["type"] == "content"] + assert "sunny" in contents[-1]["text"].lower() + + def test_llama3_bare_json_form_fires_tool(self): + # Llama-3.1 / 3.2 emit a bare-JSON tool call + # ``{"name":..,"parameters":..}`` with NO XML signal. The loop's + # safety-net parse must still fire the tool instead of treating the + # turn as "planned without calling tools" and re-prompting the model + # into giving up. Regression for the has_tool_signal gate that + # dropped these; GGUF's llama-server parses them natively. + loop, exec_fn = _make_loop( + turns = [ + ['{"name": "web_search", "parameters": {"query": "weather in SF"}}'], + ["The weather is sunny."], + ], + exec_results = ["Sunny, 18C"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather in SF"})] + contents = [e for e in events if e["type"] == "content"] + assert "sunny" in contents[-1]["text"].lower() + + def test_mistral_pre_v11_form(self): + # Pre-v11 Mistral emission: ``[TOOL_CALLS] [{...}]``. + loop, exec_fn = _make_loop( + turns = [ + [ + '[TOOL_CALLS] [{"name":"web_search",', + '"arguments":{"query":"hi"},"id":"abc"}]', + ], + ["done"], + ], + exec_results = ["ok"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "hi"})] + # Mistral-provided ids must propagate to tool_start events. + tool_start = next(e for e in events if e["type"] == "tool_start") + assert tool_start["tool_call_id"] == "abc" + + def test_mistral_v11_form(self): + # v11+ Mistral emission: bare ``name{json}`` after the trigger. + loop, exec_fn = _make_loop( + turns = [ + ['[TOOL_CALLS]web_search{"query":"hi"}'], + ["done"], + ], + exec_results = ["ok"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "hi"})] + + def test_gemma4_form(self): + # Gemma 4 emission: ``<|tool_call>call:NAME{...}``. + loop, exec_fn = _make_loop( + turns = [ + [ + "<|tool_call>call:web_search{", + 'query:<|"|>weather<|"|>', + "}", + ], + ["sunny"], + ], + exec_results = ["Sunny, 22C"], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather"})] + def test_truncated_unclosed_tool_call(self): loop, exec_fn = _make_loop( turns = [ @@ -455,6 +916,259 @@ def test_exception_in_executor_does_not_raise(self): assert "boom" in tool_end["result"] +class TestLoopRePrompt: + """Re-prompt-on-plan-without-action parity with the GGUF path. + + When the model emits forward-looking intent ("Let me search for + that") without actually calling a tool, the loop must nudge it to + act instead of silently terminating. Up to ``_MAX_REPROMPTS`` (3) + re-prompts per request, drawn from extra iteration slots so the + caller's tool-call budget is preserved. + """ + + def test_intent_signal_triggers_reprompt(self): + # Turn 1: intent signal, no tool call. + # Turn 2 (re-prompt): proper tool call -> executes. + # Turn 3: final answer. + loop, exec_fn = _make_loop( + turns = [ + ["Let me search for that."], + [ + '{"name":"web_search","arguments":' + '{"query":"sky color"}}' + ], + ["The sky is blue."], + ], + exec_results = ["Blue (Rayleigh scattering)"], + ) + events = _collect_events(loop) + # web_search must have been called once (after the re-prompt). + assert exec_fn.calls == [("web_search", {"query": "sky color"})] + contents = [e for e in events if e["type"] == "content"] + assert contents and "blue" in contents[-1]["text"].lower() + + def test_intent_signal_without_tools_does_not_reprompt(self): + # Same intent signal but no tools enabled -- must NOT re-prompt. + loop, exec_fn = _make_loop( + turns = [["Let me think about that for a moment."]], + exec_results = [], + ) + # _make_loop hard-codes three tools; rebuild without tools. + from core.inference.safetensors_agentic import run_safetensors_tool_loop + + def _gen(_messages): + yield "Let me think about that for a moment." + + exec_fn = FakeExecuteTool([]) + events = _collect_events( + run_safetensors_tool_loop( + single_turn = _gen, + messages = [{"role": "user", "content": "hi"}], + tools = [], + execute_tool = exec_fn, + ) + ) + assert exec_fn.calls == [] + contents = [e for e in events if e["type"] == "content"] + assert contents and "think" in contents[-1]["text"].lower() + + def test_direct_answer_does_not_trigger_reprompt(self): + # Plain answer with no intent words: do NOT re-prompt. + loop, exec_fn = _make_loop( + turns = [["4"]], + exec_results = [], + ) + events = _collect_events(loop) + assert exec_fn.calls == [] + contents = [e for e in events if e["type"] == "content"] + assert contents and contents[-1]["text"].strip() == "4" + + def test_max_reprompts_capped_at_three(self): + # Model keeps stalling with intent -- after 3 re-prompts the + # loop must give up rather than burn forever. + turns = [["Let me search for that."]] * 6 # well over the cap + loop, exec_fn = _make_loop( + turns = turns, + exec_results = [], + ) + events = _collect_events(loop, max_events = 500) + # No tool ever ran, but the loop terminated cleanly. + assert exec_fn.calls == [] + statuses = [e for e in events if e["type"] == "status"] + assert statuses and statuses[-1]["text"] == "" + + def test_short_intent_below_buffer_threshold_triggers_reprompt(self): + # Short emission that never exits BUFFERING (< 32 chars + no + # marker prefix). The unified buffer-end path must still + # trigger the intent re-prompt, not silently terminate. + loop, exec_fn = _make_loop( + turns = [ + ["Let me check."], + [ + '{"name":"web_search","arguments":' + '{"query":"x"}}' + ], + ["found"], + ], + exec_results = ["..."], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "x"})] + + def test_reprompt_does_not_consume_tool_budget(self): + # max_tool_iterations=1: one re-prompt, then one real tool call, + # then the budget-exhausted final answer must still fire. If the + # re-prompt ate the slot the tool call would never run. + loop, exec_fn = _make_loop( + turns = [ + # 1. Intent stall (re-prompt 1/3). + ["Let me search for that."], + # 2. Real tool call (uses the budget slot). + [ + '{"name":"web_search","arguments":' + '{"query":"weather"}}' + ], + # 3. Budget exhausted -> nudged final answer. + ["Final: it is sunny"], + ], + exec_results = ["sunny"], + max_tool_iterations = 1, + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "weather"})] + contents = [e for e in events if e["type"] == "content"] + assert contents and "sunny" in contents[-1]["text"].lower() + + +class TestLoopCanonicalHealKey: + """Per-tool canonical heal key (``code`` for python, ``command`` for + terminal, ``query`` for everything else). Mirrors GGUF after the + PR-5615 follow-up that ported this mapping over.""" + + def test_python_bare_string_heals_to_code(self): + loop, exec_fn = _make_loop( + turns = [ + ['{"name":"python","arguments":"print(1)"}' ""], + ["done"], + ], + exec_results = ["1\n"], + ) + events = _collect_events(loop) + # The bare string must heal to {"code": "print(1)"}, not + # {"query": ...}, so the python sandbox actually executes it. + assert exec_fn.calls == [("python", {"code": "print(1)"})] + + def test_terminal_bare_string_heals_to_command(self): + loop, exec_fn = _make_loop( + turns = [ + ['{"name":"terminal","arguments":"ls -la"}' ""], + ["done"], + ], + exec_results = ["..."], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("terminal", {"command": "ls -la"})] + + def test_unknown_tool_bare_string_heals_to_query(self): + loop, exec_fn = _make_loop( + turns = [ + ['{"name":"web_search","arguments":"hello"}' ""], + ["ok"], + ], + exec_results = ["..."], + ) + events = _collect_events(loop) + assert exec_fn.calls == [("web_search", {"query": "hello"})] + + +class TestGGUFSafetensorsHealingParity: + """Pin parity between the GGUF agentic loop and the safetensors / + MLX loop so a regression on either side breaks CI.""" + + def test_gguf_imports_shared_signal_markers(self): + # The GGUF BUFFERING state machine must wake on every emission + # marker the shared parser knows -- otherwise Llama-3 / Mistral + # / Gemma 4 emissions slip past as plain prose when the + # llama-server structured channel fails. + import inspect + + from core.inference.llama_cpp import LlamaCppBackend + + src = inspect.getsource(LlamaCppBackend.generate_chat_completion_with_tools) + assert "_SHARED_TOOL_XML_SIGNALS" in src, ( + "GGUF agentic loop must reuse the shared TOOL_XML_SIGNALS " + "tuple so it wakes on all five emission formats" + ) + + def test_gguf_uses_shared_strip_helper(self): + # The GGUF stream-cleanup function must delegate to the shared + # strip_tool_markup so closed-pair markup is removed for every + # emission family (Llama-3 <|python_tag|>, Mistral [TOOL_CALLS], + # Gemma 4 <|tool_call>...). + import inspect + + from core.inference.llama_cpp import LlamaCppBackend + + src = inspect.getsource(LlamaCppBackend.generate_chat_completion_with_tools) + assert "_shared_strip_tool_markup" in src, ( + "GGUF stream cleanup must delegate to the shared " + "strip_tool_markup helper" + ) + + def test_gguf_uses_canonical_heal_keys(self): + # GGUF must heal a bare-string ``arguments`` to the same per-tool + # canonical key as safetensors -- ``code`` for python, ``command`` + # for terminal, ``query`` for everything else. + import inspect + + from core.inference.llama_cpp import LlamaCppBackend + + src = inspect.getsource(LlamaCppBackend.generate_chat_completion_with_tools) + # The canonical key dict literal must be present in the heal + # path so a Llama-3 / Mistral / Gemma 4 bare-string emission + # for python doesn't get routed as {"query": "print(1)"}. + assert '"python": "code"' in src + assert '"terminal": "command"' in src + + def test_intent_regex_matches_same_phrases_as_gguf(self): + # The intent re-prompt regex must match the SAME forward-looking + # phrases on both backends so behaviour is the same on Mac (MLX + # / safetensors) and on Linux (GGUF). + from core.inference.llama_cpp import _INTENT_SIGNAL as gguf_re + from core.inference.safetensors_agentic import ( + _INTENT_SIGNAL as sf_re, + ) + + for phrase in ( + "I'll search for that", + "I will look it up", + "Let me check", + "I am going to call the tool", + "First, I will explore", + "Here's my plan", + "Now I need to call web_search", + ): + assert gguf_re.search(phrase), f"GGUF missed {phrase!r}" + assert sf_re.search(phrase), f"safetensors missed {phrase!r}" + + for plain in ( + "4", + "Hello!", + "The sky is blue.", + "I can help with that.", + "I should mention", + "Let's go.", + ): + assert not gguf_re.search(plain), f"GGUF wrongly fired on {plain!r}" + assert not sf_re.search(plain), f"safetensors wrongly fired on {plain!r}" + + def test_max_reprompts_equal_on_both_backends(self): + from core.inference.llama_cpp import _MAX_REPROMPTS as gguf_cap + from core.inference.safetensors_agentic import _MAX_REPROMPTS as sf_cap + + assert gguf_cap == sf_cap == 3 + + class TestLoopControl: def test_cancel_event_breaks_loop(self): cancel = threading.Event() @@ -784,5 +1498,245 @@ def test_empty_or_none_returns_false(self): assert is_gpt_oss_model_name(None) is False +# ──────────────────────────────────────────────────────────────────── +# Routes-level python_tag strip (multi-line; stop on next sentinel) +# ──────────────────────────────────────────────────────────────────── + + +class TestRoutesPythonTagStrip: + """Earlier revisions of ``_TOOL_XML_RE`` in + ``studio.backend.routes.inference`` used either ``[^\\n<]*`` (5615 -- + leaked the tail of any tool call whose argument contained a literal + ``<`` like ``code="if x < 10"``) or ``[^\\n]*`` (5620 round one -- + single-line only, so the second line of + ``python.call(code="line1\\nline2")`` leaked). The current pattern + ``(?:[^<]|<(?!\\|))*`` consumes any character that is not a Llama-3 + ``<|`` sentinel start, so multi-line code, embedded JSON, and bare + ``<`` characters in code all stay inside the strip. + + The fully resolved strip is also exposed via + ``strip_tool_markup(text, final=True)`` in the parser; the + streaming path's routes-level strip is the regression-prone one + because it runs on every cumulative emission while content is + still arriving. + """ + + def _strip(self, text: str) -> str: + # Import inside the test so a routes-module import error does + # not blow up the entire test file at collection time. + from routes.inference import _strip_tool_xml + + return _strip_tool_xml(text) + + def test_single_line_python_tag_stripped(self): + # Floor: the original 5620 single-line behaviour still works. + text = '<|python_tag|>brave_search.call(query="weather")' + assert self._strip(text) == "" + + def test_python_tag_with_less_than_in_code(self): + # 5615 regression: literal ``<`` inside code must NOT terminate + # the strip early. + text = '<|python_tag|>python.call(code="if x < 10: pass")' + assert self._strip(text) == "" + + def test_python_tag_multiline_code_stripped(self): + # 5620 round-1 regression: multi-line code's second line leaked. + text = '<|python_tag|>python.call(code="line1\nline2\nline3")' + assert self._strip(text) == "" + + def test_python_tag_multiline_with_less_than(self): + # Combined: multi-line code AND literal ``<`` in code. + text = ( + '<|python_tag|>python.call(code="for i in range(10):\n' + " if i < 5:\n" + ' print(i)")' + ) + assert self._strip(text) == "" + + def test_python_tag_stops_at_eom_sentinel(self): + # Strip stops at the next Llama-3 ``<|`` sentinel so any + # trailing assistant content survives. + text = ( + '<|python_tag|>python.call(code="multi\nline")' + "<|eom_id|>final answer text" + ) + assert self._strip(text) == "<|eom_id|>final answer text" + + def test_python_tag_stops_at_eot_sentinel(self): + text = '<|python_tag|>brave_search.call(query="x")' "<|eot_id|>after" + assert self._strip(text) == "<|eot_id|>after" + + def test_python_tag_json_form_multiline_stripped(self): + # The JSON form of python_tag with newlines inside string args. + text = ( + '<|python_tag|>{"name":"python",' + '"parameters":{"code":"a = 1\nb = 2\nprint(a+b)"}}' + ) + assert self._strip(text) == "" + + def test_python_tag_with_eom_then_trailing_python_tag(self): + # Two python_tag emissions back-to-back across a sentinel: both + # should strip independently. + text = ( + '<|python_tag|>brave_search.call(query="a")' + "<|eom_id|>" + '<|python_tag|>python.call(code="x=1")' + ) + # ``<|eom_id|>`` between the two strips remains; both + # python_tag blocks are fully consumed. + assert self._strip(text) == "<|eom_id|>" + + +# ──────────────────────────────────────────────────────────────────── +# Robustness fixes uncovered while validating against vLLM / sglang. +# ──────────────────────────────────────────────────────────────────── + + +class TestParserRobustness: + def test_tool_call_json_accepts_parameters_key(self): + # Hermes wrapper around a Llama-3.2 bare-JSON object that uses + # ``parameters`` instead of ``arguments``. The bare-JSON and + # python_tag paths already accept both keys; this path now does + # too. Was extracting name only and silently dropping the args. + import json + + text = ( + "\n" + '{"name": "search", "parameters": {"q": "ramen"}}\n' + "" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "search" + assert json.loads(result[0]["function"]["arguments"]) == {"q": "ramen"} + + def test_function_xml_attribute_form(self): + # MiniCPM-5 / MiniMax-M2 attribute syntax: + # ``v``. + import json + + text = ( + '' + 'Tokyo' + "" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_function_xml_attribute_form_multi_param(self): + import json + + text = ( + '' + 'Tokyo' + 'celsius' + "" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + args = json.loads(result[0]["function"]["arguments"]) + assert args == {"city": "Tokyo", "unit": "celsius"} + + def test_function_xml_legacy_equals_form_still_works(self): + # Regression guard: the old ``v`` + # syntax must keep parsing after the regex broadening. + import json + + text = ( + "" "Tokyo" "" + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_function_attribute_form_has_tool_signal(self): + # The standalone ```` attribute form must flip + # the streaming buffer; otherwise the end-of-turn safety-net parse in + # the agentic loop is gated off and the real call is dropped. + assert has_tool_signal('') is True + + def test_function_attribute_form_strip_markup(self): + # The attribute form must also be stripped from displayed text, like + # the legacy ```` form. + text = 'result X' + assert strip_tool_markup(text, final = True) == "result" + + def test_llama3_chat_template_round_trip(self): + # Meta's official Llama-3.x chat template prefixes every + # assistant turn with + # ``<|start_header_id|>assistant<|end_header_id|>\n\n``. The + # sentinel-strip in ``_parse_llama3_bare_json`` must reach past + # the role label to the JSON body, else every round-tripped + # tool call in history silently drops. + import json + + text = ( + "<|start_header_id|>assistant<|end_header_id|>\n\n" + '{"name": "get_weather", "parameters": {"city": "Tokyo"}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_llama3_round_trip_all_roles(self): + # Same logic must work for every role the chat template inserts. + import json + + for role in ("assistant", "user", "system", "tool", "ipython"): + text = ( + f"<|start_header_id|>{role}<|end_header_id|>\n\n" + '{"name": "f", "parameters": {"x": 1}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1, f"failed for role={role}" + assert json.loads(result[0]["function"]["arguments"]) == {"x": 1} + + def test_llama3_round_trip_with_eot_prefix(self): + # Prior assistant turn closes with ``<|eot_id|>``, then the + # new header opens. Both sentinels + the role must be consumed. + import json + + text = ( + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + '{"name": "f", "parameters": {}}' + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert result[0]["function"]["name"] == "f" + + def test_function_xml_followed_by_prose(self): + # Models routinely follow a tool call with explanatory prose. + # Body must terminate at ```` even without a + # ```` wrapper, else trailing prose leaks into the + # last parameter value. + import json + + text = ( + "" + "Tokyo" + "\n\nHere is what I found." + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + def test_function_attribute_xml_followed_by_prose(self): + # Same expectation for the MiniCPM-5 attribute form. + import json + + text = ( + '' + 'Tokyo' + "\n\nLet me know if you need anything else." + ) + result = parse_tool_calls_from_text(text) + assert len(result) == 1 + assert json.loads(result[0]["function"]["arguments"]) == {"city": "Tokyo"} + + if __name__ == "__main__": pytest.main([__file__, "-v"])