-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Studio: add remote MCP server support #5750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d7d6f41
9174da0
ef9d341
662dc8a
49e1042
3029469
3fb2c8d
b88f458
602d75d
d03a7e2
4ddbbce
1dc0085
63e4444
e565e10
d537a6d
3726238
2c87ded
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| # SPDX-License-Identifier: AGPL-3.0-only | ||
| # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import json | ||
| from typing import Any, Optional | ||
|
|
||
| from loggers import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| MCP_TOOL_PREFIX = "mcp__" | ||
|
|
||
| _oauth_token_store = None | ||
|
|
||
|
|
||
| def parse_server_headers(server: dict) -> Optional[dict]: | ||
| raw = server.get("headers_json") | ||
| if not raw: | ||
| return None | ||
| try: | ||
| parsed = json.loads(raw) | ||
| except (json.JSONDecodeError, ValueError): | ||
| return None | ||
| return parsed if isinstance(parsed, dict) else None | ||
|
|
||
|
|
||
| def _oauth_store(): | ||
| global _oauth_token_store | ||
| if _oauth_token_store is None: | ||
| from key_value.aio._utils.sanitization import AlwaysHashStrategy | ||
| from key_value.aio.stores.filetree import FileTreeStore | ||
| from utils.paths.storage_roots import ensure_dir, studio_root | ||
|
|
||
| # Hash keys/collections — fastmcp uses raw URLs like https://x.com as | ||
| # keys and FileTreeStore would treat the "://" as nested directories. | ||
| _oauth_token_store = FileTreeStore( | ||
| data_directory = ensure_dir(studio_root() / "mcp-oauth-tokens"), | ||
| key_sanitization_strategy = AlwaysHashStrategy(), | ||
| collection_sanitization_strategy = AlwaysHashStrategy(), | ||
| ) | ||
| return _oauth_token_store | ||
|
|
||
|
|
||
| async def clear_oauth_tokens_async(url: str) -> None: | ||
| """Drop any persisted OAuth tokens for ``url``. fastmcp keys tokens by | ||
| MCP URL, so on server delete / URL change / OAuth disable we have to | ||
| clear the old credentials explicitly. Otherwise re-registering the | ||
| same URL would silently reuse the old account's token. The entire | ||
| body runs inside the protected block -- store / OAuth construction | ||
| failing must not make the delete / update route 500.""" | ||
| try: | ||
| from fastmcp.client.auth import OAuth | ||
|
|
||
| auth = OAuth(mcp_url = url, token_storage = _oauth_store()) | ||
| await auth.token_storage_adapter.clear() | ||
| except Exception as exc: # noqa: BLE001 | ||
| # Cleanup is best-effort; the row delete still wins. | ||
| logger.warning("Failed to clear OAuth tokens for %s: %s", url, exc) | ||
|
|
||
|
|
||
| def _client(url: str, headers: Optional[dict], use_oauth: bool = False): | ||
| from fastmcp import Client | ||
| from fastmcp.client.transports import SSETransport, StreamableHttpTransport | ||
| from fastmcp.mcp_config import infer_transport_type_from_url | ||
|
|
||
| auth = None | ||
| if use_oauth: | ||
| from fastmcp.client.auth import OAuth | ||
|
|
||
| auth = OAuth(mcp_url = url, token_storage = _oauth_store()) | ||
|
|
||
| transport_cls = ( | ||
| SSETransport | ||
| if infer_transport_type_from_url(url) == "sse" | ||
| else StreamableHttpTransport | ||
| ) | ||
| return Client(transport_cls(url = url, headers = headers or None, auth = auth)) | ||
|
|
||
|
|
||
| async def list_tools_async( | ||
| url: str, | ||
| headers: Optional[dict] = None, | ||
| timeout: float = 5.0, | ||
| use_oauth: bool = False, | ||
| ) -> list[dict]: | ||
| async def _fetch() -> list[dict]: | ||
| async with _client(url, headers, use_oauth) as client: | ||
| tools = await client.list_tools() | ||
| return [t.model_dump(exclude_none = True) for t in tools] | ||
|
|
||
| return await asyncio.wait_for(_fetch(), timeout = timeout) | ||
|
|
||
|
|
||
| def _flatten_result(result: Any) -> str: | ||
| parts = [] | ||
| for block in getattr(result, "content", None) or []: | ||
| text = getattr(block, "text", None) | ||
| if text: | ||
| parts.append(str(text)) | ||
| body = "\n".join(parts) | ||
| if not body: | ||
| structured = getattr(result, "structured_content", None) | ||
| body = str(structured) if structured is not None else "" | ||
|
|
||
| if getattr(result, "is_error", False): | ||
| # "Error: " prefix triggers tool_call_parser's TOOL_ERROR_PREFIXES nudge. | ||
| return f"Error: {body}" if body else "Error: tool returned no content" | ||
| return body | ||
|
|
||
|
|
||
| def call_tool_sync( | ||
| url: str, | ||
| headers: Optional[dict], | ||
| name: str, | ||
| args: dict, | ||
| timeout: Optional[float] = 300.0, | ||
| use_oauth: bool = False, | ||
| cancel_event = None, | ||
| ) -> str: | ||
| """Synchronously call an MCP tool. | ||
|
|
||
| ``cancel_event``: optional ``threading.Event``. When set, the in-flight | ||
| HTTP call is cancelled and the function returns a cancellation Error. | ||
| Polled in parallel with the tool call via ``asyncio.wait`` so a /cancel | ||
| POST from the UI interrupts even mid-network-read. | ||
| """ | ||
|
|
||
| async def _call() -> Any: | ||
| async with _client(url, headers, use_oauth) as client: | ||
| return await client.call_tool(name, args) | ||
|
|
||
| async def _watch_cancel() -> None: | ||
| # 50 ms cadence keeps cancellation responsive without busy-looping; | ||
| # matches the cadence routes/inference.py uses for cancel watchers. | ||
| while cancel_event is not None and not cancel_event.is_set(): | ||
| await asyncio.sleep(0.05) | ||
|
|
||
| async def _race() -> Any: | ||
| # Check cancellation before spawning the call task so a pre-set | ||
| # event short-circuits before opening the transport / HTTP | ||
| # connection (reviewer-reproduced race). | ||
| if cancel_event is not None and cancel_event.is_set(): | ||
| raise _MCPCancelled | ||
| call_task = asyncio.create_task(_call()) | ||
| if cancel_event is None: | ||
| return await asyncio.wait_for(call_task, timeout = timeout) | ||
| watch_task = asyncio.create_task(_watch_cancel()) | ||
| try: | ||
| done, pending = await asyncio.wait( | ||
| {call_task, watch_task}, | ||
| timeout = timeout, | ||
| return_when = asyncio.FIRST_COMPLETED, | ||
| ) | ||
| finally: | ||
| for t in (call_task, watch_task): | ||
| if not t.done(): | ||
| t.cancel() | ||
| if not done: | ||
| raise asyncio.TimeoutError | ||
| if call_task in done: | ||
| return call_task.result() | ||
| raise _MCPCancelled | ||
|
|
||
| try: | ||
| result = asyncio.run(_race()) | ||
| except _MCPCancelled: | ||
| return f"Error: MCP tool '{name}' cancelled" | ||
| except asyncio.TimeoutError: | ||
| return f"Error: MCP tool '{name}' timed out after {timeout:g}s" | ||
| except Exception as exc: | ||
| logger.exception("MCP call_tool failed for %s: %s", name, exc) | ||
| return f"Error: MCP tool '{name}' failed: {exc}" | ||
|
|
||
| return _flatten_result(result) | ||
|
|
||
|
|
||
| class _MCPCancelled(Exception): | ||
| """Internal sentinel raised when cancel_event fires before the tool returns.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
|
|
||
| os.environ["UNSLOTH_IS_PRESENT"] = "1" | ||
|
|
||
| import asyncio | ||
| import random | ||
| import re | ||
| import shlex | ||
|
|
@@ -24,6 +25,14 @@ | |
| import threading | ||
| import urllib.request | ||
|
|
||
| from core.inference.mcp_client import ( | ||
| MCP_TOOL_PREFIX, | ||
| call_tool_sync, | ||
| list_tools_async, | ||
| parse_server_headers, | ||
| ) | ||
| from storage import mcp_servers_db | ||
|
|
||
| from loggers import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
@@ -505,6 +514,92 @@ def _get_workdir(session_id: str | None = None) -> str: | |
| ALL_TOOLS = [WEB_SEARCH_TOOL, PYTHON_TOOL, TERMINAL_TOOL] | ||
|
|
||
|
|
||
| # OpenAI's function.name regex: ^[a-zA-Z0-9_-]{1,64}$ -- enforced before | ||
| # streaming starts. MCP servers can return tool names containing '.', '/', | ||
| # spaces, etc., which the prefix scheme would forward to OpenAI verbatim | ||
| # and 400 the whole request. Validate up front and skip with a warning. | ||
| _OPENAI_FN_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]{1,64}$") | ||
|
|
||
|
|
||
| def _mcp_specs_for_server(server: dict, mcp_tools: list[dict]) -> list[dict]: | ||
| """Convert an MCP server's tool list into OpenAI function specs.""" | ||
| display = server.get("display_name") or server["id"] | ||
| specs: list[dict] = [] | ||
| seen_names: set[str] = set() | ||
| for tool in mcp_tools: | ||
| raw_name = tool.get("name") or "" | ||
| if not raw_name: | ||
| logger.warning("Skipping MCP tool on '%s': empty name.", display) | ||
| continue | ||
| name = f"{MCP_TOOL_PREFIX}{server['id']}__{raw_name}" | ||
| # OpenAI requires function.name ^[a-zA-Z0-9_-]{1,64}$; bad chars | ||
| # (., /, spaces, etc.) or oversized names would 400 the whole | ||
| # request. Skip + warn so the rest of the tools still ship. | ||
| if not _OPENAI_FN_NAME_RE.fullmatch(name): | ||
| logger.warning( | ||
| "Skipping MCP tool '%s' on '%s': composed name '%s' is not " | ||
| "valid OpenAI function.name (regex ^[a-zA-Z0-9_-]{1,64}$).", | ||
| raw_name, | ||
| display, | ||
| name, | ||
| ) | ||
| continue | ||
| # Same MCP server returning duplicate tool names would also 400 | ||
| # OpenAI ("tools[N].function.name duplicates ..."). Drop dupes. | ||
| if name in seen_names: | ||
| logger.warning( | ||
| "Skipping duplicate MCP tool '%s' on '%s'.", raw_name, display | ||
| ) | ||
| continue | ||
| seen_names.add(name) | ||
| specs.append( | ||
| { | ||
| "type": "function", | ||
| "function": { | ||
| "name": name, | ||
| "description": f"[{display}] {tool.get('description') or ''}".strip(), | ||
| "parameters": tool.get("inputSchema") | ||
| or {"type": "object", "properties": {}}, | ||
| }, | ||
| } | ||
| ) | ||
| return specs | ||
|
|
||
|
|
||
| async def get_enabled_mcp_tools() -> list[dict]: | ||
| servers = [s for s in mcp_servers_db.list_servers() if s.get("is_enabled")] | ||
| if not servers: | ||
| return [] | ||
|
|
||
| # OAuth probes need minutes for first-connect/expired-token browser | ||
| # sign-in; non-OAuth probes fail fast. Matches routes/mcp_servers.py. | ||
| results = await asyncio.gather( | ||
| *( | ||
| list_tools_async( | ||
| url = s["url"], | ||
| headers = parse_server_headers(s), | ||
| timeout = 305.0 if s.get("use_oauth") else 8.0, | ||
| use_oauth = bool(s.get("use_oauth")), | ||
| ) | ||
|
NilayYadav marked this conversation as resolved.
|
||
| for s in servers | ||
| ), | ||
| return_exceptions = True, | ||
| ) | ||
|
|
||
| specs: list[dict] = [] | ||
| for server, payload in zip(servers, results): | ||
| if isinstance(payload, BaseException): | ||
| logger.warning( | ||
| "MCP server '%s' (%s) discovery failed: %s", | ||
| server.get("display_name") or server["id"], | ||
| server.get("url"), | ||
| payload, | ||
| ) | ||
| continue | ||
| specs.extend(_mcp_specs_for_server(server, payload)) | ||
| return specs | ||
|
|
||
|
|
||
| _TIMEOUT_UNSET = object() | ||
|
|
||
|
|
||
|
|
@@ -525,6 +620,25 @@ def execute_tool( | |
| f"execute_tool: name={name}, session_id={session_id}, timeout={timeout}" | ||
| ) | ||
| effective_timeout = _EXEC_TIMEOUT if timeout is _TIMEOUT_UNSET else timeout | ||
| if name.startswith(MCP_TOOL_PREFIX): | ||
| try: | ||
| _, server_id, tool_name = name.split("__", 2) | ||
| except ValueError: | ||
| return f"Error: malformed MCP tool name '{name}'" | ||
| server = mcp_servers_db.get_server(server_id) | ||
| if not server: | ||
| return f"Error: MCP server '{server_id}' not found" | ||
| if not server.get("is_enabled"): | ||
| return f"Error: MCP server '{server_id}' is disabled" | ||
| return call_tool_sync( | ||
| url = server["url"], | ||
| headers = parse_server_headers(server), | ||
| name = tool_name, | ||
| args = arguments, | ||
| timeout = effective_timeout, | ||
| use_oauth = bool(server.get("use_oauth")), | ||
|
Comment on lines
+633
to
+639
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This new MCP branch ignores Useful? React with 👍 / 👎. |
||
| cancel_event = cancel_event, | ||
| ) | ||
| if name == "web_search": | ||
| return _web_search( | ||
| arguments.get("query", ""), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of fetching all servers and filtering in Python, use a dedicated database method to fetch only enabled servers. This improves efficiency, especially as the number of registered servers grows.
References