Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions studio/backend/core/data_recipe/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,21 @@ def build_mcp_providers(
) -> list:
from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider # pyright: ignore[reportMissingImports]

# Same gate as the chat MCP path: stdio providers spawn a local subprocess,
# so only build them when this host allows it (desktop / explicit opt-in).
# Skip them otherwise so a recipe carried onto a hosted host cannot spawn.
from core.inference.mcp_client import stdio_mcp_enabled

stdio_allowed = stdio_mcp_enabled()

providers: list[MCPProvider | LocalStdioMCPProvider] = []
for provider in recipe.get("mcp_providers", []):
if not isinstance(provider, dict):
continue
provider_type = provider.get("provider_type")
if provider_type == "stdio":
if not stdio_allowed:
continue
env = provider.get("env")
if not isinstance(env, dict):
env = {}
Expand Down
15 changes: 12 additions & 3 deletions studio/backend/core/inference/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,24 @@ def _client(url: str, headers: Optional[dict], use_oauth: bool = False):
from fastmcp import Client

if is_stdio(url):
# Belt-and-suspenders: never spawn unless stdio is enabled on this host.
if not stdio_mcp_enabled():
raise PermissionError("stdio MCP servers are disabled on this host")
from fastmcp.client.transports import StdioTransport

parts = parse_stdio_command(url)
if not parts:
raise ValueError(f"Empty stdio command: {url!r}")
# stdio env vars ride the (HTTP-only) headers field. The MCP SDK merges
# them over its default safe env (PATH etc.), so pass them through as-is.
# env vars ride the headers field (merged over the SDK's safe default env).
# keep_alive=False tears the subprocess down on exit, so a one-shot
# probe/tool call never leaves an orphan process.
return Client(
StdioTransport(command = parts[0], args = parts[1:], env = headers or None)
StdioTransport(
command = parts[0],
args = parts[1:],
env = headers or None,
keep_alive = False,
)
)

from fastmcp.client.transports import SSETransport, StreamableHttpTransport
Expand Down
10 changes: 10 additions & 0 deletions studio/backend/routes/data_recipe/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,18 @@ def list_mcp_tools(payload: McpToolsListRequest) -> McpToolsListResponse:
providers: list[McpToolsProviderResult] = []
tool_to_providers: dict[str, list[str]] = defaultdict(list)

from core.inference.mcp_client import stdio_mcp_enabled

