-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Studio: harden stdio MCP gating and fix transport edge cases #5892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -47,6 +47,14 @@ def _validate_url(url: str) -> str: | |||||||||||||||||||||||||||||||||||||||||
| raise HTTPException(status_code = 400, detail = f"Invalid command: {exc}") | ||||||||||||||||||||||||||||||||||||||||||
| if not parts or not parts[0].strip(): | ||||||||||||||||||||||||||||||||||||||||||
| raise HTTPException(status_code = 400, detail = "command must not be empty") | ||||||||||||||||||||||||||||||||||||||||||
| if "://" in parts[0]: | ||||||||||||||||||||||||||||||||||||||||||
| # A URL-scheme first token is a mistyped URL, not a command. Reject | ||||||||||||||||||||||||||||||||||||||||||
| # it cleanly instead of exec-ing it (mirrors the frontend check). | ||||||||||||||||||||||||||||||||||||||||||
| raise HTTPException( | ||||||||||||||||||||||||||||||||||||||||||
| status_code = 400, | ||||||||||||||||||||||||||||||||||||||||||
| detail = "Enter an http(s):// URL, or a local command whose " | ||||||||||||||||||||||||||||||||||||||||||
| "first token is an executable (not a URL).", | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| return trimmed | ||||||||||||||||||||||||||||||||||||||||||
| parsed = urlparse(trimmed) | ||||||||||||||||||||||||||||||||||||||||||
| if parsed.scheme not in ("http", "https"): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -101,6 +109,9 @@ async def create_mcp_server( | |||||||||||||||||||||||||||||||||||||||||
| raise HTTPException(status_code = 400, detail = "display_name must not be empty") | ||||||||||||||||||||||||||||||||||||||||||
| url = _validate_url(payload.url) | ||||||||||||||||||||||||||||||||||||||||||
| headers = _normalize_headers(payload.headers) | ||||||||||||||||||||||||||||||||||||||||||
| # OAuth is HTTP-only; force it off for stdio commands so a stale flag can't | ||||||||||||||||||||||||||||||||||||||||||
| # push the probe onto the 305s OAuth timeout. Backend is the enforcer. | ||||||||||||||||||||||||||||||||||||||||||
| use_oauth = payload.use_oauth and not is_stdio(url) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| server_id = uuid.uuid4().hex[:16] | ||||||||||||||||||||||||||||||||||||||||||
| mcp_servers_db.create_server( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -109,7 +120,7 @@ async def create_mcp_server( | |||||||||||||||||||||||||||||||||||||||||
| url = url, | ||||||||||||||||||||||||||||||||||||||||||
| headers_json = json.dumps(headers) if headers else None, | ||||||||||||||||||||||||||||||||||||||||||
| is_enabled = payload.is_enabled, | ||||||||||||||||||||||||||||||||||||||||||
| use_oauth = payload.use_oauth, | ||||||||||||||||||||||||||||||||||||||||||
| use_oauth = use_oauth, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| return _row_to_response(mcp_servers_db.get_server(server_id)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -142,6 +153,9 @@ def _changes_from_payload(payload: McpServerUpdate) -> dict: | |||||||||||||||||||||||||||||||||||||||||
| status_code = 400, detail = "use_oauth must be true or false" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| changes["use_oauth"] = payload.use_oauth | ||||||||||||||||||||||||||||||||||||||||||
| # stdio is OAuth-less: drop a stale OAuth flag when switching to a command. | ||||||||||||||||||||||||||||||||||||||||||
| if "url" in changes and is_stdio(changes["url"]): | ||||||||||||||||||||||||||||||||||||||||||
| changes["use_oauth"] = False | ||||||||||||||||||||||||||||||||||||||||||
| return changes | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -157,6 +171,15 @@ async def update_mcp_server( | |||||||||||||||||||||||||||||||||||||||||
| changes = _changes_from_payload(payload) | ||||||||||||||||||||||||||||||||||||||||||
| if not changes: | ||||||||||||||||||||||||||||||||||||||||||
| raise HTTPException(status_code = 400, detail = "No fields to update") | ||||||||||||||||||||||||||||||||||||||||||
| # headers == HTTP headers (remote) or env vars (stdio). On a transport-type | ||||||||||||||||||||||||||||||||||||||||||
| # switch with no new headers, drop the old ones so env secrets are not | ||||||||||||||||||||||||||||||||||||||||||
| # re-sent as HTTP headers (or vice versa). | ||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||
| "url" in changes | ||||||||||||||||||||||||||||||||||||||||||
| and is_stdio(changes["url"]) != is_stdio(old["url"]) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
171
to
+179
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an edge case where a user can update
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| and "headers_json" not in changes | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| changes["headers_json"] = None | ||||||||||||||||||||||||||||||||||||||||||
| # Clear persisted OAuth tokens when the URL changes or OAuth is | ||||||||||||||||||||||||||||||||||||||||||
| # disabled; fastmcp keys tokens by URL and would otherwise let a | ||||||||||||||||||||||||||||||||||||||||||
| # re-pointed server silently inherit the old account's credentials. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| """Tests for the proposed PR #5863 improvements. | ||
|
|
||
| Covers: _client() self-gating + keep_alive, OAuth normalised off for stdio | ||
| (create + update), env/header dropped on a transport-type switch, and the | ||
| backend rejecting a command whose first token is a URL scheme. | ||
|
|
||
| Run from studio/backend: python -m pytest tests/test_mcp_stdio_improvements.py -q | ||
| """ | ||
|
|
||
| import asyncio | ||
|
|
||
| import pytest | ||
| from fastapi import HTTPException | ||
|
|
||
| from core.inference import mcp_client | ||
| from storage import mcp_servers_db | ||
|
|
||
|
|
||
| def _reset_db(tmp_path, monkeypatch): | ||
| monkeypatch.setenv("UNSLOTH_STUDIO_HOME", str(tmp_path)) | ||
| monkeypatch.setattr(mcp_servers_db, "_schema_ready", False) | ||
|
|
||
|
|
||
| def _enable(monkeypatch): | ||
| monkeypatch.setenv("UNSLOTH_STUDIO_ALLOW_STDIO_MCP", "1") | ||
|
|
||
|
|
||
| def _disable(monkeypatch): | ||
| monkeypatch.delenv("UNSLOTH_STUDIO_ALLOW_STDIO_MCP", raising = False) | ||
|
|
||
|
|
||
| # ── P1: _client() self-gates the stdio sink ───────────────────────── | ||
|
|
||
|
|
||
| def test_client_refuses_stdio_when_disabled(monkeypatch): | ||
| _disable(monkeypatch) | ||
| with pytest.raises(PermissionError): | ||
| mcp_client._client("npx -y server /tmp", None) | ||
|
|
||
|
|
||
| def test_client_builds_stdio_when_enabled_without_spawning(monkeypatch): | ||
| _enable(monkeypatch) | ||
| # Constructing the Client must not spawn the subprocess (spawn happens on | ||
| # __aenter__); we only assert it builds. | ||
| client = mcp_client._client("npx -y server /tmp", {"K": "v"}) | ||
| assert client is not None | ||
|
|
||
|
|
||
| def test_client_http_unaffected_by_gate(monkeypatch): | ||
| _disable(monkeypatch) | ||
| assert mcp_client._client("https://example.com/mcp", None) is not None | ||
|
|
||
|
|
||
| # ── P3: OAuth normalised off for stdio (create + update) ──────────── | ||
|
|
||
|
|
||
| def test_create_forces_oauth_off_for_stdio(tmp_path, monkeypatch): | ||
| import routes.mcp_servers as routes_mcp | ||
| from models.mcp_servers import McpServerCreate | ||
|
|
||
| _reset_db(tmp_path, monkeypatch) | ||
| _enable(monkeypatch) | ||
| resp = asyncio.run( | ||
| routes_mcp.create_mcp_server( | ||
| McpServerCreate( | ||
| display_name = "FS", url = "npx -y server /tmp", use_oauth = True | ||
| ), | ||
| current_subject = "u", | ||
| ) | ||
| ) | ||
| assert resp.use_oauth is False | ||
| assert mcp_servers_db.get_server(resp.id)["use_oauth"] == 0 | ||
|
|
||
|
|
||
| def test_create_keeps_oauth_for_http(tmp_path, monkeypatch): | ||
| import routes.mcp_servers as routes_mcp | ||
| from models.mcp_servers import McpServerCreate | ||
|
|
||
| _reset_db(tmp_path, monkeypatch) | ||
| _enable(monkeypatch) | ||
| resp = asyncio.run( | ||
| routes_mcp.create_mcp_server( | ||
| McpServerCreate(display_name = "GH", url = "https://gh/mcp", use_oauth = True), | ||
| current_subject = "u", | ||
| ) | ||
| ) | ||
| assert resp.use_oauth is True | ||
|
|
||
|
|
||
| def test_update_url_to_stdio_clears_oauth(tmp_path, monkeypatch): | ||
| import routes.mcp_servers as routes_mcp | ||
| from models.mcp_servers import McpServerUpdate | ||
|
|
||
| _reset_db(tmp_path, monkeypatch) | ||
| _enable(monkeypatch) | ||
| monkeypatch.setattr(mcp_client, "_oauth_token_store", None) | ||
| monkeypatch.setattr( | ||
| routes_mcp, "clear_oauth_tokens_async", lambda *a, **k: asyncio.sleep(0) | ||
| ) | ||
| mcp_servers_db.create_server( | ||
| id = "s1", display_name = "A", url = "https://a/mcp", use_oauth = True | ||
| ) | ||
| resp = asyncio.run( | ||
| routes_mcp.update_mcp_server( | ||
| "s1", McpServerUpdate(url = "npx -y server /tmp"), current_subject = "u" | ||
| ) | ||
| ) | ||
| assert resp.use_oauth is False | ||
|
|
||
|
|
||
| # ── P4: env/headers dropped on a transport-type switch ────────────── | ||
|
|
||
|
|
||
| def test_switch_stdio_to_http_drops_env(tmp_path, monkeypatch): | ||
| import routes.mcp_servers as routes_mcp | ||
| from models.mcp_servers import McpServerUpdate | ||
|
|
||
| _reset_db(tmp_path, monkeypatch) | ||
| _enable(monkeypatch) | ||
| mcp_servers_db.create_server( | ||
| id = "s1", | ||
| display_name = "A", | ||
| url = "npx server", | ||
| headers_json = '{"API_KEY": "secret"}', | ||
| ) | ||
| resp = asyncio.run( | ||
| routes_mcp.update_mcp_server( | ||
| "s1", McpServerUpdate(url = "https://remote/mcp"), current_subject = "u" | ||
| ) | ||
| ) | ||
| # the stdio env must NOT survive as HTTP headers on the remote endpoint | ||
| assert resp.headers == {} | ||
| assert mcp_servers_db.get_server("s1")["headers_json"] is None | ||
|
|
||
|
|
||
| def test_switch_keeps_explicitly_supplied_headers(tmp_path, monkeypatch): | ||
| import routes.mcp_servers as routes_mcp | ||
| from models.mcp_servers import McpServerUpdate | ||
|
|
||
| _reset_db(tmp_path, monkeypatch) | ||
| _enable(monkeypatch) | ||
| mcp_servers_db.create_server( | ||
| id = "s1", | ||
| display_name = "A", | ||
| url = "npx server", | ||
| headers_json = '{"API_KEY": "secret"}', | ||
| ) | ||
| resp = asyncio.run( | ||
| routes_mcp.update_mcp_server( | ||
| "s1", | ||
| McpServerUpdate( | ||
| url = "https://remote/mcp", headers = {"Authorization": "Bearer new"} | ||
| ), | ||
| current_subject = "u", | ||
| ) | ||
| ) | ||
| assert resp.headers == {"Authorization": "Bearer new"} | ||
|
|
||
|
|
||
| def test_same_transport_edit_keeps_headers(tmp_path, monkeypatch): | ||
| import routes.mcp_servers as routes_mcp | ||
| from models.mcp_servers import McpServerUpdate | ||
|
|
||
| _reset_db(tmp_path, monkeypatch) | ||
| _enable(monkeypatch) | ||
| mcp_servers_db.create_server( | ||
| id = "s1", | ||
| display_name = "A", | ||
| url = "npx server", | ||
| headers_json = '{"API_KEY": "secret"}', | ||
| ) | ||
| # editing only the display name (still stdio) must not wipe env vars | ||
| resp = asyncio.run( | ||
| routes_mcp.update_mcp_server( | ||
| "s1", McpServerUpdate(display_name = "B"), current_subject = "u" | ||
| ) | ||
| ) | ||
| assert resp.headers == {"API_KEY": "secret"} | ||
|
|
||
|
|
||
| # ── P5: reject a command whose first token is a URL scheme ─────────── | ||
|
|
||
|
|
||
| def test_validate_url_rejects_url_scheme_command_when_enabled(monkeypatch): | ||
| from routes.mcp_servers import _validate_url | ||
|
|
||
| _enable(monkeypatch) | ||
| for bad in ["ftp://host/x", "file:///etc/passwd", "ws://h/y"]: | ||
| with pytest.raises(HTTPException) as exc: | ||
| _validate_url(bad) | ||
| assert exc.value.status_code == 400 | ||
|
|
||
|
|
||
| def test_validate_url_allows_url_in_argument(monkeypatch): | ||
| from routes.mcp_servers import _validate_url | ||
|
|
||
| _enable(monkeypatch) | ||
| # :// inside an ARGUMENT (not the first token) is still a valid command | ||
| assert _validate_url("npx server --url https://x/mcp") == ( | ||
| "npx server --url https://x/mcp" | ||
| ) | ||
|
|
||
|
|
||
| # ── P6: Data Recipe stdio path obeys the same host gate ───────────── | ||
| # build_mcp_providers needs the data_designer plugin, which is only installed in | ||
| # the Studio test job; skip there rather than fail the core matrix. | ||
|
|
||
| _STDIO_RECIPE = { | ||
| "mcp_providers": [ | ||
| { | ||
| "provider_type": "stdio", | ||
| "name": "fs", | ||
| "command": "npx", | ||
| "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], | ||
| "env": {}, | ||
| } | ||
| ] | ||
| } | ||
|
|
||
|
|
||
| def test_data_recipe_skips_stdio_when_disabled(monkeypatch): | ||
| pytest.importorskip("data_designer") | ||
| _disable(monkeypatch) | ||
| from core.data_recipe.service import build_mcp_providers | ||
|
|
||
| # gate off -> the stdio provider is dropped (no subprocess can be spawned) | ||
| assert build_mcp_providers(_STDIO_RECIPE) == [] | ||
|
|
||
|
|
||
| def test_data_recipe_builds_stdio_when_enabled(monkeypatch): | ||
| pytest.importorskip("data_designer") | ||
| _enable(monkeypatch) | ||
| from core.data_recipe.service import build_mcp_providers | ||
|
|
||
| built = build_mcp_providers(_STDIO_RECIPE) | ||
| assert len(built) == 1 # constructed (not spawned) only when enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the stored row is already a stdio command, a client can still
PUTonly{"use_oauth": true}; this guard only runs when the request also includesurl, sochangespreservesuse_oauth=Trueand persists it. Subsequent refresh/discovery for that server pass the stored flag intoprobe_timeout(..., use_oauth), so the stdio probe takes the 305s OAuth timeout path instead of the intended 60s stdio path. Normalize against the effective URL (old URL unless a new one was supplied) before saving.Useful? React with 👍 / 👎.