Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions studio/backend/core/inference/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5076,13 +5076,30 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str:
_effective_timeout = (
None if tool_call_timeout >= 9999 else tool_call_timeout
)
result = execute_tool(
tool_name,
arguments,
cancel_event = cancel_event,
timeout = _effective_timeout,
session_id = session_id,
)
# Guard against the model emitting a tool not in the
# per-request advertised set: filtered MCP names, a
# built-in the caller opted out of, or a stale name
# from a prior turn. Mirrors the safetensors loop's
# allowed_tool_names check.
_allowed = {
(t.get("function") or {}).get("name")
for t in (tools or [])
if (t.get("function") or {}).get("name")
}
if _allowed and tool_name not in _allowed:
result = (
f"Error: tool '{tool_name}' is not enabled "
"for this request. Use one of the enabled "
"tools or provide a final answer."
)
else:
result = execute_tool(
tool_name,
arguments,
cancel_event = cancel_event,
timeout = _effective_timeout,
session_id = session_id,
)

yield {
"type": "tool_end",
Expand Down
181 changes: 181 additions & 0 deletions studio/backend/core/inference/mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0

from __future__ import annotations

import asyncio
import json
from typing import Any, Optional

from loggers import get_logger

logger = get_logger(__name__)

MCP_TOOL_PREFIX = "mcp__"

_oauth_token_store = None


def parse_server_headers(server: dict) -> Optional[dict]:
raw = server.get("headers_json")
if not raw:
return None
try:
parsed = json.loads(raw)
except (json.JSONDecodeError, ValueError):
return None
return parsed if isinstance(parsed, dict) else None


def _oauth_store():
global _oauth_token_store
if _oauth_token_store is None:
from key_value.aio._utils.sanitization import AlwaysHashStrategy
from key_value.aio.stores.filetree import FileTreeStore
from utils.paths.storage_roots import ensure_dir, studio_root

# Hash keys/collections — fastmcp uses raw URLs like https://x.com as
# keys and FileTreeStore would treat the "://" as nested directories.
_oauth_token_store = FileTreeStore(
data_directory = ensure_dir(studio_root() / "mcp-oauth-tokens"),
key_sanitization_strategy = AlwaysHashStrategy(),
collection_sanitization_strategy = AlwaysHashStrategy(),
)
return _oauth_token_store


async def clear_oauth_tokens_async(url: str) -> None:
"""Drop any persisted OAuth tokens for ``url``. fastmcp keys tokens by
MCP URL, so on server delete / URL change / OAuth disable we have to
clear the old credentials explicitly. Otherwise re-registering the
same URL would silently reuse the old account's token. The entire
body runs inside the protected block -- store / OAuth construction
failing must not make the delete / update route 500."""
try:
from fastmcp.client.auth import OAuth

auth = OAuth(mcp_url = url, token_storage = _oauth_store())
await auth.token_storage_adapter.clear()
except Exception as exc: # noqa: BLE001
# Cleanup is best-effort; the row delete still wins.
logger.warning("Failed to clear OAuth tokens for %s: %s", url, exc)


def _client(url: str, headers: Optional[dict], use_oauth: bool = False):
from fastmcp import Client
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
from fastmcp.mcp_config import infer_transport_type_from_url

auth = None
if use_oauth:
from fastmcp.client.auth import OAuth

auth = OAuth(mcp_url = url, token_storage = _oauth_store())

transport_cls = (
SSETransport
if infer_transport_type_from_url(url) == "sse"
else StreamableHttpTransport
)
return Client(transport_cls(url = url, headers = headers or None, auth = auth))


async def list_tools_async(
url: str,
headers: Optional[dict] = None,
timeout: float = 5.0,
use_oauth: bool = False,
) -> list[dict]:
async def _fetch() -> list[dict]:
async with _client(url, headers, use_oauth) as client:
tools = await client.list_tools()
return [t.model_dump(exclude_none = True) for t in tools]

return await asyncio.wait_for(_fetch(), timeout = timeout)


def _flatten_result(result: Any) -> str:
parts = []
for block in getattr(result, "content", None) or []:
text = getattr(block, "text", None)
if text:
parts.append(str(text))
body = "\n".join(parts)
if not body:
structured = getattr(result, "structured_content", None)
body = str(structured) if structured is not None else ""

if getattr(result, "is_error", False):
# "Error: " prefix triggers tool_call_parser's TOOL_ERROR_PREFIXES nudge.
return f"Error: {body}" if body else "Error: tool returned no content"
return body


def call_tool_sync(
url: str,
headers: Optional[dict],
name: str,
args: dict,
timeout: Optional[float] = 300.0,
use_oauth: bool = False,
cancel_event = None,
) -> str:
"""Synchronously call an MCP tool.

``cancel_event``: optional ``threading.Event``. When set, the in-flight
HTTP call is cancelled and the function returns a cancellation Error.
Polled in parallel with the tool call via ``asyncio.wait`` so a /cancel
POST from the UI interrupts even mid-network-read.
"""

async def _call() -> Any:
async with _client(url, headers, use_oauth) as client:
return await client.call_tool(name, args)

async def _watch_cancel() -> None:
# 50 ms cadence keeps cancellation responsive without busy-looping;
# matches the cadence routes/inference.py uses for cancel watchers.
while cancel_event is not None and not cancel_event.is_set():
await asyncio.sleep(0.05)

async def _race() -> Any:
# Check cancellation before spawning the call task so a pre-set
# event short-circuits before opening the transport / HTTP
# connection (reviewer-reproduced race).
if cancel_event is not None and cancel_event.is_set():
raise _MCPCancelled
call_task = asyncio.create_task(_call())
if cancel_event is None:
return await asyncio.wait_for(call_task, timeout = timeout)
watch_task = asyncio.create_task(_watch_cancel())
try:
done, pending = await asyncio.wait(
{call_task, watch_task},
timeout = timeout,
return_when = asyncio.FIRST_COMPLETED,
)
finally:
for t in (call_task, watch_task):
if not t.done():
t.cancel()
if not done:
raise asyncio.TimeoutError
if call_task in done:
return call_task.result()
raise _MCPCancelled

try:
result = asyncio.run(_race())
except _MCPCancelled:
return f"Error: MCP tool '{name}' cancelled"
except asyncio.TimeoutError:
return f"Error: MCP tool '{name}' timed out after {timeout:g}s"
except Exception as exc:
logger.exception("MCP call_tool failed for %s: %s", name, exc)
return f"Error: MCP tool '{name}' failed: {exc}"

return _flatten_result(result)


class _MCPCancelled(Exception):
"""Internal sentinel raised when cancel_event fires before the tool returns."""
13 changes: 9 additions & 4 deletions studio/backend/core/inference/tool_call_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@

# _TOOL_CLOSED_PATS: closed pairs only. _TOOL_ALL_PATS: also trailing
# unclosed runs so truncated tails don't leak markup.
# Function-name char set tracks OpenAI's ^[a-zA-Z0-9_-]{1,64}$ so MCP
# tool names that contain a hyphen (e.g. mcp__srv__list-issues) parse
# the same as the built-in web_search/python/terminal names.
_TOOL_CLOSED_PATS = [
re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL),
re.compile(r"<function=\w+>.*?</function>", re.DOTALL),
re.compile(r"<function=[\w-]+>.*?</function>", re.DOTALL),
]
_TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [
re.compile(r"<tool_call>.*$", re.DOTALL),
re.compile(r"<function=\w+>.*$", re.DOTALL),
re.compile(r"<function=[\w-]+>.*$", re.DOTALL),
]


