Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions studio/backend/core/inference/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import asyncio
import json
import os
import shlex
import sys
from typing import Any, Optional

from loggers import get_logger
Expand All @@ -16,7 +19,55 @@
_oauth_token_store = None


def is_stdio(address: str) -> bool:
"""A non-HTTP address is a local stdio command, e.g.
'npx -y @modelcontextprotocol/server-filesystem /path'."""
return not address.strip().lower().startswith(("http://", "https://"))


def parse_stdio_command(address: str) -> list[str]:
"""Split a stdio command line into argv. Shared by route validation and the
transport so both agree on quoting (notably Windows backslash paths)."""
posix = sys.platform != "win32"
parts = shlex.split(address, posix = posix)
if not posix:
# posix=False keeps backslash paths intact but also keeps the surrounding
# quotes on a token. Strip a matched pair so the argv reaches the
# subprocess clean ('"C:\\Program Files\\node"' -> C:\\Program Files\\node).
parts = [
p[1:-1] if len(p) >= 2 and p[0] == p[-1] and p[0] in "\"'" else p
for p in parts
]
return parts


def stdio_mcp_enabled() -> bool:
"""stdio MCP servers spawn local processes as the backend user (and bypass
the python/terminal sandbox), so they are only allowed when the backend
host is the user's own machine. The Tauri desktop app sets
UNSLOTH_STUDIO_ALLOW_STDIO_MCP=1 (see main.py); advanced localhost /
self-hosted users can opt in with the same variable. It stays off for
Colab and any network (0.0.0.0) bind."""
return os.environ.get("UNSLOTH_STUDIO_ALLOW_STDIO_MCP") == "1"


# Probe timeouts for discovering a server's tool list. OAuth needs minutes for
# first-connect/expired-token browser sign-in; stdio allows for first-run
# package download (e.g. `npx -y ...`); HTTP fails fast.
_HTTP_PROBE_TIMEOUT = 8.0
_OAUTH_PROBE_TIMEOUT = 305.0
_STDIO_PROBE_TIMEOUT = 60.0


def probe_timeout(address: str, use_oauth: bool) -> float:
if use_oauth:
return _OAUTH_PROBE_TIMEOUT
return _STDIO_PROBE_TIMEOUT if is_stdio(address) else _HTTP_PROBE_TIMEOUT


def parse_server_headers(server: dict) -> Optional[dict]:
"""Parsed headers_json. For stdio servers this dict is the process
environment instead of HTTP headers (see _client)."""
raw = server.get("headers_json")
if not raw:
return None
Expand Down Expand Up @@ -63,6 +114,19 @@ async def clear_oauth_tokens_async(url: str) -> None:

def _client(url: str, headers: Optional[dict], use_oauth: bool = False):
from fastmcp import Client

if is_stdio(url):
from fastmcp.client.transports import StdioTransport

parts = parse_stdio_command(url)
if not parts:
raise ValueError(f"Empty stdio command: {url!r}")
# stdio env vars ride the (HTTP-only) headers field. The MCP SDK merges
# them over its default safe env (PATH etc.), so pass them through as-is.
return Client(
StdioTransport(command = parts[0], args = parts[1:], env = headers or None)
)
Comment thread
oobabooga marked this conversation as resolved.

from fastmcp.client.transports import SSETransport, StreamableHttpTransport
from fastmcp.mcp_config import infer_transport_type_from_url

Expand Down
13 changes: 10 additions & 3 deletions studio/backend/core/inference/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
from core.inference.mcp_client import (
MCP_TOOL_PREFIX,
call_tool_sync,
is_stdio,
list_tools_async,
parse_server_headers,
probe_timeout,
stdio_mcp_enabled,
)
from storage import mcp_servers_db

Expand Down Expand Up @@ -568,17 +571,19 @@ def _mcp_specs_for_server(server: dict, mcp_tools: list[dict]) -> list[dict]:

async def get_enabled_mcp_tools() -> list[dict]:
servers = [s for s in mcp_servers_db.list_servers() if s.get("is_enabled")]
# Never spawn stdio servers when stdio is disabled on this host (e.g. a DB
# carried over from a desktop install onto a Colab / network deployment).
if not stdio_mcp_enabled():
servers = [s for s in servers if not is_stdio(s["url"])]
if not servers:
return []

