diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 76234386aa..0325620b2d 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -5076,13 +5076,30 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: _effective_timeout = ( None if tool_call_timeout >= 9999 else tool_call_timeout ) - result = execute_tool( - tool_name, - arguments, - cancel_event = cancel_event, - timeout = _effective_timeout, - session_id = session_id, - ) + # Guard against the model emitting a tool not in the + # per-request advertised set: filtered MCP names, a + # built-in the caller opted out of, or a stale name + # from a prior turn. Mirrors the safetensors loop's + # allowed_tool_names check. + _allowed = { + (t.get("function") or {}).get("name") + for t in (tools or []) + if (t.get("function") or {}).get("name") + } + if _allowed and tool_name not in _allowed: + result = ( + f"Error: tool '{tool_name}' is not enabled " + "for this request. Use one of the enabled " + "tools or provide a final answer." + ) + else: + result = execute_tool( + tool_name, + arguments, + cancel_event = cancel_event, + timeout = _effective_timeout, + session_id = session_id, + ) yield { "type": "tool_end", diff --git a/studio/backend/core/inference/mcp_client.py b/studio/backend/core/inference/mcp_client.py new file mode 100644 index 0000000000..a5e614899d --- /dev/null +++ b/studio/backend/core/inference/mcp_client.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +from __future__ import annotations + +import asyncio +import json +from typing import Any, Optional + +from loggers import get_logger + +logger = get_logger(__name__) + +MCP_TOOL_PREFIX = "mcp__" + +_oauth_token_store = None + + +def parse_server_headers(server: dict) -> Optional[dict]: + raw = server.get("headers_json") + if not raw: + return None + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, ValueError): + return None + return parsed if isinstance(parsed, dict) else None + + +def _oauth_store(): + global _oauth_token_store + if _oauth_token_store is None: + from key_value.aio._utils.sanitization import AlwaysHashStrategy + from key_value.aio.stores.filetree import FileTreeStore + from utils.paths.storage_roots import ensure_dir, studio_root + + # Hash keys/collections — fastmcp uses raw URLs like https://x.com as + # keys and FileTreeStore would treat the "://" as nested directories. + _oauth_token_store = FileTreeStore( + data_directory = ensure_dir(studio_root() / "mcp-oauth-tokens"), + key_sanitization_strategy = AlwaysHashStrategy(), + collection_sanitization_strategy = AlwaysHashStrategy(), + ) + return _oauth_token_store + + +async def clear_oauth_tokens_async(url: str) -> None: + """Drop any persisted OAuth tokens for ``url``. fastmcp keys tokens by + MCP URL, so on server delete / URL change / OAuth disable we have to + clear the old credentials explicitly. Otherwise re-registering the + same URL would silently reuse the old account's token. The entire + body runs inside the protected block -- store / OAuth construction + failing must not make the delete / update route 500.""" + try: + from fastmcp.client.auth import OAuth + + auth = OAuth(mcp_url = url, token_storage = _oauth_store()) + await auth.token_storage_adapter.clear() + except Exception as exc: # noqa: BLE001 + # Cleanup is best-effort; the row delete still wins. + logger.warning("Failed to clear OAuth tokens for %s: %s", url, exc) + + +def _client(url: str, headers: Optional[dict], use_oauth: bool = False): + from fastmcp import Client + from fastmcp.client.transports import SSETransport, StreamableHttpTransport + from fastmcp.mcp_config import infer_transport_type_from_url + + auth = None + if use_oauth: + from fastmcp.client.auth import OAuth + + auth = OAuth(mcp_url = url, token_storage = _oauth_store()) + + transport_cls = ( + SSETransport + if infer_transport_type_from_url(url) == "sse" + else StreamableHttpTransport + ) + return Client(transport_cls(url = url, headers = headers or None, auth = auth)) + + +async def list_tools_async( + url: str, + headers: Optional[dict] = None, + timeout: float = 5.0, + use_oauth: bool = False, +) -> list[dict]: + async def _fetch() -> list[dict]: + async with _client(url, headers, use_oauth) as client: + tools = await client.list_tools() + return [t.model_dump(exclude_none = True) for t in tools] + + return await asyncio.wait_for(_fetch(), timeout = timeout) + + +def _flatten_result(result: Any) -> str: + parts = [] + for block in getattr(result, "content", None) or []: + text = getattr(block, "text", None) + if text: + parts.append(str(text)) + body = "\n".join(parts) + if not body: + structured = getattr(result, "structured_content", None) + body = str(structured) if structured is not None else "" + + if getattr(result, "is_error", False): + # "Error: " prefix triggers tool_call_parser's TOOL_ERROR_PREFIXES nudge. + return f"Error: {body}" if body else "Error: tool returned no content" + return body + + +def call_tool_sync( + url: str, + headers: Optional[dict], + name: str, + args: dict, + timeout: Optional[float] = 300.0, + use_oauth: bool = False, + cancel_event = None, +) -> str: + """Synchronously call an MCP tool. + + ``cancel_event``: optional ``threading.Event``. When set, the in-flight + HTTP call is cancelled and the function returns a cancellation Error. + Polled in parallel with the tool call via ``asyncio.wait`` so a /cancel + POST from the UI interrupts even mid-network-read. + """ + + async def _call() -> Any: + async with _client(url, headers, use_oauth) as client: + return await client.call_tool(name, args) + + async def _watch_cancel() -> None: + # 50 ms cadence keeps cancellation responsive without busy-looping; + # matches the cadence routes/inference.py uses for cancel watchers. + while cancel_event is not None and not cancel_event.is_set(): + await asyncio.sleep(0.05) + + async def _race() -> Any: + # Check cancellation before spawning the call task so a pre-set + # event short-circuits before opening the transport / HTTP + # connection (reviewer-reproduced race). + if cancel_event is not None and cancel_event.is_set(): + raise _MCPCancelled + call_task = asyncio.create_task(_call()) + if cancel_event is None: + return await asyncio.wait_for(call_task, timeout = timeout) + watch_task = asyncio.create_task(_watch_cancel()) + try: + done, pending = await asyncio.wait( + {call_task, watch_task}, + timeout = timeout, + return_when = asyncio.FIRST_COMPLETED, + ) + finally: + for t in (call_task, watch_task): + if not t.done(): + t.cancel() + if not done: + raise asyncio.TimeoutError + if call_task in done: + return call_task.result() + raise _MCPCancelled + + try: + result = asyncio.run(_race()) + except _MCPCancelled: + return f"Error: MCP tool '{name}' cancelled" + except asyncio.TimeoutError: + return f"Error: MCP tool '{name}' timed out after {timeout:g}s" + except Exception as exc: + logger.exception("MCP call_tool failed for %s: %s", name, exc) + return f"Error: MCP tool '{name}' failed: {exc}" + + return _flatten_result(result) + + +class _MCPCancelled(Exception): + """Internal sentinel raised when cancel_event fires before the tool returns.""" diff --git a/studio/backend/core/inference/tool_call_parser.py b/studio/backend/core/inference/tool_call_parser.py index a0ab8a2a53..2f94990623 100644 --- a/studio/backend/core/inference/tool_call_parser.py +++ b/studio/backend/core/inference/tool_call_parser.py @@ -13,13 +13,16 @@ # _TOOL_CLOSED_PATS: closed pairs only. _TOOL_ALL_PATS: also trailing # unclosed runs so truncated tails don't leak markup. +# Function-name char set tracks OpenAI's ^[a-zA-Z0-9_-]{1,64}$ so MCP +# tool names that contain a hyphen (e.g. mcp__srv__list-issues) parse +# the same as the built-in web_search/python/terminal names. _TOOL_CLOSED_PATS = [ re.compile(r".*?", re.DOTALL), - re.compile(r".*?", re.DOTALL), + re.compile(r".*?", re.DOTALL), ] _TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [ re.compile(r".*$", re.DOTALL), - re.compile(r".*$", re.DOTALL), + re.compile(r".*$", re.DOTALL), ] @@ -60,10 +63,12 @@ # Pre-compiled patterns reused by ``parse_tool_calls_from_text``. _TC_JSON_START_RE = re.compile(r"\s*\{") -_TC_FUNC_START_RE = re.compile(r"\s*") +_TC_FUNC_START_RE = re.compile(r"\s*") _TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") -_TC_PARAM_START_RE = re.compile(r"\s*") +# Parameter names can carry hyphens too (e.g. MCP tool schemas with +# `issue-number`, `repo-name`); using `\w+` here dropped those keys. +_TC_PARAM_START_RE = re.compile(r"\s*") _TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") diff --git a/studio/backend/core/inference/tools.py b/studio/backend/core/inference/tools.py index 0e9cce7c3e..1b854eacf1 100644 --- a/studio/backend/core/inference/tools.py +++ b/studio/backend/core/inference/tools.py @@ -14,6 +14,7 @@ os.environ["UNSLOTH_IS_PRESENT"] = "1" +import asyncio import random import re import shlex @@ -24,6 +25,14 @@ import threading import urllib.request +from core.inference.mcp_client import ( + MCP_TOOL_PREFIX, + call_tool_sync, + list_tools_async, + parse_server_headers, +) +from storage import mcp_servers_db + from loggers import get_logger logger = get_logger(__name__) @@ -505,6 +514,92 @@ def _get_workdir(session_id: str | None = None) -> str: ALL_TOOLS = [WEB_SEARCH_TOOL, PYTHON_TOOL, TERMINAL_TOOL] +# OpenAI's function.name regex: ^[a-zA-Z0-9_-]{1,64}$ -- enforced before +# streaming starts. MCP servers can return tool names containing '.', '/', +# spaces, etc., which the prefix scheme would forward to OpenAI verbatim +# and 400 the whole request. Validate up front and skip with a warning. +_OPENAI_FN_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]{1,64}$") + + +def _mcp_specs_for_server(server: dict, mcp_tools: list[dict]) -> list[dict]: + """Convert an MCP server's tool list into OpenAI function specs.""" + display = server.get("display_name") or server["id"] + specs: list[dict] = [] + seen_names: set[str] = set() + for tool in mcp_tools: + raw_name = tool.get("name") or "" + if not raw_name: + logger.warning("Skipping MCP tool on '%s': empty name.", display) + continue + name = f"{MCP_TOOL_PREFIX}{server['id']}__{raw_name}" + # OpenAI requires function.name ^[a-zA-Z0-9_-]{1,64}$; bad chars + # (., /, spaces, etc.) or oversized names would 400 the whole + # request. Skip + warn so the rest of the tools still ship. + if not _OPENAI_FN_NAME_RE.fullmatch(name): + logger.warning( + "Skipping MCP tool '%s' on '%s': composed name '%s' is not " + "valid OpenAI function.name (regex ^[a-zA-Z0-9_-]{1,64}$).", + raw_name, + display, + name, + ) + continue + # Same MCP server returning duplicate tool names would also 400 + # OpenAI ("tools[N].function.name duplicates ..."). Drop dupes. + if name in seen_names: + logger.warning( + "Skipping duplicate MCP tool '%s' on '%s'.", raw_name, display + ) + continue + seen_names.add(name) + specs.append( + { + "type": "function", + "function": { + "name": name, + "description": f"[{display}] {tool.get('description') or ''}".strip(), + "parameters": tool.get("inputSchema") + or {"type": "object", "properties": {}}, + }, + } + ) + return specs + + +async def get_enabled_mcp_tools() -> list[dict]: + servers = [s for s in mcp_servers_db.list_servers() if s.get("is_enabled")] + if not servers: + return [] + + # OAuth probes need minutes for first-connect/expired-token browser + # sign-in; non-OAuth probes fail fast. Matches routes/mcp_servers.py. + results = await asyncio.gather( + *( + list_tools_async( + url = s["url"], + headers = parse_server_headers(s), + timeout = 305.0 if s.get("use_oauth") else 8.0, + use_oauth = bool(s.get("use_oauth")), + ) + for s in servers + ), + return_exceptions = True, + ) + + specs: list[dict] = [] + for server, payload in zip(servers, results): + if isinstance(payload, BaseException): + logger.warning( + "MCP server '%s' (%s) discovery failed: %s", + server.get("display_name") or server["id"], + server.get("url"), + payload, + ) + continue + specs.extend(_mcp_specs_for_server(server, payload)) + return specs + + _TIMEOUT_UNSET = object() @@ -525,6 +620,25 @@ def execute_tool( f"execute_tool: name={name}, session_id={session_id}, timeout={timeout}" ) effective_timeout = _EXEC_TIMEOUT if timeout is _TIMEOUT_UNSET else timeout + if name.startswith(MCP_TOOL_PREFIX): + try: + _, server_id, tool_name = name.split("__", 2) + except ValueError: + return f"Error: malformed MCP tool name '{name}'" + server = mcp_servers_db.get_server(server_id) + if not server: + return f"Error: MCP server '{server_id}' not found" + if not server.get("is_enabled"): + return f"Error: MCP server '{server_id}' is disabled" + return call_tool_sync( + url = server["url"], + headers = parse_server_headers(server), + name = tool_name, + args = arguments, + timeout = effective_timeout, + use_oauth = bool(server.get("use_oauth")), + cancel_event = cancel_event, + ) if name == "web_search": return _web_search( arguments.get("query", ""), diff --git a/studio/backend/core/tool_healing.py b/studio/backend/core/tool_healing.py index bb61965764..e8bd7d9ea4 100644 --- a/studio/backend/core/tool_healing.py +++ b/studio/backend/core/tool_healing.py @@ -17,22 +17,25 @@ import json import re -# Pre-compiled patterns for tool XML stripping. +# Pre-compiled patterns for tool XML stripping. Hyphen in the +# function/parameter name char-class tracks OpenAI's allowed set so +# MCP tool names with dashes (mcp__srv__list-issues) and parameter +# names with dashes (`issue-number`) parse alongside the built-ins. _TOOL_CLOSED_PATS = [ re.compile(r".*?", re.DOTALL), - re.compile(r".*?", re.DOTALL), + re.compile(r".*?", re.DOTALL), ] _TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [ re.compile(r".*$", re.DOTALL), - re.compile(r".*$", re.DOTALL), + re.compile(r".*$", re.DOTALL), ] # Pre-compiled patterns for tool-call XML parsing. _TC_JSON_START_RE = re.compile(r"\s*\{") -_TC_FUNC_START_RE = re.compile(r"\s*") +_TC_FUNC_START_RE = re.compile(r"\s*") _TC_END_TAG_RE = re.compile(r"") _TC_FUNC_CLOSE_RE = re.compile(r"\s*\s*$") -_TC_PARAM_START_RE = re.compile(r"\s*") +_TC_PARAM_START_RE = re.compile(r"\s*") _TC_PARAM_CLOSE_RE = re.compile(r"\s*\s*$") diff --git a/studio/backend/main.py b/studio/backend/main.py index 689241b915..c220c1ba3c 100644 --- a/studio/backend/main.py +++ b/studio/backend/main.py @@ -122,6 +122,7 @@ def _studio_root_id() -> str: export_router, inference_router, inference_studio_router, + mcp_servers_router, models_router, providers_router, training_history_router, @@ -524,6 +525,7 @@ async def _recipes_redirect(rest: str = ""): # standard /v1/chat/completions path. app.include_router(inference_router, prefix = "/v1", tags = ["openai-compat"]) app.include_router(providers_router, prefix = "/api/providers", tags = ["providers"]) +app.include_router(mcp_servers_router, prefix = "/api/mcp/servers", tags = ["mcp"]) app.include_router(datasets_router, prefix = "/api/datasets", tags = ["datasets"]) app.include_router(data_recipe_router, prefix = "/api/data-recipe", tags = ["data-recipe"]) app.include_router(export_router, prefix = "/api/export", tags = ["export"]) diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index 0af9425fdc..a2af9af21e 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -708,6 +708,10 @@ class ChatCompletionRequest(BaseModel): "all local tools are enabled and no server-side tools are forwarded." ), ) + mcp_enabled: Optional[bool] = Field( + None, + description = "[x-unsloth] When true, append tools from every enabled MCP server to this request's tool list.", + ) auto_heal_tool_calls: Optional[bool] = Field( True, description = "[x-unsloth] Auto-detect and fix malformed tool calls from model output.", diff --git a/studio/backend/models/mcp_servers.py b/studio/backend/models/mcp_servers.py new file mode 100644 index 0000000000..c696eb0faa --- /dev/null +++ b/studio/backend/models/mcp_servers.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +from typing import Optional + +from pydantic import BaseModel, Field + + +class McpServerCreate(BaseModel): + display_name: str + url: str + headers: Optional[dict[str, str]] = None + is_enabled: bool = True + use_oauth: bool = False + + +class McpServerUpdate(BaseModel): + display_name: Optional[str] = None + url: Optional[str] = None + # Absent in request body = leave as-is; null = drop all headers; dict = set. + headers: Optional[dict[str, str]] = None + is_enabled: Optional[bool] = None + use_oauth: Optional[bool] = None + + +class McpServerResponse(BaseModel): + id: str + display_name: str + url: str + headers: dict[str, str] = Field(default_factory = dict) + is_enabled: bool = True + use_oauth: bool = False + created_at: str + updated_at: str + + +class McpServerTestRequest(BaseModel): + url: str + headers: Optional[dict[str, str]] = None + use_oauth: bool = False + + +class McpServerProbeResult(BaseModel): + ok: bool + tool_count: int = 0 + error: Optional[str] = None diff --git a/studio/backend/requirements/extras.txt b/studio/backend/requirements/extras.txt index d783975a4f..daa8982ea5 100644 --- a/studio/backend/requirements/extras.txt +++ b/studio/backend/requirements/extras.txt @@ -52,6 +52,5 @@ addict easydict einops tabulate -fastmcp>=3.0.2 openai>=2.7.2 websockets>=15.0.1 diff --git a/studio/backend/requirements/studio.txt b/studio/backend/requirements/studio.txt index 96f8816b57..d6eba73245 100644 --- a/studio/backend/requirements/studio.txt +++ b/studio/backend/requirements/studio.txt @@ -18,3 +18,4 @@ diceware ddgs cryptography>=42.0.0 httpx>=0.27.0 +fastmcp>=3.0.2 diff --git a/studio/backend/routes/__init__.py b/studio/backend/routes/__init__.py index 6bb5d15e8e..ee3ab61b6e 100644 --- a/studio/backend/routes/__init__.py +++ b/studio/backend/routes/__init__.py @@ -16,6 +16,7 @@ from routes.training_history import router as training_history_router from routes.chat_history import router as chat_history_router from routes.providers import router as providers_router +from routes.mcp_servers import router as mcp_servers_router __all__ = [ "training_router", @@ -29,4 +30,5 @@ "training_history_router", "chat_history_router", "providers_router", + "mcp_servers_router", ] diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index a156f2397c..952e786553 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -435,7 +435,9 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: # 4. tail-only `` (outer close truncated by EOS); anchored to # `\Z` so mid-text `` in user code samples survives. _TOOL_XML_RE = _re.compile( - r"<(?:tool_call|function=\w+)>.*?(?:|\Z)" + # Hyphen in the name char-class matches MCP tool names with dashes + # (mcp__srv__list-issues) which would otherwise leak past this strip. + r"<(?:tool_call|function=[\w-]+)>.*?(?:|\Z)" r"|" r"|\s*\Z", _re.DOTALL, @@ -2438,17 +2440,29 @@ async def audio_input_stream(): # ── Tool-calling path (agentic loop) ────────────────── # `_effective_enable_tools` lets `unsloth run --enable-tools/--disable-tools` # hard-override the per-request value. Without a CLI override, falls - # back to `payload.enable_tools` (existing behavior). + # back to `payload.enable_tools` (existing behavior). `mcp_enabled=true` + # also opens the tool loop so MCP-only callers do not have to flip a + # second flag, BUT must still honor a CLI `--disable-tools` policy -- + # checking the raw policy here keeps `mcp_enabled` from re-enabling + # tools that the operator explicitly forbade. + from state.tool_policy import get_tool_policy as _get_tool_policy_g + + _cli_policy = _get_tool_policy_g() + _tools_on = _effective_enable_tools(payload) + _mcp_allowed = bool(payload.mcp_enabled) and _cli_policy is not False use_tools = ( - _effective_enable_tools(payload) + (_tools_on or _mcp_allowed) and llama_backend.supports_tools and not has_gguf_image ) if use_tools: - from core.inference.tools import ALL_TOOLS + from core.inference.tools import ALL_TOOLS, get_enabled_mcp_tools - if payload.enabled_tools is not None: + if not _tools_on: + # MCP-only request: skip built-ins, leave room for MCP tools. + tools_to_use = [] + elif payload.enabled_tools is not None: tools_to_use = [ t for t in ALL_TOOLS @@ -2457,6 +2471,19 @@ async def audio_input_stream(): else: tools_to_use = ALL_TOOLS + if _mcp_allowed: + tools_to_use = tools_to_use + await get_enabled_mcp_tools() + + # Skip the tool loop when no tool actually survived, so the + # safetensors loop's "empty = allow all" semantic cannot reach + # built-in tools the caller did not opt into. Existing callers + # who omit enabled_tools still get ALL_TOOLS here, so this + # only suppresses the loop when discovery + opt-in left it + # genuinely empty. + if not tools_to_use: + use_tools = False + + if use_tools: # ── Tool-use system prompt nudge ────────────────────── _tool_names = {t["function"]["name"] for t in tools_to_use} _has_web = "web_search" in _tool_names @@ -2932,8 +2959,15 @@ async def gguf_stream_chunks(): else 25 ) + # Match the GGUF path: mcp_enabled also opens the tool loop on its own + # but must still honor a CLI `--disable-tools` policy. + from state.tool_policy import get_tool_policy as _get_tool_policy_sf + + _sf_cli_policy = _get_tool_policy_sf() + _sf_tools_on = _effective_enable_tools(payload) + _sf_mcp_allowed = bool(payload.mcp_enabled) and _sf_cli_policy is not False _sf_use_tools = ( - _effective_enable_tools(payload) + (_sf_tools_on or _sf_mcp_allowed) and _sf_features.get("supports_tools", False) and image is None and not _sf_is_gptoss @@ -2941,15 +2975,27 @@ async def gguf_stream_chunks(): ) if _sf_use_tools: - from core.inference.tools import ALL_TOOLS + from core.inference.tools import ALL_TOOLS, get_enabled_mcp_tools - if payload.enabled_tools is not None: + if not _sf_tools_on: + _sf_tools_to_use = [] + elif payload.enabled_tools is not None: _sf_tools_to_use = [ t for t in ALL_TOOLS if t["function"]["name"] in payload.enabled_tools ] else: _sf_tools_to_use = ALL_TOOLS + if _sf_mcp_allowed: + _sf_tools_to_use = _sf_tools_to_use + await get_enabled_mcp_tools() + + # Mirror the GGUF path: refuse to enter the tool loop when nothing + # survived, so a model-emitted built-in call cannot piggy-back on + # the empty allow-list. + if not _sf_tools_to_use: + _sf_use_tools = False + + if _sf_use_tools: _sf_tool_names = {t["function"]["name"] for t in _sf_tools_to_use} _sf_has_web = "web_search" in _sf_tool_names _sf_has_code = "python" in _sf_tool_names or "terminal" in _sf_tool_names diff --git a/studio/backend/routes/mcp_servers.py b/studio/backend/routes/mcp_servers.py new file mode 100644 index 0000000000..a7501d1691 --- /dev/null +++ b/studio/backend/routes/mcp_servers.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import json +import uuid +from urllib.parse import urlparse + +import structlog +from fastapi import APIRouter, Depends, HTTPException + +from auth.authentication import get_current_subject +from core.inference.mcp_client import ( + clear_oauth_tokens_async, + list_tools_async, + parse_server_headers, +) +from models.mcp_servers import ( + McpServerCreate, + McpServerProbeResult, + McpServerResponse, + McpServerTestRequest, + McpServerUpdate, +) +from storage import mcp_servers_db + +logger = structlog.get_logger(__name__) + +router = APIRouter() + + +_PROBE_TIMEOUT_SECONDS = 8.0 +# When OAuth probes need to open a browser, wait long enough for the user to +# sign in. Matches fastmcp's default OAuth callback_timeout (300 s) + slack. +_OAUTH_PROBE_TIMEOUT_SECONDS = 305.0 + + +def _validate_url(url: str) -> str: + trimmed = (url or "").strip() + if not trimmed: + raise HTTPException(status_code = 400, detail = "url must not be empty") + parsed = urlparse(trimmed) + if parsed.scheme not in ("http", "https"): + raise HTTPException( + status_code = 400, + detail = "url must start with http:// or https://", + ) + if not parsed.netloc: + raise HTTPException(status_code = 400, detail = "url is missing a host") + return trimmed + + +def _normalize_headers(headers: dict[str, str] | None) -> dict[str, str] | None: + """Trim header names, drop empties, coerce values to str. None if nothing left.""" + if not headers: + return None + out: dict[str, str] = {} + for raw_key, value in headers.items(): + key = str(raw_key).strip() + if key: + out[key] = str(value) + return out or None + + +def _row_to_response(row: dict) -> McpServerResponse: + return McpServerResponse( + id = row["id"], + display_name = row["display_name"], + url = row["url"], + headers = parse_server_headers(row) or {}, + is_enabled = bool(row["is_enabled"]), + use_oauth = bool(row.get("use_oauth")), + created_at = row["created_at"], + updated_at = row["updated_at"], + ) + + +@router.get("/", response_model = list[McpServerResponse]) +async def list_mcp_servers( + current_subject: str = Depends(get_current_subject), +): + return [_row_to_response(row) for row in mcp_servers_db.list_servers()] + + +@router.post("/", response_model = McpServerResponse, status_code = 201) +async def create_mcp_server( + payload: McpServerCreate, + current_subject: str = Depends(get_current_subject), +): + display_name = (payload.display_name or "").strip() + if not display_name: + raise HTTPException(status_code = 400, detail = "display_name must not be empty") + url = _validate_url(payload.url) + headers = _normalize_headers(payload.headers) + + server_id = uuid.uuid4().hex[:16] + mcp_servers_db.create_server( + id = server_id, + display_name = display_name, + url = url, + headers_json = json.dumps(headers) if headers else None, + is_enabled = payload.is_enabled, + use_oauth = payload.use_oauth, + ) + return _row_to_response(mcp_servers_db.get_server(server_id)) + + +def _changes_from_payload(payload: McpServerUpdate) -> dict: + sent = payload.model_fields_set + changes: dict = {} + + if "display_name" in sent: + name = (payload.display_name or "").strip() + if not name: + raise HTTPException( + status_code = 400, detail = "display_name must not be empty" + ) + changes["display_name"] = name + if "url" in sent: + changes["url"] = _validate_url(payload.url or "") + if "headers" in sent: + headers = _normalize_headers(payload.headers) + changes["headers_json"] = json.dumps(headers) if headers else None + if "is_enabled" in sent: + if payload.is_enabled is None: + raise HTTPException( + status_code = 400, detail = "is_enabled must be true or false" + ) + changes["is_enabled"] = payload.is_enabled + if "use_oauth" in sent: + if payload.use_oauth is None: + raise HTTPException( + status_code = 400, detail = "use_oauth must be true or false" + ) + changes["use_oauth"] = payload.use_oauth + return changes + + +@router.put("/{server_id}", response_model = McpServerResponse) +async def update_mcp_server( + server_id: str, + payload: McpServerUpdate, + current_subject: str = Depends(get_current_subject), +): + old = mcp_servers_db.get_server(server_id) + if not old: + raise HTTPException(status_code = 404, detail = "MCP server not found") + changes = _changes_from_payload(payload) + if not changes: + raise HTTPException(status_code = 400, detail = "No fields to update") + # Clear persisted OAuth tokens when the URL changes or OAuth is + # disabled; fastmcp keys tokens by URL and would otherwise let a + # re-pointed server silently inherit the old account's credentials. + if bool(old.get("use_oauth")) and ( + ("url" in changes and changes["url"] != old["url"]) + or changes.get("use_oauth") is False + ): + await clear_oauth_tokens_async(old["url"]) + mcp_servers_db.update_server(server_id, changes) + return _row_to_response(mcp_servers_db.get_server(server_id)) + + +@router.delete("/{server_id}", status_code = 204) +async def delete_mcp_server( + server_id: str, + current_subject: str = Depends(get_current_subject), +): + old = mcp_servers_db.get_server(server_id) + if not old: + raise HTTPException(status_code = 404, detail = "MCP server not found") + if old.get("use_oauth"): + await clear_oauth_tokens_async(old["url"]) + mcp_servers_db.delete_server(server_id) + + +@router.post("/{server_id}/refresh", response_model = McpServerProbeResult) +async def refresh_mcp_server_tools( + server_id: str, + current_subject: str = Depends(get_current_subject), +): + server = mcp_servers_db.get_server(server_id) + if not server: + raise HTTPException(status_code = 404, detail = "MCP server not found") + + use_oauth = bool(server.get("use_oauth")) + try: + tools = await list_tools_async( + url = server["url"], + headers = parse_server_headers(server), + timeout = _OAUTH_PROBE_TIMEOUT_SECONDS + if use_oauth + else _PROBE_TIMEOUT_SECONDS, + use_oauth = use_oauth, + ) + except Exception as exc: # noqa: BLE001 — surface transport+timeout errors to UI + logger.warning("MCP refresh failed", server_id = server_id, error = str(exc)) + return McpServerProbeResult(ok = False, error = str(exc)) + + return McpServerProbeResult(ok = True, tool_count = len(tools)) + + +@router.post("/test", response_model = McpServerProbeResult) +async def test_mcp_server( + payload: McpServerTestRequest, + current_subject: str = Depends(get_current_subject), +): + # URL/header validation must surface as 400 like create/update so the + # frontend's create-form pre-flight gets the same error semantics as + # the actual save call. Only catch transport/timeout errors below. + url = _validate_url(payload.url) + headers = _normalize_headers(payload.headers) + try: + tools = await list_tools_async( + url = url, + headers = headers, + timeout = _OAUTH_PROBE_TIMEOUT_SECONDS + if payload.use_oauth + else _PROBE_TIMEOUT_SECONDS, + use_oauth = payload.use_oauth, + ) + except Exception as exc: # noqa: BLE001 + return McpServerProbeResult(ok = False, error = str(exc)) + + return McpServerProbeResult(ok = True, tool_count = len(tools)) diff --git a/studio/backend/storage/mcp_servers_db.py b/studio/backend/storage/mcp_servers_db.py new file mode 100644 index 0000000000..da2fa15423 --- /dev/null +++ b/studio/backend/storage/mcp_servers_db.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import sqlite3 +import threading +from datetime import datetime, timezone +from typing import Optional + +from utils.paths import studio_db_path, ensure_dir + +_schema_lock = threading.Lock() +_schema_ready = False + + +def _ensure_schema(conn: sqlite3.Connection) -> None: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS mcp_servers ( + id TEXT NOT NULL PRIMARY KEY, + display_name TEXT NOT NULL, + url TEXT NOT NULL, + headers_json TEXT, + is_enabled INTEGER NOT NULL DEFAULT 1, + use_oauth INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """ + ) + # use_oauth was added after the first release; backfill for pre-existing DBs. + cols = { + r["name"] for r in conn.execute("PRAGMA table_info(mcp_servers)").fetchall() + } + if "use_oauth" not in cols: + conn.execute( + "ALTER TABLE mcp_servers ADD COLUMN use_oauth INTEGER NOT NULL DEFAULT 0" + ) + + +def get_connection() -> sqlite3.Connection: + global _schema_ready + db_path = studio_db_path() + ensure_dir(db_path.parent) + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + if not _schema_ready: + with _schema_lock: + if not _schema_ready: + try: + _ensure_schema(conn) + _schema_ready = True + except Exception: + conn.close() + raise + return conn + + +def create_server( + id: str, + display_name: str, + url: str, + headers_json: Optional[str] = None, + is_enabled: bool = True, + use_oauth: bool = False, +) -> None: + now = datetime.now(timezone.utc).isoformat() + conn = get_connection() + try: + conn.execute( + """ + INSERT INTO mcp_servers + (id, display_name, url, headers_json, + is_enabled, use_oauth, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + id, + display_name, + url, + headers_json, + int(is_enabled), + int(use_oauth), + now, + now, + ), + ) + conn.commit() + finally: + conn.close() + + +def update_server(id: str, changes: dict) -> bool: + """Apply column updates and bump ``updated_at``. Returns True on a hit.""" + if not changes: + return False + bool_cols = {"is_enabled", "use_oauth"} + sets, params = [], [] + for col, value in changes.items(): + sets.append(f"{col} = ?") + params.append(int(value) if col in bool_cols else value) + sets.append("updated_at = ?") + params.extend([datetime.now(timezone.utc).isoformat(), id]) + + conn = get_connection() + try: + cursor = conn.execute( + f"UPDATE mcp_servers SET {', '.join(sets)} WHERE id = ?", + params, + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def delete_server(id: str) -> bool: + conn = get_connection() + try: + cursor = conn.execute("DELETE FROM mcp_servers WHERE id = ?", (id,)) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_server(id: str) -> Optional[dict]: + conn = get_connection() + try: + row = conn.execute("SELECT * FROM mcp_servers WHERE id = ?", (id,)).fetchone() + return dict(row) if row else None + finally: + conn.close() + + +def list_servers() -> list[dict]: + conn = get_connection() + try: + rows = conn.execute("SELECT * FROM mcp_servers ORDER BY created_at").fetchall() + return [dict(row) for row in rows] + finally: + conn.close() diff --git a/studio/backend/tests/test_desktop_auth.py b/studio/backend/tests/test_desktop_auth.py index bc8788fd90..b1522dd382 100644 --- a/studio/backend/tests/test_desktop_auth.py +++ b/studio/backend/tests/test_desktop_auth.py @@ -437,6 +437,7 @@ def test_health_response_reports_desktop_capability_fields(monkeypatch): export_router = APIRouter(), inference_router = APIRouter(), inference_studio_router = APIRouter(), + mcp_servers_router = APIRouter(), models_router = APIRouter(), providers_router = APIRouter(), training_history_router = APIRouter(), diff --git a/studio/backend/tests/test_mcp_servers.py b/studio/backend/tests/test_mcp_servers.py new file mode 100644 index 0000000000..10a6eb012b --- /dev/null +++ b/studio/backend/tests/test_mcp_servers.py @@ -0,0 +1,632 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import pytest +from fastapi import HTTPException + +from storage import mcp_servers_db + + +def _reset_db(tmp_path, monkeypatch): + monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path)) + monkeypatch.setattr(mcp_servers_db, "_schema_ready", False) + + +# ── storage: mcp_servers_db ───────────────────────────────────────── + + +def test_create_and_get_server(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + mcp_servers_db.create_server( + id = "srv1", + display_name = "GitHub", + url = "https://example.com/mcp", + headers_json = '{"Authorization": "Bearer x"}', + is_enabled = True, + use_oauth = False, + ) + row = mcp_servers_db.get_server("srv1") + assert row["id"] == "srv1" + assert row["display_name"] == "GitHub" + assert row["url"] == "https://example.com/mcp" + assert row["headers_json"] == '{"Authorization": "Bearer x"}' + assert row["is_enabled"] == 1 + assert row["use_oauth"] == 0 + + +def test_list_servers_ordered_by_created_at(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + mcp_servers_db.create_server(id = "a", display_name = "A", url = "https://a/m") + mcp_servers_db.create_server(id = "b", display_name = "B", url = "https://b/m") + rows = mcp_servers_db.list_servers() + assert [r["id"] for r in rows] == ["a", "b"] + + +def test_update_server_coerces_bools(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + mcp_servers_db.create_server(id = "srv1", display_name = "A", url = "https://a/m") + assert mcp_servers_db.update_server( + "srv1", {"is_enabled": False, "use_oauth": True} + ) + row = mcp_servers_db.get_server("srv1") + assert row["is_enabled"] == 0 + assert row["use_oauth"] == 1 + + +def test_update_server_empty_changes_returns_false(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + mcp_servers_db.create_server(id = "srv1", display_name = "A", url = "https://a/m") + assert mcp_servers_db.update_server("srv1", {}) is False + + +def test_delete_server_roundtrip(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + mcp_servers_db.create_server(id = "srv1", display_name = "A", url = "https://a/m") + assert mcp_servers_db.delete_server("srv1") is True + assert mcp_servers_db.delete_server("srv1") is False + assert mcp_servers_db.get_server("srv1") is None + + +# ── routes/mcp_servers: pure helpers ──────────────────────────────── + + +def test_validate_url_accepts_http_and_https(): + from routes.mcp_servers import _validate_url + + assert _validate_url("http://example.com/mcp") == "http://example.com/mcp" + assert _validate_url("https://example.com/mcp") == "https://example.com/mcp" + assert _validate_url(" https://example.com/mcp ") == "https://example.com/mcp" + + +@pytest.mark.parametrize("bad", ["", " ", "ftp://x", "http://", "noscheme.com"]) +def test_validate_url_rejects_bad(bad): + from routes.mcp_servers import _validate_url + + with pytest.raises(HTTPException) as exc: + _validate_url(bad) + assert exc.value.status_code == 400 + + +def test_normalize_headers(): + from routes.mcp_servers import _normalize_headers + + assert _normalize_headers({" Auth ": "Bearer x", "": "ignored"}) == { + "Auth": "Bearer x" + } + assert _normalize_headers({"X": 42}) == {"X": "42"} + assert _normalize_headers({}) is None + assert _normalize_headers(None) is None + assert _normalize_headers({" ": "x"}) is None + + +def test_changes_from_payload_tristate_headers(): + from routes.mcp_servers import _changes_from_payload + from models.mcp_servers import McpServerUpdate + + # omitted → key absent + assert "headers_json" not in _changes_from_payload( + McpServerUpdate(display_name = "x") + ) + # null → stored as None (clear all headers) + assert _changes_from_payload(McpServerUpdate(headers = None))["headers_json"] is None + # dict → serialised JSON + assert ( + _changes_from_payload(McpServerUpdate(headers = {"a": "1"}))["headers_json"] + == '{"a": "1"}' + ) + + +# ── core/inference/tools: MCP wiring ──────────────────────────────── + + +def test_mcp_specs_skip_oversized_names(): + from core.inference.tools import _mcp_specs_for_server + + server = {"id": "s" * 30, "display_name": "S"} + tools = [ + {"name": "ok", "description": "fine"}, + {"name": "x" * 40, "description": "too long"}, + ] + specs = _mcp_specs_for_server(server, tools) + assert len(specs) == 1 + assert specs[0]["function"]["name"].endswith("__ok") + assert len(specs[0]["function"]["name"]) <= 64 + + +def test_execute_tool_malformed_mcp_name(): + from core.inference.tools import execute_tool + + out = execute_tool("mcp__no_double_underscore", {}) + assert out.startswith("Error: malformed MCP tool name") + + +def test_execute_tool_unknown_server(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + from core.inference.tools import execute_tool + + assert ( + execute_tool("mcp__missing__do_thing", {}) + == "Error: MCP server 'missing' not found" + ) + + +def test_execute_tool_disabled_server(tmp_path, monkeypatch): + _reset_db(tmp_path, monkeypatch) + mcp_servers_db.create_server( + id = "srv1", + display_name = "A", + url = "https://a/m", + is_enabled = False, + ) + from core.inference.tools import execute_tool + + assert ( + execute_tool("mcp__srv1__do_thing", {}) + == "Error: MCP server 'srv1' is disabled" + ) + + +def test_mcp_specs_skip_invalid_openai_function_names(): + """OpenAI requires function.name ^[a-zA-Z0-9_-]{1,64}$; tools whose + names contain '.', '/', spaces, etc. would 400 the whole request.""" + from core.inference.tools import _mcp_specs_for_server + + server = {"id": "srv", "display_name": "S"} + tools = [ + {"name": "ok"}, + {"name": "with.dot"}, + {"name": "weird/slash"}, + {"name": "has space"}, + {"name": "good-dash_ok"}, + ] + specs = _mcp_specs_for_server(server, tools) + names = {s["function"]["name"] for s in specs} + assert {"mcp__srv__ok", "mcp__srv__good-dash_ok"} == names + + +def test_mcp_specs_skip_empty_tool_name(): + from core.inference.tools import _mcp_specs_for_server + + server = {"id": "srv", "display_name": "S"} + specs = _mcp_specs_for_server(server, [{"name": "", "description": "x"}]) + assert specs == [] + + +def test_mcp_specs_drops_duplicate_names(): + """Same tool name twice from one MCP server -> OpenAI rejects the + request as 'duplicates'. Drop the duplicate before forwarding.""" + from core.inference.tools import _mcp_specs_for_server + + server = {"id": "srv", "display_name": "S"} + tools = [{"name": "echo"}, {"name": "echo"}] + specs = _mcp_specs_for_server(server, tools) + assert len(specs) == 1 + + +def test_call_tool_sync_respects_pre_set_cancel_event(monkeypatch): + """cancel_event already set before the call -> immediate Error: cancelled + without making a network round-trip.""" + import threading + from core.inference import mcp_client + + # Stub _client so the test doesn't need a real MCP server. + class _StubClient: + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + async def call_tool(self, name, args): + import asyncio as _asyncio + + await _asyncio.sleep(30) # never finishes within the test + + monkeypatch.setattr(mcp_client, "_client", lambda *a, **kw: _StubClient()) + + cancel = threading.Event() + cancel.set() + out = mcp_client.call_tool_sync( + url = "https://example/mcp", + headers = None, + name = "slow", + args = {}, + timeout = 30.0, + cancel_event = cancel, + ) + assert "cancelled" in out.lower() + + +def test_clear_oauth_tokens_async_no_op_safe(tmp_path, monkeypatch): + """clear_oauth_tokens_async on a URL with no stored token must not raise -- + the delete + update handlers call it best-effort regardless of prior state.""" + import asyncio + + monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path)) + from core.inference import mcp_client + + monkeypatch.setattr(mcp_client, "_oauth_token_store", None) + asyncio.run(mcp_client.clear_oauth_tokens_async("https://example.com/mcp")) + + +def test_delete_server_calls_oauth_cleanup_when_oauth_was_on(tmp_path, monkeypatch): + """delete_mcp_server route helper should invoke clear_oauth_tokens_async + when the deleted row had use_oauth=true.""" + import asyncio + + _reset_db(tmp_path, monkeypatch) + from core.inference import mcp_client + + monkeypatch.setattr(mcp_client, "_oauth_token_store", None) + mcp_servers_db.create_server( + id = "oauth1", + display_name = "GH", + url = "https://gh-mcp.example/mcp", + is_enabled = True, + use_oauth = True, + ) + + calls: list[str] = [] + + async def fake_clear(url): + calls.append(url) + + monkeypatch.setattr(mcp_client, "clear_oauth_tokens_async", fake_clear) + # Re-import the route's binding through the module so the patch is seen. + import routes.mcp_servers as routes_mcp + + monkeypatch.setattr(routes_mcp, "clear_oauth_tokens_async", fake_clear) + asyncio.run(routes_mcp.delete_mcp_server("oauth1", current_subject = "u")) + assert calls == ["https://gh-mcp.example/mcp"] + assert mcp_servers_db.get_server("oauth1") is None + + +def test_delete_server_skips_oauth_cleanup_when_oauth_off(tmp_path, monkeypatch): + """No OAuth token cleanup when the deleted server never had OAuth.""" + import asyncio + + _reset_db(tmp_path, monkeypatch) + from core.inference import mcp_client + import routes.mcp_servers as routes_mcp + + monkeypatch.setattr(mcp_client, "_oauth_token_store", None) + mcp_servers_db.create_server( + id = "noauth", + display_name = "Plain", + url = "https://plain/mcp", + is_enabled = True, + use_oauth = False, + ) + calls: list[str] = [] + + async def fake_clear(url): + calls.append(url) + + monkeypatch.setattr(routes_mcp, "clear_oauth_tokens_async", fake_clear) + asyncio.run(routes_mcp.delete_mcp_server("noauth", current_subject = "u")) + assert calls == [] + + +def test_update_server_clears_oauth_on_url_change(tmp_path, monkeypatch): + """Changing the URL on an OAuth server must drop the old URL's tokens + so the new URL doesn't silently inherit credentials.""" + import asyncio + + _reset_db(tmp_path, monkeypatch) + from core.inference import mcp_client + from models.mcp_servers import McpServerUpdate + import routes.mcp_servers as routes_mcp + + monkeypatch.setattr(mcp_client, "_oauth_token_store", None) + mcp_servers_db.create_server( + id = "s1", + display_name = "A", + url = "https://old/mcp", + is_enabled = True, + use_oauth = True, + ) + calls: list[str] = [] + + async def fake_clear(url): + calls.append(url) + + monkeypatch.setattr(routes_mcp, "clear_oauth_tokens_async", fake_clear) + asyncio.run( + routes_mcp.update_mcp_server( + "s1", + McpServerUpdate(url = "https://new/mcp"), + current_subject = "u", + ) + ) + assert calls == ["https://old/mcp"] + row = mcp_servers_db.get_server("s1") + assert row["url"] == "https://new/mcp" + + +def test_update_server_clears_oauth_when_oauth_disabled(tmp_path, monkeypatch): + """Flipping use_oauth false must drop the old URL's tokens.""" + import asyncio + + _reset_db(tmp_path, monkeypatch) + from core.inference import mcp_client + from models.mcp_servers import McpServerUpdate + import routes.mcp_servers as routes_mcp + + monkeypatch.setattr(mcp_client, "_oauth_token_store", None) + mcp_servers_db.create_server( + id = "s1", + display_name = "A", + url = "https://u/mcp", + is_enabled = True, + use_oauth = True, + ) + calls: list[str] = [] + + async def fake_clear(url): + calls.append(url) + + monkeypatch.setattr(routes_mcp, "clear_oauth_tokens_async", fake_clear) + asyncio.run( + routes_mcp.update_mcp_server( + "s1", + McpServerUpdate(use_oauth = False), + current_subject = "u", + ) + ) + assert calls == ["https://u/mcp"] + + +def test_changes_from_payload_rejects_null_is_enabled(): + """Explicit null for is_enabled used to hit int(None) -> TypeError 500.""" + from routes.mcp_servers import _changes_from_payload + from models.mcp_servers import McpServerUpdate + + with pytest.raises(HTTPException) as exc: + _changes_from_payload(McpServerUpdate(is_enabled = None)) + assert exc.value.status_code == 400 + + +def test_changes_from_payload_rejects_null_use_oauth(): + """Explicit null for use_oauth used to hit int(None) -> TypeError 500.""" + from routes.mcp_servers import _changes_from_payload + from models.mcp_servers import McpServerUpdate + + with pytest.raises(HTTPException) as exc: + _changes_from_payload(McpServerUpdate(use_oauth = None)) + assert exc.value.status_code == 400 + + +def test_test_endpoint_surfaces_url_validation_as_400(tmp_path, monkeypatch): + """POST /api/mcp/servers/test must 400 on invalid URL like create/update; + previously the same input returned 200 with {"ok": false}.""" + import asyncio + + _reset_db(tmp_path, monkeypatch) + from routes.mcp_servers import test_mcp_server + from models.mcp_servers import McpServerTestRequest + + with pytest.raises(HTTPException) as exc: + asyncio.run( + test_mcp_server( + McpServerTestRequest(url = "ftp://nope"), + current_subject = "u", + ) + ) + assert exc.value.status_code == 400 + + +def test_tool_xml_parser_handles_hyphenated_parameter_names(): + """MCP tool schemas commonly use hyphenated property names like + `issue-number` / `repo-name`; the XML parser's `` regex + dropped those keys. Verify hyphenated parameter names round-trip.""" + from core.inference.tool_call_parser import parse_tool_calls_from_text + import json as _json + + calls = parse_tool_calls_from_text( + "" + "Bug report" + "octocat/hello" + "" + ) + assert len(calls) == 1 + args = _json.loads(calls[0]["function"]["arguments"]) + assert args == {"issue-title": "Bug report", "repo-name": "octocat/hello"} + + +def test_tool_healing_strip_handles_hyphenated_function_names(): + """GGUF's core/tool_healing.py has its own copy of the XML strip + regex; the round-4 fix to the shared parser missed this file.""" + from core.tool_healing import strip_tool_call_markup + + out = strip_tool_call_markup( + "before " + "x after" + ) + assert out == "before after" + + +def test_gguf_allow_list_blocks_unadvertised_tool(monkeypatch): + """When the model emits a tool call not in the per-request tool list + the GGUF agentic loop must refuse to dispatch -- mirroring the + safetensors path. Previously execute_tool ran the call regardless.""" + from core.inference import tools as tools_mod + + captured: list[str] = [] + + def fake_execute(name, args, **kw): + captured.append(name) + return "executed" + + monkeypatch.setattr(tools_mod, "execute_tool", fake_execute) + + # Re-create the allow-list check inline so we can unit-test the + # behavior without spinning up llama-server. + def _gate(tools_advertised, called_name, args): + allowed = { + (t.get("function") or {}).get("name") + for t in (tools_advertised or []) + if (t.get("function") or {}).get("name") + } + if allowed and called_name not in allowed: + return "Error: tool '" + called_name + "' is not enabled" + return fake_execute(called_name, args) + + # Built-in not in advertised list -> blocked. + out = _gate( + [{"function": {"name": "mcp__srv__echo"}}], + "terminal", + {"command": "echo x"}, + ) + assert "not enabled" in out + assert captured == [] + # Tool in advertised list -> runs. + out = _gate( + [{"function": {"name": "mcp__srv__echo"}}], + "mcp__srv__echo", + {"text": "hi"}, + ) + assert out == "executed" + assert captured == ["mcp__srv__echo"] + + +def test_call_tool_sync_short_circuits_on_pre_set_cancel(monkeypatch): + """cancel_event set BEFORE call_tool_sync runs -> no HTTP request + is made. Previously the call task was created before the cancel + check, opening a transport that the watcher then had to cancel.""" + from core.inference import mcp_client + + opened: list[str] = [] + + class _StubClient: + async def __aenter__(self): + opened.append("opened") + return self + + async def __aexit__(self, *args): + return False + + async def call_tool(self, name, args): + return "ran" + + monkeypatch.setattr(mcp_client, "_client", lambda *a, **kw: _StubClient()) + + import threading + + ev = threading.Event() + ev.set() + out = mcp_client.call_tool_sync( + url = "https://example/mcp", + headers = None, + name = "x", + args = {}, + timeout = 5.0, + cancel_event = ev, + ) + assert "cancelled" in out.lower() + # The client must NOT have been opened. + assert opened == [] + + +def test_clear_oauth_tokens_swallows_constructor_errors(tmp_path, monkeypatch): + """clear_oauth_tokens_async is best-effort; an OAuth constructor + failure (e.g. missing fastmcp.client.auth) must not bubble out into + a 500 from the delete / update routes.""" + import asyncio + from core.inference import mcp_client + + monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path)) + monkeypatch.setattr(mcp_client, "_oauth_token_store", None) + + # Patch the OAuth import path to raise so the entire body fails. + class _BoomOAuth: + def __init__(self, *a, **kw): + raise RuntimeError("simulated") + + import sys as _sys + + fake_mod = type(_sys)("fastmcp.client.auth") + fake_mod.OAuth = _BoomOAuth + monkeypatch.setitem(_sys.modules, "fastmcp.client.auth", fake_mod) + # Must not raise. + asyncio.run(mcp_client.clear_oauth_tokens_async("https://x/mcp")) + + +def test_tool_xml_parser_handles_hyphenated_function_names(): + """MCP tool names are advertised as `mcp__srv__list-issues` (the regex + fix allows '-'); the XML tool-call parser must parse them too, + otherwise the model can call the tool but Studio cannot dispatch.""" + from core.inference.tool_call_parser import parse_tool_calls_from_text + + calls = parse_tool_calls_from_text( + "" + "octocat/hello" + "" + ) + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "mcp__srv__list-issues" + import json as _json + + args = _json.loads(calls[0]["function"]["arguments"]) + assert args == {"repo": "octocat/hello"} + + +def test_tool_xml_strip_handles_hyphenated_function_names(): + """routes/inference.py:_TOOL_XML_RE must strip a `` + block; otherwise hyphenated MCP tool-call XML leaks into chat history.""" + import re as _re + from pathlib import Path + + src = (Path(__file__).resolve().parent.parent / "routes/inference.py").read_text() + m = _re.search(r"_TOOL_XML_RE = _re\.compile\((.*?)\n\)", src, _re.DOTALL) + assert m, "could not extract _TOOL_XML_RE" + ns: dict = {"_re": _re} + exec(f"_TOOL_XML_RE = _re.compile({m.group(1)})", ns) + rx = ns["_TOOL_XML_RE"] + stripped = rx.sub( + "", + "before " + "x after", + ) + assert stripped == "before after" + + +def test_safetensors_agentic_empty_allowlist_still_means_allow_all(): + """Document existing contract: at the safetensors_agentic layer, + tools=[] is still treated as "no constraint" (so existing callers + work unchanged). The real fix for the MCP-only-no-discovery case + lives at the route level in inference.py, which refuses to enter + use_tools when the resolved tool list is empty.""" + import threading + from core.inference.safetensors_agentic import run_safetensors_tool_loop + + calls: list[str] = [] + + def fake_execute(name, args, **kw): + calls.append(name) + return "ran" + + iteration = {"n": 0} + + def fake_single_turn(messages): + iteration["n"] += 1 + if iteration["n"] == 1: + txt = '{"name":"python","arguments":{"code":"1"}}' + buf = "" + for ch in txt: + buf += ch + yield buf + else: + yield "done" + + list( + run_safetensors_tool_loop( + single_turn = fake_single_turn, + messages = [{"role": "user", "content": "x"}], + tools = [], + execute_tool = fake_execute, + cancel_event = threading.Event(), + max_tool_iterations = 1, + ) + ) + # Empty allow-list = run anything (preserved contract). + assert calls == [("python", {"code": "1"})] or len(calls) >= 1 diff --git a/studio/frontend/src/components/assistant-ui/tool-fallback.tsx b/studio/frontend/src/components/assistant-ui/tool-fallback.tsx index f0d6262a45..2ea87d63f3 100644 --- a/studio/frontend/src/components/assistant-ui/tool-fallback.tsx +++ b/studio/frontend/src/components/assistant-ui/tool-fallback.tsx @@ -101,6 +101,16 @@ const statusIconMap: Record = { "requires-action": AlertCircleIcon, }; +const MCP_TOOL_PREFIX = "mcp__"; + +function formatToolNameForDisplay(toolName: string): string { + if (!toolName.startsWith(MCP_TOOL_PREFIX)) return toolName; + const rest = toolName.slice(MCP_TOOL_PREFIX.length); + const sep = rest.indexOf("__"); + if (sep <= 0) return toolName; + return `${rest.slice(0, sep)} · ${rest.slice(sep + 2)}`; +} + function ToolFallbackTrigger({ toolName, status, @@ -119,6 +129,7 @@ function ToolFallbackTrigger({ const StatusIcon = statusIconMap[statusType]; const label = isCancelled ? "Cancelled tool" : "Used tool"; + const displayName = formatToolNameForDisplay(toolName); return ( {label}:{" "} - {toolName} + {displayName} {isRunning && ( {label}:{" "} - {toolName} + {displayName} )} diff --git a/studio/frontend/src/features/chat/api/chat-adapter.ts b/studio/frontend/src/features/chat/api/chat-adapter.ts index ce59429762..16f7327aeb 100644 --- a/studio/frontend/src/features/chat/api/chat-adapter.ts +++ b/studio/frontend/src/features/chat/api/chat-adapter.ts @@ -1053,6 +1053,7 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { toolsEnabled, codeToolsEnabled, imageToolsEnabled, + mcpEnabledForChat, webFetchToolsEnabled, } = runtime; const externalSelection = parseExternalModelId(params.checkpoint); @@ -1846,13 +1847,14 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { ...(supportsPreserveThinking ? { preserve_thinking: preserveThinking } : {}), - ...(supportsTools && (toolsEnabled || codeToolsEnabled) + ...(supportsTools && (toolsEnabled || codeToolsEnabled || mcpEnabledForChat) ? { enable_tools: true, enabled_tools: [ ...(toolsEnabled ? ["web_search"] : []), ...(codeToolsEnabled ? ["python", "terminal"] : []), ], + mcp_enabled: mcpEnabledForChat, auto_heal_tool_calls: useChatRuntimeStore.getState().autoHealToolCalls, max_tool_calls_per_message: diff --git a/studio/frontend/src/features/chat/api/mcp-servers-api.ts b/studio/frontend/src/features/chat/api/mcp-servers-api.ts new file mode 100644 index 0000000000..be88d12664 --- /dev/null +++ b/studio/frontend/src/features/chat/api/mcp-servers-api.ts @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import { authFetch } from "@/features/auth"; +import { formatFastApiDetail } from "@/lib/format-fastapi-error"; + +export interface McpServerConfig { + id: string; + display_name: string; + url: string; + headers: Record; + is_enabled: boolean; + use_oauth: boolean; + created_at: string; + updated_at: string; +} + +export interface McpServerProbeResult { + ok: boolean; + tool_count: number; + error: string | null; +} + +function parseErrorText(status: number, body: unknown): string { + if (body && typeof body === "object") { + const { detail, message } = body as { detail?: unknown; message?: unknown }; + const formatted = formatFastApiDetail(detail); + if (formatted) return formatted; + if (typeof message === "string" && message) return message; + } + return `Request failed (${status})`; +} + +async function mcpRequest( + path: string, + init?: { method?: string; body?: object }, +): Promise { + const response = await authFetch(`/api/mcp/servers${path}`, { + method: init?.method, + headers: init?.body ? { "Content-Type": "application/json" } : undefined, + body: init?.body ? JSON.stringify(init.body) : undefined, + }); + // 204 No Content (DELETE) has no body — calling .json() would throw. + if (response.status === 204) return undefined as T; + const json = await response.json().catch(() => null); + if (!response.ok) throw new Error(parseErrorText(response.status, json)); + return json as T; +} + +export function listMcpServers(): Promise { + return mcpRequest("/"); +} + +export function createMcpServer(payload: { + displayName: string; + url: string; + headers?: Record; + isEnabled?: boolean; + useOauth?: boolean; +}): Promise { + return mcpRequest("/", { + method: "POST", + body: { + display_name: payload.displayName, + url: payload.url, + headers: payload.headers ?? null, + is_enabled: payload.isEnabled ?? true, + use_oauth: payload.useOauth ?? false, + }, + }); +} + +export function updateMcpServer( + serverId: string, + payload: { + displayName?: string; + url?: string; + /** null = drop stored headers; omit to leave as-is */ + headers?: Record | null; + isEnabled?: boolean; + useOauth?: boolean; + }, +): Promise { + const body: Record = {}; + if (payload.displayName !== undefined) body.display_name = payload.displayName; + if (payload.url !== undefined) body.url = payload.url; + if (payload.headers !== undefined) body.headers = payload.headers; + if (payload.isEnabled !== undefined) body.is_enabled = payload.isEnabled; + if (payload.useOauth !== undefined) body.use_oauth = payload.useOauth; + return mcpRequest(`/${serverId}`, { method: "PUT", body }); +} + +export function deleteMcpServer(serverId: string): Promise { + return mcpRequest(`/${serverId}`, { method: "DELETE" }); +} + +export function refreshMcpServerTools( + serverId: string, +): Promise { + return mcpRequest(`/${serverId}/refresh`, { method: "POST" }); +} + +export function testMcpServer(payload: { + url: string; + headers?: Record; + useOauth?: boolean; +}): Promise { + return mcpRequest("/test", { + method: "POST", + body: { + url: payload.url, + headers: payload.headers ?? null, + use_oauth: payload.useOauth ?? false, + }, + }); +} diff --git a/studio/frontend/src/features/chat/chat-mcp-servers-dialog.tsx b/studio/frontend/src/features/chat/chat-mcp-servers-dialog.tsx new file mode 100644 index 0000000000..35b5aca64c --- /dev/null +++ b/studio/frontend/src/features/chat/chat-mcp-servers-dialog.tsx @@ -0,0 +1,497 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import { useCallback, useEffect, useState } from "react"; +import { toast } from "sonner"; +import { Delete02Icon, Edit03Icon, PlusSignIcon } from "@hugeicons/core-free-icons"; +import { HugeiconsIcon } from "@hugeicons/react"; +import { RefreshCwIcon } from "lucide-react"; + +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Spinner } from "@/components/ui/spinner"; +import { Switch } from "@/components/ui/switch"; +import { + type McpServerConfig, + createMcpServer, + deleteMcpServer, + listMcpServers, + refreshMcpServerTools, + testMcpServer, + updateMcpServer, +} from "./api/mcp-servers-api"; +type HeaderRow = { id: string; key: string; value: string }; + +type FormState = { + displayName: string; + url: string; + headers: HeaderRow[]; + useOauth: boolean; +}; + +const EMPTY_FORM: FormState = { + displayName: "", + url: "", + headers: [], + useOauth: false, +}; + +function newRowId(): string { + return `r_${Math.random().toString(36).slice(2, 10)}`; +} + +function headersFromObject(headers: Record): HeaderRow[] { + return Object.entries(headers).map(([k, v]) => ({ + id: newRowId(), + key: k, + value: v, + })); +} + +function headersToObject(rows: HeaderRow[]): Record | undefined { + const out: Record = {}; + for (const row of rows) { + const key = row.key.trim(); + if (!key) continue; + out[key] = row.value; + } + return Object.keys(out).length > 0 ? out : undefined; +} + +function isValidUrl(url: string): boolean { + const trimmed = url.trim(); + if (!trimmed) return false; + try { + const parsed = new URL(trimmed); + return parsed.protocol === "http:" || parsed.protocol === "https:"; + } catch { + return false; + } +} + +function HeadersEditor({ + rows, + onChange, +}: { + rows: HeaderRow[]; + onChange: (rows: HeaderRow[]) => void; +}) { + const update = (id: string, patch: Partial) => + onChange(rows.map((row) => (row.id === id ? { ...row, ...patch } : row))); + const add = () => + onChange([...rows, { id: newRowId(), key: "", value: "" }]); + const remove = (id: string) => + onChange(rows.filter((row) => row.id !== id)); + + return ( + <> +
+ + +
+ {rows.length === 0 ? ( +
+ Optional. Add an Authorization header here for servers + that require auth. +
+ ) : ( +
+ {rows.map((row) => ( +
+ update(row.id, { key: e.target.value })} + /> + update(row.id, { value: e.target.value })} + /> + +
+ ))} +
+ )} + + ); +} + +export interface ChatMcpServersDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; +} + +type View = + | { kind: "list" } + | { kind: "create" } + | { kind: "edit"; id: string }; + +export function ChatMcpServersDialog({ + open, + onOpenChange, +}: ChatMcpServersDialogProps) { + const [servers, setServers] = useState([]); + const [loading, setLoading] = useState(false); + const [view, setView] = useState({ kind: "list" }); + const [form, setForm] = useState(EMPTY_FORM); + const [saving, setSaving] = useState(false); + const [testing, setTesting] = useState(false); + const [refreshingId, setRefreshingId] = useState(null); + + const refresh = useCallback(async () => { + setLoading(true); + try { + const rows = await listMcpServers(); + setServers(rows); + } catch (err) { + toast.error("Failed to load MCP servers", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + if (!open) return; + refresh(); + }, [open, refresh]); + + function startCreate() { + setView({ kind: "create" }); + setForm(EMPTY_FORM); + } + + function startEdit(server: McpServerConfig) { + setView({ kind: "edit", id: server.id }); + setForm({ + displayName: server.display_name, + url: server.url, + headers: headersFromObject(server.headers ?? {}), + useOauth: server.use_oauth ?? false, + }); + } + + function cancelForm() { + setView({ kind: "list" }); + setForm(EMPTY_FORM); + } + + async function testConnection() { + const trimmedUrl = form.url.trim(); + if (!isValidUrl(trimmedUrl)) { + toast.error("Enter a valid http:// or https:// URL first"); + return; + } + setTesting(true); + try { + const result = await testMcpServer({ + url: trimmedUrl, + headers: headersToObject(form.headers), + useOauth: form.useOauth, + }); + if (result.ok) { + toast.success( + `Connected (${result.tool_count} tool${result.tool_count === 1 ? "" : "s"})`, + ); + } else { + toast.error("Connection failed", { + description: result.error ?? "Unknown error", + }); + } + } catch (err) { + toast.error("Connection test failed", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setTesting(false); + } + } + + async function submitForm() { + const trimmedName = form.displayName.trim(); + const trimmedUrl = form.url.trim(); + if (!trimmedName) { + toast.error("Display name is required"); + return; + } + if (!trimmedUrl) { + toast.error("URL is required"); + return; + } + if (!isValidUrl(trimmedUrl)) { + toast.error("URL must start with http:// or https://"); + return; + } + setSaving(true); + try { + const headers = headersToObject(form.headers); + if (view.kind === "edit") { + await updateMcpServer(view.id, { + displayName: trimmedName, + url: trimmedUrl, + headers: headers ?? null, + useOauth: form.useOauth, + }); + toast.success("MCP server updated"); + } else { + await createMcpServer({ + displayName: trimmedName, + url: trimmedUrl, + headers: headers, + useOauth: form.useOauth, + }); + toast.success("MCP server added"); + } + cancelForm(); + await refresh(); + } catch (err) { + toast.error("Save failed", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setSaving(false); + } + } + + async function removeServer(server: McpServerConfig) { + const ok = window.confirm(`Delete MCP server "${server.display_name}"?`); + if (!ok) return; + try { + await deleteMcpServer(server.id); + await refresh(); + } catch (err) { + toast.error("Delete failed", { + description: err instanceof Error ? err.message : String(err), + }); + } + } + + async function toggleEnabled(server: McpServerConfig, next: boolean) { + // Optimistic update so the switch doesn't snap back during the round-trip. + setServers((rows) => + rows.map((row) => + row.id === server.id ? { ...row, is_enabled: next } : row, + ), + ); + try { + await updateMcpServer(server.id, { isEnabled: next }); + } catch (err) { + setServers((rows) => + rows.map((row) => + row.id === server.id ? { ...row, is_enabled: !next } : row, + ), + ); + toast.error("Update failed", { + description: err instanceof Error ? err.message : String(err), + }); + } + } + + async function refreshTools(server: McpServerConfig) { + setRefreshingId(server.id); + try { + const result = await refreshMcpServerTools(server.id); + if (result.ok) { + toast.success( + `Refreshed "${server.display_name}" (${result.tool_count} tool${result.tool_count === 1 ? "" : "s"})`, + ); + } else { + toast.error(`Refresh failed for "${server.display_name}"`, { + description: result.error ?? "Unknown error", + }); + } + } catch (err) { + toast.error("Refresh failed", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setRefreshingId(null); + } + } + + const showForm = view.kind !== "list"; + + return ( + + + + MCP Servers + + Register remote MCP servers. + + + + {showForm ? ( +
+
+ + + setForm((prev) => ({ ...prev, displayName: e.target.value })) + } + placeholder="e.g. GitHub MCP" + /> +
+
+ + + setForm((prev) => ({ ...prev, url: e.target.value })) + } + placeholder="https://example.com/mcp" + /> +
+ +
+
+ + + For servers that require browser-based authentication + (GitHub, Linear, etc.). A browser window will open on first + connect. + +
+ + setForm((prev) => ({ ...prev, useOauth })) + } + /> +
+ + setForm((prev) => ({ ...prev, headers }))} + /> + +
+ +
+ + +
+
+
+ ) : ( +
+
+ +
+ {loading ? ( +
+ +
+ ) : servers.length === 0 ? ( +
+ No MCP servers configured yet. +
+ ) : ( +
    + {servers.map((server) => ( +
  • +
    +
    + {server.display_name} +
    +
    + {server.url} +
    +
    +
    + toggleEnabled(server, next)} + aria-label="Enable server" + /> + + + +
    +
  • + ))} +
+ )} +
+ )} +
+
+ ); +} diff --git a/studio/frontend/src/features/chat/chat-settings-sheet.tsx b/studio/frontend/src/features/chat/chat-settings-sheet.tsx index 3714fb128a..8c852b8189 100644 --- a/studio/frontend/src/features/chat/chat-settings-sheet.tsx +++ b/studio/frontend/src/features/chat/chat-settings-sheet.tsx @@ -91,6 +91,8 @@ import { providerSupportsFastMode, } from "./provider-capabilities"; import { useChatRuntimeStore } from "./stores/chat-runtime-store"; +import { ChatMcpServersDialog } from "./chat-mcp-servers-dialog"; +import { listMcpServers } from "./api/mcp-servers-api"; import type { InferenceParams } from "./types/runtime"; export { defaultInferenceParams, type Preset } from "./presets/preset-policy"; @@ -1341,6 +1343,12 @@ export function ChatSettingsPanel({ ) : null} + + {!isExternalModel ? ( + + + + ) : null} s.mcpEnabledForChat); + const setMcpEnabledForChat = useChatRuntimeStore( + (s) => s.setMcpEnabledForChat, + ); + const [enabledServerCount, setEnabledServerCount] = useState( + null, + ); + const [dialogOpen, setDialogOpen] = useState(false); + const [refreshTick, setRefreshTick] = useState(0); + + useEffect(() => { + let cancelled = false; + listMcpServers() + .then((rows) => { + if (cancelled) return; + setEnabledServerCount(rows.filter((row) => row.is_enabled).length); + }) + .catch(() => { + if (!cancelled) setEnabledServerCount(0); + }); + return () => { + cancelled = true; + }; + }, [refreshTick]); + + return ( +
+
+
+ + Use MCP Servers + + + When on, every server marked enabled in the manage dialog is + attached to this chat's tool list. + +
+ +
+
+ + {enabledServerCount === null + ? "Loading…" + : enabledServerCount === 0 + ? "No servers configured" + : `${enabledServerCount} server${enabledServerCount === 1 ? "" : "s"} enabled`} + + +
+ { + setDialogOpen(next); + if (!next) setRefreshTick((tick) => tick + 1); + }} + /> +
+ ); +} + function ChatTemplateFields() { const defaultTemplate = useChatRuntimeStore((s) => s.defaultChatTemplate); const override = useChatRuntimeStore((s) => s.chatTemplateOverride); diff --git a/studio/frontend/src/features/chat/stores/chat-runtime-store.ts b/studio/frontend/src/features/chat/stores/chat-runtime-store.ts index a71e4127a2..09c06f92d8 100644 --- a/studio/frontend/src/features/chat/stores/chat-runtime-store.ts +++ b/studio/frontend/src/features/chat/stores/chat-runtime-store.ts @@ -27,6 +27,7 @@ export const CHAT_REASONING_ENABLED_KEY = "unsloth_chat_reasoning_enabled"; export const CHAT_TOOLS_ENABLED_KEY = "unsloth_chat_tools_enabled"; export const CHAT_CODE_TOOLS_ENABLED_KEY = "unsloth_chat_code_tools_enabled"; export const CHAT_IMAGE_TOOLS_ENABLED_KEY = "unsloth_chat_image_tools_enabled"; +export const CHAT_MCP_ENABLED_KEY = "unsloth_chat_mcp_enabled"; export const CHAT_WEB_FETCH_TOOLS_ENABLED_KEY = "unsloth_chat_web_fetch_tools_enabled"; @@ -282,6 +283,7 @@ type ChatRuntimeStore = { toolsEnabled: boolean; codeToolsEnabled: boolean; imageToolsEnabled: boolean; + mcpEnabledForChat: boolean; /** * Fetch pill state, independent of `toolsEnabled` (Search). Only * consulted when `providerSupportsBuiltinWebFetch` is true. @@ -349,6 +351,7 @@ type ChatRuntimeStore = { setToolsEnabled: (enabled: boolean, options?: { persist?: boolean }) => void; setCodeToolsEnabled: (enabled: boolean) => void; setImageToolsEnabled: (enabled: boolean) => void; + setMcpEnabledForChat: (enabled: boolean) => void; setWebFetchToolsEnabled: (enabled: boolean) => void; setToolStatus: (status: string | null) => void; setGeneratingStatus: (status: string | null) => void; @@ -599,6 +602,7 @@ export const useChatRuntimeStore = create((set, get) => ({ toolsEnabled: loadBool(CHAT_TOOLS_ENABLED_KEY, false), codeToolsEnabled: loadBool(CHAT_CODE_TOOLS_ENABLED_KEY, false), imageToolsEnabled: loadBool(CHAT_IMAGE_TOOLS_ENABLED_KEY, false), + mcpEnabledForChat: loadBool(CHAT_MCP_ENABLED_KEY, false), webFetchToolsEnabled: loadBool(CHAT_WEB_FETCH_TOOLS_ENABLED_KEY, false), toolStatus: null, generatingStatus: null, @@ -875,6 +879,11 @@ export const useChatRuntimeStore = create((set, get) => ({ saveBool(CHAT_IMAGE_TOOLS_ENABLED_KEY, imageToolsEnabled); return { imageToolsEnabled }; }), + setMcpEnabledForChat: (mcpEnabledForChat) => + set(() => { + saveBool(CHAT_MCP_ENABLED_KEY, mcpEnabledForChat); + return { mcpEnabledForChat }; + }), setWebFetchToolsEnabled: (webFetchToolsEnabled) => set(() => { saveBool(CHAT_WEB_FETCH_TOOLS_ENABLED_KEY, webFetchToolsEnabled);