Expand Down Expand Up @@ -60,10 +63,12 @@

# Pre-compiled patterns reused by ``parse_tool_calls_from_text``.
_TC_JSON_START_RE = re.compile(r"<tool_call>\s*\{")
_TC_FUNC_START_RE = re.compile(r"<function=(\w+)>\s*")
_TC_FUNC_START_RE = re.compile(r"<function=([\w-]+)>\s*")
_TC_END_TAG_RE = re.compile(r"</tool_call>")
_TC_FUNC_CLOSE_RE = re.compile(r"\s*</function>\s*$")
_TC_PARAM_START_RE = re.compile(r"<parameter=(\w+)>\s*")
# Parameter names can carry hyphens too (e.g. MCP tool schemas with
# `issue-number`, `repo-name`); using `\w+` here dropped those keys.
_TC_PARAM_START_RE = re.compile(r"<parameter=([\w-]+)>\s*")
_TC_PARAM_CLOSE_RE = re.compile(r"\s*</parameter>\s*$")


Expand Down
114 changes: 114 additions & 0 deletions studio/backend/core/inference/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

os.environ["UNSLOTH_IS_PRESENT"] = "1"

import asyncio
import random
import re
import shlex
Expand All @@ -24,6 +25,14 @@
import threading
import urllib.request

from core.inference.mcp_client import (
MCP_TOOL_PREFIX,
call_tool_sync,
list_tools_async,
parse_server_headers,
)
from storage import mcp_servers_db

from loggers import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -505,6 +514,92 @@ def _get_workdir(session_id: str | None = None) -> str:
ALL_TOOLS = [WEB_SEARCH_TOOL, PYTHON_TOOL, TERMINAL_TOOL]