# OAuth probes need minutes for first-connect/expired-token browser
# sign-in; non-OAuth probes fail fast. Matches routes/mcp_servers.py.
results = await asyncio.gather(
*(
list_tools_async(
url = s["url"],
headers = parse_server_headers(s),
timeout = 305.0 if s.get("use_oauth") else 8.0,
timeout = probe_timeout(s["url"], bool(s.get("use_oauth"))),
use_oauth = bool(s.get("use_oauth")),
)
for s in servers
Expand Down Expand Up @@ -630,6 +635,8 @@ def execute_tool(
return f"Error: MCP server '{server_id}' not found"
if not server.get("is_enabled"):
return f"Error: MCP server '{server_id}' is disabled"
if is_stdio(server["url"]) and not stdio_mcp_enabled():
return f"Error: stdio MCP server '{server_id}' is disabled on this host"
return call_tool_sync(
url = server["url"],
headers = parse_server_headers(server),
Expand Down
5 changes: 5 additions & 0 deletions studio/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def _load_desktop_owner() -> dict[str, str] | None:

_DESKTOP_OWNER = _load_desktop_owner()

# The Tauri desktop app runs the backend on the owner's own machine, so local
# stdio MCP servers are safe there. setdefault lets an explicit "0" opt out.
if _DESKTOP_OWNER:
os.environ.setdefault("UNSLOTH_STUDIO_ALLOW_STDIO_MCP", "1")


def _desktop_owner() -> dict[str, str] | None:
return _DESKTOP_OWNER
Expand Down
36 changes: 24 additions & 12 deletions studio/backend/routes/mcp_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from auth.authentication import get_current_subject
from core.inference.mcp_client import (
clear_oauth_tokens_async,
is_stdio,
list_tools_async,
parse_server_headers,
parse_stdio_command,
probe_timeout,
stdio_mcp_enabled,
)
from models.mcp_servers import (
McpServerCreate,
Expand All @@ -28,16 +32,22 @@
router = APIRouter()


_PROBE_TIMEOUT_SECONDS = 8.0
# When OAuth probes need to open a browser, wait long enough for the user to
# sign in. Matches fastmcp's default OAuth callback_timeout (300 s) + slack.
_OAUTH_PROBE_TIMEOUT_SECONDS = 305.0


def _validate_url(url: str) -> str:
trimmed = (url or "").strip()
if not trimmed:
raise HTTPException(status_code = 400, detail = "url must not be empty")
# When stdio is enabled on this host, a non-HTTP value is a local command.
# Reuse this field so stdio servers ride the existing CRUD/storage with no
# schema change. When stdio is disabled the value falls through to the
# http-only validation below, so non-HTTP input is just a bad URL (400).
if stdio_mcp_enabled() and is_stdio(trimmed):
try:
parts = parse_stdio_command(trimmed)
except ValueError as exc:
raise HTTPException(status_code = 400, detail = f"Invalid command: {exc}")
if not parts or not parts[0].strip():
raise HTTPException(status_code = 400, detail = "command must not be empty")
return trimmed
parsed = urlparse(trimmed)
if parsed.scheme not in ("http", "https"):
raise HTTPException(
Expand Down Expand Up @@ -180,15 +190,19 @@ async def refresh_mcp_server_tools(
server = mcp_servers_db.get_server(server_id)
if not server:
raise HTTPException(status_code = 404, detail = "MCP server not found")
# Refresh uses the stored address, so re-check the stdio gate here too: a
# stdio row from a desktop DB must not spawn on a hosted/network host.
if is_stdio(server["url"]) and not stdio_mcp_enabled():
raise HTTPException(
status_code = 400, detail = "stdio MCP servers are disabled on this host"
)

use_oauth = bool(server.get("use_oauth"))
try:
tools = await list_tools_async(
url = server["url"],
headers = parse_server_headers(server),
timeout = _OAUTH_PROBE_TIMEOUT_SECONDS
if use_oauth
else _PROBE_TIMEOUT_SECONDS,
timeout = probe_timeout(server["url"], use_oauth),
use_oauth = use_oauth,
)
except Exception as exc: # noqa: BLE001 — surface transport+timeout errors to UI
Expand All @@ -212,9 +226,7 @@ async def test_mcp_server(
tools = await list_tools_async(
url = url,
headers = headers,
timeout = _OAUTH_PROBE_TIMEOUT_SECONDS
if payload.use_oauth
else _PROBE_TIMEOUT_SECONDS,
timeout = probe_timeout(url, payload.use_oauth),
use_oauth = payload.use_oauth,
)
except Exception as exc: # noqa: BLE001
Expand Down
Loading
Loading