for provider_payload in payload.mcp_providers:
provider_name = str(provider_payload.get("name", "")).strip()
if provider_payload.get("provider_type") == "stdio" and not stdio_mcp_enabled():
providers.append(
McpToolsProviderResult(
name = provider_name,
error = "Local (stdio) MCP servers are disabled on this host.",
)
)
continue
built = build_mcp_providers({"mcp_providers": [provider_payload]})
if len(built) != 1:
providers.append(
Expand Down
25 changes: 24 additions & 1 deletion studio/backend/routes/mcp_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def _validate_url(url: str) -> str:
raise HTTPException(status_code = 400, detail = f"Invalid command: {exc}")
if not parts or not parts[0].strip():
raise HTTPException(status_code = 400, detail = "command must not be empty")
if "://" in parts[0]:
# A URL-scheme first token is a mistyped URL, not a command. Reject
# it cleanly instead of exec-ing it (mirrors the frontend check).
raise HTTPException(
status_code = 400,
detail = "Enter an http(s):// URL, or a local command whose "
"first token is an executable (not a URL).",
)
return trimmed
parsed = urlparse(trimmed)
if parsed.scheme not in ("http", "https"):
Expand Down Expand Up @@ -101,6 +109,9 @@ async def create_mcp_server(
raise HTTPException(status_code = 400, detail = "display_name must not be empty")
url = _validate_url(payload.url)
headers = _normalize_headers(payload.headers)
# OAuth is HTTP-only; force it off for stdio commands so a stale flag can't
# push the probe onto the 305s OAuth timeout. Backend is the enforcer.
use_oauth = payload.use_oauth and not is_stdio(url)

server_id = uuid.uuid4().hex[:16]
mcp_servers_db.create_server(
Expand All @@ -109,7 +120,7 @@ async def create_mcp_server(
url = url,
headers_json = json.dumps(headers) if headers else None,
is_enabled = payload.is_enabled,
use_oauth = payload.use_oauth,
use_oauth = use_oauth,
)
return _row_to_response(mcp_servers_db.get_server(server_id))

Expand Down Expand Up @@ -142,6 +153,9 @@ def _changes_from_payload(payload: McpServerUpdate) -> dict:
status_code = 400, detail = "use_oauth must be true or false"
)
changes["use_oauth"] = payload.use_oauth
# stdio is OAuth-less: drop a stale OAuth flag when switching to a command.
if "url" in changes and is_stdio(changes["url"]):
changes["use_oauth"] = False
Comment on lines +157 to +158

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep OAuth disabled for existing stdio updates

When the stored row is already a stdio command, a client can still PUT only {"use_oauth": true}; this guard only runs when the request also includes url, so changes preserves use_oauth=True and persists it. Subsequent refresh/discovery for that server pass the stored flag into probe_timeout(..., use_oauth), so the stdio probe takes the 305s OAuth timeout path instead of the intended 60s stdio path. Normalize against the effective URL (old URL unless a new one was supplied) before saving.

Useful? React with 👍 / 👎.

return changes


Expand All @@ -157,6 +171,15 @@ async def update_mcp_server(
changes = _changes_from_payload(payload)
if not changes:
raise HTTPException(status_code = 400, detail = "No fields to update")
# headers == HTTP headers (remote) or env vars (stdio). On a transport-type
# switch with no new headers, drop the old ones so env secrets are not
# re-sent as HTTP headers (or vice versa).
if (
"url" in changes
and is_stdio(changes["url"]) != is_stdio(old["url"])
Comment on lines 171 to +179

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an edge case where a user can update use_oauth to True on an existing stdio server without changing the URL. Since _changes_from_payload only checks changes["url"], updating only use_oauth on a stdio server would bypass this check and persist use_oauth = True in the database. We should resolve the resulting URL using changes.get("url", old["url"]) and force use_oauth = False if it is a stdio server.

Suggested change
changes = _changes_from_payload(payload)
if not changes:
raise HTTPException(status_code = 400, detail = "No fields to update")
# headers == HTTP headers (remote) or env vars (stdio). On a transport-type
# switch with no new headers, drop the old ones so env secrets are not
# re-sent as HTTP headers (or vice versa).
if (
"url" in changes
and is_stdio(changes["url"]) != is_stdio(old["url"])
changes = _changes_from_payload(payload)
if not changes:
raise HTTPException(status_code = 400, detail = "No fields to update")
if is_stdio(changes.get("url", old["url"])):
changes["use_oauth"] = False
# headers == HTTP headers (remote) or env vars (stdio). On a transport-type
# switch with no new headers, drop the old ones so env secrets are not
# re-sent as HTTP headers (or vice versa).
if "url" in changes and is_stdio(changes["url"]) != is_stdio(old["url"]) \
and "headers_json" not in changes:
changes["headers_json"] = None

and "headers_json" not in changes
):
changes["headers_json"] = None
# Clear persisted OAuth tokens when the URL changes or OAuth is
# disabled; fastmcp keys tokens by URL and would otherwise let a
# re-pointed server silently inherit the old account's credentials.
Expand Down
236 changes: 236 additions & 0 deletions studio/backend/tests/test_mcp_stdio_improvements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Tests for the proposed PR #5863 improvements.

Covers: _client() self-gating + keep_alive, OAuth normalised off for stdio
(create + update), env/header dropped on a transport-type switch, and the
backend rejecting a command whose first token is a URL scheme.

Run from studio/backend: python -m pytest tests/test_mcp_stdio_improvements.py -q
"""

import asyncio

import pytest
from fastapi import HTTPException

from core.inference import mcp_client
from storage import mcp_servers_db


def _reset_db(tmp_path, monkeypatch):
monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path))
monkeypatch.setattr(mcp_servers_db, "_schema_ready", False)


def _enable(monkeypatch):
monkeypatch.setenv("UNSLOTH_STUDIO_ALLOW_STDIO_MCP", "1")


def _disable(monkeypatch):
monkeypatch.delenv("UNSLOTH_STUDIO_ALLOW_STDIO_MCP", raising = False)


# ── P1: _client() self-gates the stdio sink ─────────────────────────


def test_client_refuses_stdio_when_disabled(monkeypatch):
_disable(monkeypatch)
with pytest.raises(PermissionError):
mcp_client._client("npx -y server /tmp", None)


def test_client_builds_stdio_when_enabled_without_spawning(monkeypatch):
_enable(monkeypatch)
# Constructing the Client must not spawn the subprocess (spawn happens on
# __aenter__); we only assert it builds.
client = mcp_client._client("npx -y server /tmp", {"K": "v"})
assert client is not None


def test_client_http_unaffected_by_gate(monkeypatch):
_disable(monkeypatch)
assert mcp_client._client("https://example.com/mcp", None) is not None


# ── P3: OAuth normalised off for stdio (create + update) ────────────


def test_create_forces_oauth_off_for_stdio(tmp_path, monkeypatch):
import routes.mcp_servers as routes_mcp
from models.mcp_servers import McpServerCreate

_reset_db(tmp_path, monkeypatch)
_enable(monkeypatch)
resp = asyncio.run(
routes_mcp.create_mcp_server(
McpServerCreate(
display_name = "FS", url = "npx -y server /tmp", use_oauth = True
),
current_subject = "u",
)
)
assert resp.use_oauth is False
assert mcp_servers_db.get_server(resp.id)["use_oauth"] == 0


def test_create_keeps_oauth_for_http(tmp_path, monkeypatch):
import routes.mcp_servers as routes_mcp
from models.mcp_servers import McpServerCreate

_reset_db(tmp_path, monkeypatch)
_enable(monkeypatch)
resp = asyncio.run(
routes_mcp.create_mcp_server(
McpServerCreate(display_name = "GH", url = "https://gh/mcp", use_oauth = True),
current_subject = "u",
)
)
assert resp.use_oauth is True


def test_update_url_to_stdio_clears_oauth(tmp_path, monkeypatch):
import routes.mcp_servers as routes_mcp
from models.mcp_servers import McpServerUpdate

_reset_db(tmp_path, monkeypatch)
_enable(monkeypatch)
monkeypatch.setattr(mcp_client, "_oauth_token_store", None)
monkeypatch.setattr(
routes_mcp, "clear_oauth_tokens_async", lambda *a, **k: asyncio.sleep(0)
)
mcp_servers_db.create_server(
id = "s1", display_name = "A", url = "https://a/mcp", use_oauth = True
)
resp = asyncio.run(
routes_mcp.update_mcp_server(
"s1", McpServerUpdate(url = "npx -y server /tmp"), current_subject = "u"
)
)
assert resp.use_oauth is False


# ── P4: env/headers dropped on a transport-type switch ──────────────


def test_switch_stdio_to_http_drops_env(tmp_path, monkeypatch):
import routes.mcp_servers as routes_mcp
from models.mcp_servers import McpServerUpdate

_reset_db(tmp_path, monkeypatch)
_enable(monkeypatch)
mcp_servers_db.create_server(
id = "s1",
display_name = "A",
url = "npx server",
headers_json = '{"API_KEY": "secret"}',
)
resp = asyncio.run(
routes_mcp.update_mcp_server(
"s1", McpServerUpdate(url = "https://remote/mcp"), current_subject = "u"
)
)
# the stdio env must NOT survive as HTTP headers on the remote endpoint
assert resp.headers == {}
assert mcp_servers_db.get_server("s1")["headers_json"] is None


def test_switch_keeps_explicitly_supplied_headers(tmp_path, monkeypatch):
import routes.mcp_servers as routes_mcp
from models.mcp_servers import McpServerUpdate

_reset_db(tmp_path, monkeypatch)
_enable(monkeypatch)
mcp_servers_db.create_server(
id = "s1",
display_name = "A",
url = "npx server",
headers_json = '{"API_KEY": "secret"}',
)
resp = asyncio.run(
routes_mcp.update_mcp_server(
"s1",
McpServerUpdate(
url = "https://remote/mcp", headers = {"Authorization": "Bearer new"}
),
current_subject = "u",
)
)
assert resp.headers == {"Authorization": "Bearer new"}


def test_same_transport_edit_keeps_headers(tmp_path, monkeypatch):
import routes.mcp_servers as routes_mcp
from models.mcp_servers import McpServerUpdate

_reset_db(tmp_path, monkeypatch)
_enable(monkeypatch)
mcp_servers_db.create_server(
id = "s1",
display_name = "A",
url = "npx server",
headers_json = '{"API_KEY": "secret"}',
)
# editing only the display name (still stdio) must not wipe env vars
resp = asyncio.run(
routes_mcp.update_mcp_server(
"s1", McpServerUpdate(display_name = "B"), current_subject = "u"
)
)
assert resp.headers == {"API_KEY": "secret"}


# ── P5: reject a command whose first token is a URL scheme ───────────


def test_validate_url_rejects_url_scheme_command_when_enabled(monkeypatch):
from routes.mcp_servers import _validate_url

_enable(monkeypatch)
for bad in ["ftp://host/x", "file:///etc/passwd", "ws://h/y"]:
with pytest.raises(HTTPException) as exc:
_validate_url(bad)
assert exc.value.status_code == 400


def test_validate_url_allows_url_in_argument(monkeypatch):
from routes.mcp_servers import _validate_url

_enable(monkeypatch)
# :// inside an ARGUMENT (not the first token) is still a valid command
assert _validate_url("npx server --url https://x/mcp") == (
"npx server --url https://x/mcp"
)


# ── P6: Data Recipe stdio path obeys the same host gate ─────────────
# build_mcp_providers needs the data_designer plugin, which is only installed in
# the Studio test job; skip there rather than fail the core matrix.

_STDIO_RECIPE = {
"mcp_providers": [
{
"provider_type": "stdio",
"name": "fs",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
"env": {},
}
]
}


def test_data_recipe_skips_stdio_when_disabled(monkeypatch):
pytest.importorskip("data_designer")
_disable(monkeypatch)
from core.data_recipe.service import build_mcp_providers

# gate off -> the stdio provider is dropped (no subprocess can be spawned)
assert build_mcp_providers(_STDIO_RECIPE) == []


def test_data_recipe_builds_stdio_when_enabled(monkeypatch):
pytest.importorskip("data_designer")
_enable(monkeypatch)
from core.data_recipe.service import build_mcp_providers

built = build_mcp_providers(_STDIO_RECIPE)
assert len(built) == 1 # constructed (not spawned) only when enabled
Loading
Loading