# OpenAI's function.name regex: ^[a-zA-Z0-9_-]{1,64}$ -- enforced before
# streaming starts. MCP servers can return tool names containing '.', '/',
# spaces, etc., which the prefix scheme would forward to OpenAI verbatim
# and 400 the whole request. Validate up front and skip with a warning.
_OPENAI_FN_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")


def _mcp_specs_for_server(server: dict, mcp_tools: list[dict]) -> list[dict]:
"""Convert an MCP server's tool list into OpenAI function specs."""
display = server.get("display_name") or server["id"]
specs: list[dict] = []
seen_names: set[str] = set()
for tool in mcp_tools:
raw_name = tool.get("name") or ""
if not raw_name:
logger.warning("Skipping MCP tool on '%s': empty name.", display)
continue
name = f"{MCP_TOOL_PREFIX}{server['id']}__{raw_name}"
# OpenAI requires function.name ^[a-zA-Z0-9_-]{1,64}$; bad chars
# (., /, spaces, etc.) or oversized names would 400 the whole
# request. Skip + warn so the rest of the tools still ship.
if not _OPENAI_FN_NAME_RE.fullmatch(name):
logger.warning(
"Skipping MCP tool '%s' on '%s': composed name '%s' is not "
"valid OpenAI function.name (regex ^[a-zA-Z0-9_-]{1,64}$).",
raw_name,
display,
name,
)
continue
# Same MCP server returning duplicate tool names would also 400
# OpenAI ("tools[N].function.name duplicates ..."). Drop dupes.
if name in seen_names:
logger.warning(
"Skipping duplicate MCP tool '%s' on '%s'.", raw_name, display
)
continue
seen_names.add(name)
specs.append(
{
"type": "function",
"function": {
"name": name,
"description": f"[{display}] {tool.get('description') or ''}".strip(),
"parameters": tool.get("inputSchema")
or {"type": "object", "properties": {}},
},
}
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of fetching all servers and filtering in Python, use a dedicated database method to fetch only enabled servers. This improves efficiency, especially as the number of registered servers grows.

Suggested change
)
servers = mcp_servers_db.list_enabled_servers()
References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

return specs


async def get_enabled_mcp_tools() -> list[dict]:
servers = [s for s in mcp_servers_db.list_servers() if s.get("is_enabled")]
if not servers:
return []

# OAuth probes need minutes for first-connect/expired-token browser
# sign-in; non-OAuth probes fail fast. Matches routes/mcp_servers.py.
results = await asyncio.gather(
*(
list_tools_async(
url = s["url"],
headers = parse_server_headers(s),
timeout = 305.0 if s.get("use_oauth") else 8.0,
use_oauth = bool(s.get("use_oauth")),
)
Comment thread
NilayYadav marked this conversation as resolved.
for s in servers
),
return_exceptions = True,
)

specs: list[dict] = []
for server, payload in zip(servers, results):
if isinstance(payload, BaseException):
logger.warning(
"MCP server '%s' (%s) discovery failed: %s",
server.get("display_name") or server["id"],
server.get("url"),
payload,
)
continue
specs.extend(_mcp_specs_for_server(server, payload))
return specs


_TIMEOUT_UNSET = object()


Expand All @@ -525,6 +620,25 @@ def execute_tool(
f"execute_tool: name={name}, session_id={session_id}, timeout={timeout}"
)
effective_timeout = _EXEC_TIMEOUT if timeout is _TIMEOUT_UNSET else timeout
if name.startswith(MCP_TOOL_PREFIX):
try:
_, server_id, tool_name = name.split("__", 2)
except ValueError:
return f"Error: malformed MCP tool name '{name}'"
server = mcp_servers_db.get_server(server_id)
if not server:
return f"Error: MCP server '{server_id}' not found"
if not server.get("is_enabled"):
return f"Error: MCP server '{server_id}' is disabled"
return call_tool_sync(
url = server["url"],
headers = parse_server_headers(server),
name = tool_name,
args = arguments,
timeout = effective_timeout,
use_oauth = bool(server.get("use_oauth")),
Comment on lines +633 to +639
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate cancellation into MCP tool execution

This new MCP branch ignores cancel_event, so an in-flight remote tool call cannot be interrupted when the user cancels/disconnects; the worker thread stays blocked until timeout (default up to 300s). In the tool-streaming paths, cancellation is polled between next() calls, so this blocking call delays teardown and can tie up worker capacity under slow/hung MCP servers. Please thread cancellation through the MCP call path (or use shorter cancellable waits) to match existing tool behavior.

Useful? React with 👍 / 👎.

cancel_event = cancel_event,
)
if name == "web_search":
return _web_search(
arguments.get("query", ""),
Expand Down
Loading
Loading