diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f31dd2c..2ecfb78 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,6 +8,18 @@ on: jobs: + test: + name: Run tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.11" + - name: Run pytest + run: uv run --group dev pytest tests + build-macos: name: Build jarvis binary + macOS app (arm64) runs-on: macos-26 diff --git a/.gitignore b/.gitignore index a797dd1..18235bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ .venv/ .tokens/ __pycache__/ +.pytest_cache/ +.coverage macOs/Jarvis/Jarvis.xcodeproj/xcuserdata/ macOs/Jarvis/Jarvis.xcodeproj/project.xcworkspace/xcuserdata/ macOs/Jarvis/Jarvis/Resources/jarvis diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..88c8260 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,74 @@ +# AGENTS.md + +## What this is + +Jarvis is an MCP proxy that aggregates multiple MCP servers behind 2 synthetic tools (`search_tools` + `call_tool`). Python 3.11+, managed with **uv**. + +## Layout + +``` +src/jarvis/ # the package (6 modules) + __main__.py # CLI entrypoint — arg parsing, server startup + config.py # DATA_DIR, presets, config loading, OAuth wiring + proxy.py # builds FastMCP proxy (stdio vs HTTP client selection) + api.py # REST management API (runs on port+1) + probe.py # server/tool discovery + tui.py # Textual TUIs (mcp manager, auth manager) +tests/unit/ # pure unit tests +tests/integration/ # API endpoint + TUI tests +scripts/ # PyInstaller build scripts +macOs/ # Xcode project for the menu bar app +``` + +## Commands + +```bash +# Install deps (no separate install step — uv handles it) +uv sync --group dev + +# Run locally +uv run python -m jarvis --http 7070 + +# Run all tests +uv run --group dev pytest tests + +# Run a single test file or test +uv run --group dev pytest tests/unit/test_config.py +uv run --group dev pytest tests/unit/test_config.py::test_name -k test_name + +# Build standalone binary (macOS arm64) +bash scripts/build_jarvis_binary.sh + +# Build standalone binary (Linux x86_64) +bash scripts/build_jarvis_binary_linux.sh +``` + +## Testing quirks + +- **pytest-asyncio `auto` mode** is on (`asyncio_mode = "auto"` in pyproject.toml). Do not add `@pytest.mark.asyncio` to async tests. +- **`conftest.py` sets `JARVIS_DATA_DIR` at import time** before any jarvis module is imported. This isolates tests from `~/.jarvis`. If you add a new conftest or rearrange imports, preserve this ordering — the module-level `DATA_DIR` and `token_storage` in `config.py` bind once on first import. +- Use the `data_dir` fixture for per-test isolation. It monkeypatches `DATA_DIR`, `PRESETS_PATH`, and `token_storage` across `config`, `api`, and `probe` modules. +- Use the `servers_json` fixture when you need a pre-populated `servers.json` in the isolated data dir. + +## Architecture notes + +- `config.py` resolves `DATA_DIR` and creates `token_storage` (DiskStore) **at module level**. The env var `JARVIS_DATA_DIR` overrides the default `~/.jarvis` — this is the only mechanism for test isolation. +- `proxy.py` chooses `StatefulProxyClient` (persistent subprocess) for stdio servers and `ProxyClient` (fresh connection) for HTTP/SSE. The stateful clients are pinned to `mcp._stateful_clients` to avoid GC. +- The hatchling build uses `packages = ["src/jarvis"]` — the wheel package is `jarvis`, not `jarvis_mcp`. + +## macOS app + +The menu bar app (`macOs/Jarvis/`) is a Swift/Xcode project that embeds the PyInstaller binary. **Build order matters** — the binary must exist before Xcode can bundle it: + +```bash +# 1. Build the Python binary into the Xcode Resources dir +bash scripts/build_jarvis_binary.sh # → macOs/Jarvis/Jarvis/Resources/jarvis + +# 2. Build the app +xcodebuild -project macOs/Jarvis/Jarvis.xcodeproj -scheme Jarvis -configuration Debug build +``` +## CI + +- Every push/PR: pytest + binary builds (macOS arm64, Linux x86_64). +- No lint or typecheck step in CI. Ruff cache exists locally but there is no enforced config. +- Releases trigger on `v*` tags and produce binaries + a macOS `.dmg`. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 0000000..47dc3e3 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/README.md b/README.md index 6e6dad9..fc87be6 100644 --- a/README.md +++ b/README.md @@ -133,17 +133,13 @@ Lists all servers and their auth status. `l` to trigger the OAuth login flow for ## OAuth authentication -Servers with `"auth": "oauth"` require a one-time browser login. Use the `auth` TUI (above), or authenticate from the CLI: +Servers with `"auth": "oauth"` require a one-time browser login. Use the `auth` TUI to trigger the flow: ```bash -# All OAuth servers -jarvis --auth - -# Specific server -jarvis --auth atlassian +jarvis auth ``` -Tokens are stored in `~/.jarvis/` and reused automatically on subsequent runs. +Select the server and press `l` to open the browser login flow. Tokens are stored in `~/.jarvis/` and reused automatically on subsequent runs. ## Modes @@ -173,7 +169,6 @@ Commands: Options: --config PATH Use a specific config file --http PORT Run as an HTTP server on PORT (management UI) - --auth [SERVER] Authenticate with all servers or a specific one --code-mode Enable code mode transform --help, -h Show this message and exit diff --git a/macOs/Jarvis/Jarvis/Services/LogViewerView.swift b/macOs/Jarvis/Jarvis/Services/LogViewerView.swift index 4f57a7f..1f835d9 100644 --- a/macOs/Jarvis/Jarvis/Services/LogViewerView.swift +++ b/macOs/Jarvis/Jarvis/Services/LogViewerView.swift @@ -94,8 +94,7 @@ struct LogViewerView: View { if let content = try? String(contentsOf: logURL, encoding: .utf8) { // Get last 10000 lines to avoid memory issues let lines = content.split(separator: "\n", omittingEmptySubsequences: false) - let recentLines = lines.suffix(10000) - logContent = recentLines.joined(separator: "\n") + logContent = lines.suffix(10000).joined(separator: "\n") } else { logContent = "No logs found at \(logURL.path(percentEncoded: false))\n\nThe log file will be created when the server starts." } diff --git a/pyproject.toml b/pyproject.toml index d58ad87..5efc72f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,20 @@ dependencies = [ "textual>=0.83", ] +[dependency-groups] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "coverage[toml]>=7.0", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src/jarvis"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/scripts/build_jarvis_binary.sh b/scripts/build_jarvis_binary.sh index b74bb7b..362685d 100755 --- a/scripts/build_jarvis_binary.sh +++ b/scripts/build_jarvis_binary.sh @@ -27,6 +27,8 @@ uv run --with 'pyinstaller==6.19.0' pyinstaller \ --copy-metadata starlette \ --copy-metadata uvicorn \ --copy-metadata textual \ + --copy-metadata pydantic-monty \ + --hidden-import pydantic_monty \ "$REPO_ROOT/src/jarvis/__main__.py" echo "==> Done. Binary at: $OUT_DIR/jarvis" diff --git a/scripts/build_jarvis_binary_linux.sh b/scripts/build_jarvis_binary_linux.sh index c614276..a4fc25c 100755 --- a/scripts/build_jarvis_binary_linux.sh +++ b/scripts/build_jarvis_binary_linux.sh @@ -27,6 +27,8 @@ uv run --with 'pyinstaller==6.19.0' pyinstaller \ --copy-metadata starlette \ --copy-metadata uvicorn \ --copy-metadata textual \ + --copy-metadata pydantic-monty \ + --hidden-import pydantic_monty \ "$REPO_ROOT/src/jarvis/__main__.py" echo "==> Done. Binary at: $OUT_DIR/jarvis" diff --git a/src/jarvis/__main__.py b/src/jarvis/__main__.py index d57f1b5..41aa474 100644 --- a/src/jarvis/__main__.py +++ b/src/jarvis/__main__.py @@ -3,7 +3,7 @@ from pathlib import Path from fastmcp.mcp_config import MCPConfig -from fastmcp.server import create_proxy +from jarvis.proxy import build_proxy from fastmcp.experimental.transforms.code_mode import CodeMode from fastmcp.server.transforms.search import BM25SearchTransform @@ -58,7 +58,6 @@ "Options:\n" " --config PATH Use a specific config file\n" " --http PORT Run as an HTTP server on PORT (management UI)\n" - " --auth [SERVER] Authenticate with all servers or a specific one\n" " --code-mode Enable code mode transform\n" " --help, -h Show this message and exit\n" "\n" @@ -82,57 +81,9 @@ disabled_tools = get_disabled_tools(config_path) code_mode = "--code-mode" in sys.argv -# Validate config before branching so all paths share the same MCPConfig object -# for server-name lookups (e.g. --auth target validation). config = MCPConfig.model_validate(mcp_dict) -if "--auth" in sys.argv: - # Resolve and validate the target server *before* creating the proxy so we - # can narrow the connection to only the requested server. - auth_idx = next( - (i for i, arg in enumerate(filtered_argv) if arg == "--auth"), None - ) - if auth_idx is not None: - target = next( - ( - filtered_argv[i] - for i in range(auth_idx + 1, len(filtered_argv)) - if not filtered_argv[i].startswith("-") - ), - None, - ) - else: - target = None - if target and target not in config.mcpServers: - print( - f"Unknown server '{target}'. Available: {', '.join(config.mcpServers)}" - ) - sys.exit(1) - - if target: - # Connect only to the requested server instead of all of them. - narrow_dict = {**mcp_dict, "mcpServers": {target: mcp_dict["mcpServers"][target]}} - auth_config = MCPConfig.model_validate(narrow_dict) - configure_servers(auth_config) - mcp = create_proxy(auth_config, name="jarvis") - else: - configure_servers(config) - mcp = create_proxy(config, name="jarvis") - - mcp.add_transform(BM25SearchTransform(max_results=5)) - - async def auth() -> None: - tools = await mcp.list_tools() - print(f"Authenticated. {len(tools)} tools available:") - for t in tools: - print(f" - {t.name}") - - try: - asyncio.run(auth()) - except KeyboardInterrupt: - print("\nAuth cancelled.") - -elif "--http" in sys.argv: +if "--http" in sys.argv: idx = sys.argv.index("--http") port_arg = sys.argv[idx + 1] if idx + 1 < len(sys.argv) else "" if not port_arg.isdigit(): @@ -148,12 +99,10 @@ async def auth() -> None: port = parsed_port configure_servers(config) - mcp = create_proxy(config, name="jarvis") + mcp = build_proxy(config, "jarvis") if disabled_tools: mcp.disable(names=disabled_tools) - mcp.add_transform( - CodeMode() if code_mode else BM25SearchTransform(max_results=5) - ) + mcp.add_transform(CodeMode() if code_mode else BM25SearchTransform(max_results=5)) async def _run_http() -> None: start_api_thread(port, port + 1) @@ -168,10 +117,8 @@ async def _run_http() -> None: else: configure_servers(config) - mcp = create_proxy(config, name="jarvis") + mcp = build_proxy(config, "jarvis") if disabled_tools: mcp.disable(names=disabled_tools) - mcp.add_transform( - CodeMode() if code_mode else BM25SearchTransform(max_results=5) - ) + mcp.add_transform(CodeMode() if code_mode else BM25SearchTransform(max_results=5)) mcp.run(show_banner=False) diff --git a/src/jarvis/config.py b/src/jarvis/config.py index 65b93a2..8460123 100644 --- a/src/jarvis/config.py +++ b/src/jarvis/config.py @@ -12,7 +12,9 @@ ENV_VAR_RE = re.compile(r"\$\{(\w+)\}") NON_STANDARD_KEYS = {"enabled", "disabledTools"} -DATA_DIR = Path.home() / ".jarvis" +# Data directory is overridable via ``JARVIS_DATA_DIR`` so tests (and alternate +# deployments) can isolate their state from the user's real ``~/.jarvis``. +DATA_DIR = Path(os.environ.get("JARVIS_DATA_DIR") or (Path.home() / ".jarvis")) PRESETS_PATH = DATA_DIR / "presets.json" token_storage = DiskStore(directory=str(DATA_DIR)) diff --git a/src/jarvis/proxy.py b/src/jarvis/proxy.py new file mode 100644 index 0000000..25197f2 --- /dev/null +++ b/src/jarvis/proxy.py @@ -0,0 +1,56 @@ +"""Proxy builder for Jarvis. + +Replaces ``fastmcp.server.create_proxy(MCPConfig)`` with a builder that uses +``StatefulProxyClient`` for stdio backends (persistent subprocess per frontend +session) and ``ProxyClient`` for HTTP/SSE backends (fresh connection per request). +""" + +from __future__ import annotations + +from fastmcp.mcp_config import MCPConfig, StdioMCPServer +from fastmcp.server import FastMCP +from fastmcp.server.providers.proxy import ( + ProxyClient, + ProxyProvider, + StatefulProxyClient, +) + + +def build_proxy(config: MCPConfig, name: str = "jarvis") -> FastMCP: + """Build a FastMCP proxy server from an MCPConfig. + + For each server in *config*: + - stdio servers get a ``StatefulProxyClient`` with ``new_stateful`` as the + client factory, so the subprocess lives for the duration of each frontend + session rather than being respawned on every tool call. + - HTTP/SSE servers get a ``ProxyClient`` with ``new`` as the factory, + giving a fresh connection per request (stateless, correct for HTTP). + + Args: + config: Validated MCPConfig with servers already configured + (OAuth injected, env vars expanded). + name: Name for the resulting FastMCP server. + + Returns: + A ``FastMCP`` server with one ``ProxyProvider`` per backend, namespaced + by server name. + """ + mcp: FastMCP = FastMCP(name=name) + # Keep strong references to StatefulProxyClient instances so they are not + # garbage-collected while the server is alive (new_stateful reads _caches). + mcp._stateful_clients: list = [] # type: ignore[attr-defined] + + for server_name, server in config.mcpServers.items(): + transport = server.to_transport() + + if isinstance(server, StdioMCPServer): + client = StatefulProxyClient(transport) + mcp._stateful_clients.append(client) + factory = client.new_stateful + else: + client = ProxyClient(transport) + factory = client.new + + mcp.add_provider(ProxyProvider(factory), namespace=server_name) + + return mcp diff --git a/src/jarvis/tui.py b/src/jarvis/tui.py index 199c17b..166d79b 100644 --- a/src/jarvis/tui.py +++ b/src/jarvis/tui.py @@ -29,7 +29,6 @@ def load_config(config_path: Path) -> tuple[dict[str, Any], str | None]: return {"mcpServers": {}}, f"Config parse error: {exc}" - # ── MCP Manager ─────────────────────────────────────────────────────────────── @@ -192,7 +191,10 @@ async def _probe_all(self) -> None: raw_servers = { d["name"]: servers_config[d["name"]] for node in tree.root.children - if (d := node.data) and d.get("type") == "server" and d.get("enabled", True) and d["name"] in servers_config + if (d := node.data) + and d.get("type") == "server" + and d.get("enabled", True) + and d["name"] in servers_config } total = len(raw_servers) @@ -285,8 +287,8 @@ class AuthManagerApp(App[None]): """Manage OAuth authentication for proxied MCP servers. Lists every configured server and its auth type. For OAuth servers the - user can trigger a login flow (opens the browser via the existing - ``jarvis --auth SERVER`` flow) or clear all cached tokens. + user can trigger a login flow (opens the browser) or clear all cached + tokens. """ TITLE = "Jarvis Auth Manager" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3f5059a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,102 @@ +"""Shared test fixtures. + +Isolates every test from the real ``~/.jarvis`` directory in two layers: + +1. *Before* any ``jarvis`` module is imported, ``JARVIS_DATA_DIR`` is set to + a session-scoped temp directory. This guarantees that the module-level + ``config.token_storage = DiskStore(directory=str(DATA_DIR))`` binding — + which happens exactly once on first import — never touches the user's + real ``~/.jarvis/cache.db``. +2. A per-test ``data_dir`` fixture rebinds ``DATA_DIR``/``PRESETS_PATH`` and + recreates ``token_storage`` against an isolated ``tmp_path`` subdir, so + tests that need per-test isolation still get it. +""" + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# Step 1: isolate the module-level ``token_storage`` from the real home dir. +# This runs at conftest *import* time, which happens before any jarvis import +# below, so ``DiskStore`` ends up pointing at a scratch directory that we +# then throw away at the end of the session. +# --------------------------------------------------------------------------- + +_SESSION_DATA_DIR = Path(tempfile.mkdtemp(prefix="jarvis-test-")) +os.environ["JARVIS_DATA_DIR"] = str(_SESSION_DATA_DIR) + +# Safe to import jarvis now — DATA_DIR will resolve to _SESSION_DATA_DIR. +from jarvis import api as api_mod # noqa: E402 +from jarvis import config as config_mod # noqa: E402 +from jarvis import probe as probe_mod # noqa: E402 +from key_value.aio.stores.disk import DiskStore # noqa: E402 + + +def pytest_sessionfinish(session, exitstatus) -> None: # noqa: ARG001 + """Remove the session-scoped scratch directory after the run.""" + shutil.rmtree(_SESSION_DATA_DIR, ignore_errors=True) + + +@pytest.fixture +def data_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Redirect all jarvis data-dir reads/writes to an isolated temp dir. + + Rebinds ``DATA_DIR`` / ``PRESETS_PATH`` / ``token_storage`` on every + module that captured a reference at import time, so each test runs + against a fresh, empty data directory. + """ + fake_dir = tmp_path / "jarvis_data" + fake_dir.mkdir() + + monkeypatch.setattr(config_mod, "DATA_DIR", fake_dir) + monkeypatch.setattr(config_mod, "PRESETS_PATH", fake_dir / "presets.json") + # ``api`` and ``probe`` both imported ``DATA_DIR`` by name at module load. + monkeypatch.setattr(api_mod, "DATA_DIR", fake_dir) + monkeypatch.setattr(probe_mod, "DATA_DIR", fake_dir) + + # Recreate the token store against the fresh dir and rebind it on every + # module that holds a reference. This keeps OAuth tests from leaking + # across each other *and* makes it impossible to hit the session-scoped + # scratch dir either. + fresh_store = DiskStore(directory=str(fake_dir)) + monkeypatch.setattr(config_mod, "token_storage", fresh_store) + monkeypatch.setattr(probe_mod, "token_storage", fresh_store) + + return fake_dir + + +@pytest.fixture +def servers_json(data_dir: Path) -> Path: + """Write a default ``servers.json`` into the isolated data dir.""" + path = data_dir / "servers.json" + path.write_text( + json.dumps( + { + "mcpServers": { + "alpha": { + "url": "https://alpha.example.com/mcp", + "transport": "http", + }, + "beta": { + "command": "echo", + "args": ["hello"], + "disabledTools": ["noisy"], + }, + "gamma": { + "url": "https://gamma.example.com/mcp", + "transport": "http", + "enabled": False, + }, + } + }, + indent=2, + ) + ) + return path diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_api_endpoints.py b/tests/integration/test_api_endpoints.py new file mode 100644 index 0000000..50cd454 --- /dev/null +++ b/tests/integration/test_api_endpoints.py @@ -0,0 +1,349 @@ +"""Integration tests for the Jarvis REST API. + +Exercises every route on ``create_api_app`` end-to-end through Starlette's +``TestClient``. Network-touching code paths (``probe_all_servers``) are +stubbed so the tests never open sockets to real MCP servers. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +from starlette.testclient import TestClient + +from jarvis import api as api_mod +from jarvis import probe as probe_mod +from jarvis.api import create_api_app + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def client(data_dir: Path, servers_json: Path) -> TestClient: + app = create_api_app(mcp_port=7070) + with TestClient(app) as client: + yield client + + +@pytest.fixture +def stub_probe(monkeypatch: pytest.MonkeyPatch): + """Replace ``probe_all_servers`` with a deterministic in-memory stub.""" + + async def fake_probe_all(raw_servers: dict, timeout: float = 30): + return { + name: [ + {"name": f"{name}_tool1", "description": "first"}, + {"name": f"{name}_tool2", "description": "second"}, + ] + for name in raw_servers + } + + monkeypatch.setattr(api_mod, "probe_all_servers", fake_probe_all) + # also patch the original module so any lingering reference works + monkeypatch.setattr(probe_mod, "probe_all_servers", fake_probe_all) + return fake_probe_all + + +# ── /api/health ────────────────────────────────────────────────────────────── + + +class TestHealth: + def test_returns_status_and_ports(self, client: TestClient) -> None: + response = client.get("/api/health") + assert response.status_code == 200 + assert response.json() == { + "status": "ok", + "mcp_port": 7070, + "api_port": 7071, + } + + +# ── /api/tools ─────────────────────────────────────────────────────────────── + + +class TestGetTools: + def test_returns_probe_results_for_enabled_servers( + self, client: TestClient, stub_probe + ) -> None: + response = client.get("/api/tools") + assert response.status_code == 200 + body = response.json() + # gamma is disabled in the fixture and must not appear + assert set(body.keys()) == {"alpha", "beta"} + assert body["alpha"][0]["name"] == "alpha_tool1" + + def test_returns_error_on_probe_failure( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + async def boom(raw_servers, timeout: float = 30): + raise RuntimeError("probe exploded") + + monkeypatch.setattr(api_mod, "probe_all_servers", boom) + response = client.get("/api/tools") + assert response.status_code == 500 + assert "probe exploded" in response.json()["error"] + + +# ── /api/config ────────────────────────────────────────────────────────────── + + +class TestConfigEndpoint: + def test_get_returns_raw_servers_json( + self, client: TestClient, servers_json: Path + ) -> None: + response = client.get("/api/config") + assert response.status_code == 200 + assert response.json() == json.loads(servers_json.read_text()) + + def test_put_overwrites_config( + self, client: TestClient, servers_json: Path + ) -> None: + new = {"mcpServers": {"solo": {"url": "http://solo"}}} + response = client.put("/api/config", json=new) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + assert json.loads(servers_json.read_text()) == new + + def test_path_traversal_is_rejected( + self, client: TestClient, tmp_path: Path + ) -> None: + outside = tmp_path / "evil.json" + outside.write_text("{}") + response = client.get(f"/api/config?path={outside}") + assert response.status_code == 400 + + def test_explicit_path_inside_data_dir_is_accepted( + self, client: TestClient, data_dir: Path + ) -> None: + alt = data_dir / "alt.json" + alt.write_text(json.dumps({"mcpServers": {"x": {"url": "http://x"}}})) + response = client.get(f"/api/config?path={alt}") + assert response.status_code == 200 + assert response.json() == {"mcpServers": {"x": {"url": "http://x"}}} + + def test_non_json_file_is_rejected( + self, client: TestClient, data_dir: Path + ) -> None: + alt = data_dir / "alt.txt" + alt.write_text("not json") + response = client.get(f"/api/config?path={alt}") + assert response.status_code == 400 + + +# ── /api/servers/{name}/toggle ─────────────────────────────────────────────── + + +class TestToggleServer: + def test_disable_server_writes_enabled_false( + self, client: TestClient, servers_json: Path + ) -> None: + response = client.post( + "/api/servers/alpha/toggle", json={"enabled": False} + ) + assert response.status_code == 200 + data = json.loads(servers_json.read_text()) + assert data["mcpServers"]["alpha"]["enabled"] is False + + def test_enable_server_removes_enabled_key( + self, client: TestClient, servers_json: Path + ) -> None: + # ``gamma`` starts with enabled=false + response = client.post( + "/api/servers/gamma/toggle", json={"enabled": True} + ) + assert response.status_code == 200 + data = json.loads(servers_json.read_text()) + assert "enabled" not in data["mcpServers"]["gamma"] + + def test_unknown_server_404(self, client: TestClient) -> None: + response = client.post( + "/api/servers/ghost/toggle", json={"enabled": False} + ) + assert response.status_code == 404 + assert "not found" in response.json()["error"] + + +# ── /api/tools/toggle ──────────────────────────────────────────────────────── + + +class TestToggleTool: + def test_disable_tool_appends_to_disabled_list( + self, client: TestClient, servers_json: Path + ) -> None: + response = client.post( + "/api/tools/toggle", + json={"server": "alpha", "tool": "destructive", "enabled": False}, + ) + assert response.status_code == 200 + data = json.loads(servers_json.read_text()) + assert data["mcpServers"]["alpha"]["disabledTools"] == ["destructive"] + + def test_enable_tool_removes_from_disabled_list( + self, client: TestClient, servers_json: Path + ) -> None: + # beta starts with disabledTools=["noisy"] + response = client.post( + "/api/tools/toggle", + json={"server": "beta", "tool": "noisy", "enabled": True}, + ) + assert response.status_code == 200 + data = json.loads(servers_json.read_text()) + # key must be removed entirely when list becomes empty + assert "disabledTools" not in data["mcpServers"]["beta"] + + def test_enable_tool_not_in_list_is_noop( + self, client: TestClient, servers_json: Path + ) -> None: + before = json.loads(servers_json.read_text()) + response = client.post( + "/api/tools/toggle", + json={"server": "alpha", "tool": "never", "enabled": True}, + ) + assert response.status_code == 200 + assert json.loads(servers_json.read_text()) == before + + def test_disable_same_tool_twice_is_idempotent( + self, client: TestClient, servers_json: Path + ) -> None: + payload = {"server": "alpha", "tool": "dupe", "enabled": False} + client.post("/api/tools/toggle", json=payload) + client.post("/api/tools/toggle", json=payload) + data = json.loads(servers_json.read_text()) + assert data["mcpServers"]["alpha"]["disabledTools"] == ["dupe"] + + def test_unknown_server_404(self, client: TestClient) -> None: + response = client.post( + "/api/tools/toggle", + json={"server": "ghost", "tool": "t", "enabled": False}, + ) + assert response.status_code == 404 + + +# ── /api/presets ───────────────────────────────────────────────────────────── + + +class TestPresetsEndpoints: + def test_list_initially_empty(self, client: TestClient, data_dir: Path) -> None: + response = client.get("/api/presets") + assert response.status_code == 200 + body = response.json() + assert body["presets"] == [] + assert body["activePresetID"] is None + assert body["activeConfigPath"] == str(data_dir / "servers.json") + + def test_create_preset_returns_201( + self, client: TestClient, data_dir: Path + ) -> None: + preset_file = data_dir / "work.json" + preset_file.write_text('{"mcpServers": {}}') + response = client.post( + "/api/presets", + json={"name": "work", "filePath": str(preset_file)}, + ) + assert response.status_code == 201 + preset = response.json()["preset"] + assert preset["name"] == "work" + assert preset["id"] # non-empty uuid + # and appears in the listing + listing = client.get("/api/presets").json() + assert any(p["id"] == preset["id"] for p in listing["presets"]) + + def test_create_preset_missing_fields_returns_400( + self, client: TestClient + ) -> None: + response = client.post("/api/presets", json={"name": "only-name"}) + assert response.status_code == 400 + + def test_update_preset(self, client: TestClient, data_dir: Path) -> None: + f1 = data_dir / "a.json" + f1.write_text('{"mcpServers": {}}') + created = client.post( + "/api/presets", json={"name": "a", "filePath": str(f1)} + ).json()["preset"] + response = client.patch( + f"/api/presets/{created['id']}", json={"name": "renamed"} + ) + assert response.status_code == 200 + assert response.json()["preset"]["name"] == "renamed" + + def test_update_unknown_preset_404(self, client: TestClient) -> None: + response = client.patch("/api/presets/does-not-exist", json={"name": "x"}) + assert response.status_code == 404 + + def test_delete_preset(self, client: TestClient, data_dir: Path) -> None: + f = data_dir / "d.json" + f.write_text('{"mcpServers": {}}') + created = client.post( + "/api/presets", json={"name": "d", "filePath": str(f)} + ).json()["preset"] + response = client.delete(f"/api/presets/{created['id']}") + assert response.status_code == 200 + listing = client.get("/api/presets").json() + assert all(p["id"] != created["id"] for p in listing["presets"]) + + def test_delete_unknown_preset_404(self, client: TestClient) -> None: + response = client.delete("/api/presets/does-not-exist") + assert response.status_code == 404 + + def test_delete_active_preset_clears_active( + self, client: TestClient, data_dir: Path + ) -> None: + f = data_dir / "d.json" + f.write_text('{"mcpServers": {}}') + created = client.post( + "/api/presets", json={"name": "d", "filePath": str(f)} + ).json()["preset"] + client.post(f"/api/presets/{created['id']}/activate") + client.delete(f"/api/presets/{created['id']}") + listing = client.get("/api/presets").json() + assert listing["activePresetID"] is None + + def test_activate_preset(self, client: TestClient, data_dir: Path) -> None: + f = data_dir / "x.json" + f.write_text('{"mcpServers": {}}') + created = client.post( + "/api/presets", json={"name": "x", "filePath": str(f)} + ).json()["preset"] + response = client.post(f"/api/presets/{created['id']}/activate") + assert response.status_code == 200 + assert response.json()["activePresetID"] == created["id"] + assert client.get("/api/presets").json()["activePresetID"] == created["id"] + + def test_activate_unknown_preset_404(self, client: TestClient) -> None: + response = client.post("/api/presets/missing/activate") + assert response.status_code == 404 + + def test_activate_default_clears_active( + self, client: TestClient, data_dir: Path + ) -> None: + f = data_dir / "x.json" + f.write_text('{"mcpServers": {}}') + created = client.post( + "/api/presets", json={"name": "x", "filePath": str(f)} + ).json()["preset"] + client.post(f"/api/presets/{created['id']}/activate") + response = client.post("/api/presets/default/activate") + assert response.status_code == 200 + assert response.json()["activePresetID"] is None + + +# ── Config round-trip across endpoints ─────────────────────────────────────── + + +class TestConfigRoundTrip: + """Sanity check: server/tool toggles leave a valid readable config.""" + + def test_toggle_then_get_returns_updated_config( + self, client: TestClient + ) -> None: + client.post("/api/servers/alpha/toggle", json={"enabled": False}) + client.post( + "/api/tools/toggle", + json={"server": "alpha", "tool": "bad", "enabled": False}, + ) + body = client.get("/api/config").json() + assert body["mcpServers"]["alpha"]["enabled"] is False + assert body["mcpServers"]["alpha"]["disabledTools"] == ["bad"] diff --git a/tests/integration/test_api_errors.py b/tests/integration/test_api_errors.py new file mode 100644 index 0000000..ab51f0e --- /dev/null +++ b/tests/integration/test_api_errors.py @@ -0,0 +1,161 @@ +"""Error-path integration tests for the REST API. + +These specifically exercise the ``except Exception`` branches inside each +handler so the 500 responses don't rot. +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from unittest.mock import patch + +import pytest +from starlette.testclient import TestClient + +from jarvis.api import atomic_write, create_api_app + + +@pytest.fixture +def client(data_dir: Path, servers_json: Path) -> TestClient: + app = create_api_app(mcp_port=7070) + with TestClient(app) as client: + yield client + + +class TestConfigGetErrors: + def test_500_when_file_unreadable( + self, client: TestClient, servers_json: Path + ) -> None: + # Write a corrupt JSON file — the GET handler catches the decode error + servers_json.write_text("{ not json") + response = client.get("/api/config") + assert response.status_code == 500 + assert "error" in response.json() + + +class TestConfigPutErrors: + def test_500_when_atomic_write_fails( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + import jarvis.api as api_mod + + def boom(path: Path, data: dict) -> None: + raise OSError("disk full") + + monkeypatch.setattr(api_mod, "atomic_write", boom) + response = client.put("/api/config", json={"mcpServers": {}}) + assert response.status_code == 500 + assert "disk full" in response.json()["error"] + + +class TestToggleServerErrors: + def test_500_when_config_unreadable( + self, client: TestClient, servers_json: Path + ) -> None: + # corrupt JSON → json.loads raises inside the handler + servers_json.write_text("{ not json") + response = client.post( + "/api/servers/alpha/toggle", json={"enabled": False} + ) + assert response.status_code == 500 + + def test_500_when_body_missing(self, client: TestClient) -> None: + # No JSON body at all → ``request.json()`` raises → caught, 500 + response = client.post("/api/servers/alpha/toggle") + assert response.status_code == 500 + + +class TestToggleToolErrors: + def test_500_when_body_missing_keys(self, client: TestClient) -> None: + # Missing "server"/"tool" → KeyError → caught, 500 + response = client.post("/api/tools/toggle", json={}) + assert response.status_code == 500 + + +class TestPresetErrors: + def test_update_preset_with_bad_body_400( + self, client: TestClient, data_dir: Path + ) -> None: + f = data_dir / "p.json" + f.write_text("{}") + created = client.post( + "/api/presets", json={"name": "p", "filePath": str(f)} + ).json()["preset"] + # send invalid JSON body → request.json() raises + response = client.patch( + f"/api/presets/{created['id']}", + content=b"{ not json", + headers={"content-type": "application/json"}, + ) + assert response.status_code == 400 + + +class TestResolveConfigPathTraversal: + def test_unresolvable_path_hits_except_branch( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + """The try/except around ``Path.resolve()`` must catch nastiness.""" + import jarvis.api as api_mod + + original = Path + + class ExplodingPath(type(Path())): + def resolve(self, *a, **kw): # type: ignore[override] + raise RuntimeError("boom") + + def fake_path(value): + if isinstance(value, str) and value == "EXPLODE": + return ExplodingPath(".") + return original(value) + + monkeypatch.setattr(api_mod, "Path", fake_path) + response = client.get("/api/config?path=EXPLODE") + assert response.status_code == 400 + + +# ── atomic_write direct error-path tests ───────────────────────────────────── + + +class TestAtomicWriteErrors: + def test_cleans_up_tmp_when_write_fails(self, tmp_path: Path) -> None: + """Force ``os.fdopen`` to fail; the ``.tmp`` file must be removed.""" + target = tmp_path / "out.json" + real_fdopen = os.fdopen + calls = {"n": 0} + + def failing_fdopen(fd, *args, **kwargs): + calls["n"] += 1 + f = real_fdopen(fd, *args, **kwargs) + f.close() # close the fd properly before raising + raise OSError("write failed") + + with patch("jarvis.api.os.fdopen", side_effect=failing_fdopen): + with pytest.raises(OSError, match="write failed"): + atomic_write(target, {"a": 1}) + + assert calls["n"] == 1 + assert not target.exists() + assert list(tmp_path.glob("*.tmp")) == [] + + def test_survives_unlink_failure_during_cleanup( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """If the tmp-file cleanup itself fails with OSError, the original + exception must still propagate.""" + target = tmp_path / "out.json" + + import jarvis.api as api_mod + + def failing_replace(src, dst): + raise RuntimeError("replace failed") + + def failing_unlink(path): + raise OSError("unlink failed") + + monkeypatch.setattr(api_mod.os, "replace", failing_replace) + monkeypatch.setattr(api_mod.os, "unlink", failing_unlink) + + with pytest.raises(RuntimeError, match="replace failed"): + atomic_write(target, {"a": 1}) diff --git a/tests/integration/test_tui_auth_manager.py b/tests/integration/test_tui_auth_manager.py new file mode 100644 index 0000000..84f8502 --- /dev/null +++ b/tests/integration/test_tui_auth_manager.py @@ -0,0 +1,276 @@ +"""Integration tests for ``AuthManagerApp`` (the ``jarvis auth`` TUI). + +Uses Textual's ``run_test`` harness with a fake in-memory token store so +that no files in ``~/.jarvis/cache.db`` are touched. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from jarvis import config as config_mod +from jarvis import probe as probe_mod +from jarvis.tui import AuthManagerApp + + +class FakeCache: + def __init__(self, keys: list[str] | None = None) -> None: + self._keys = list(keys or []) + self.cleared = False + + def iterkeys(self): + return iter(self._keys) + + def clear(self) -> None: + self.cleared = True + self._keys.clear() + + +class FakeTokenStorage: + def __init__(self, keys: list[str] | None = None) -> None: + self._cache = FakeCache(keys) + + +@pytest.fixture +def fake_store(monkeypatch: pytest.MonkeyPatch) -> FakeTokenStorage: + store = FakeTokenStorage( + keys=["https://atlassian.example.com/mcp|token|abc"] + ) + monkeypatch.setattr(config_mod, "token_storage", store) + return store + + +@pytest.fixture +def auth_config(data_dir: Path) -> Path: + path = data_dir / "servers.json" + path.write_text( + json.dumps( + { + "mcpServers": { + "atlassian": { + "url": "https://atlassian.example.com/mcp", + "transport": "http", + "auth": "oauth", + }, + "github": { + "url": "https://github.example.com/mcp", + "transport": "http", + "auth": "oauth", + }, + "local": {"command": "echo", "args": ["hi"]}, + } + }, + indent=2, + ) + ) + return path + + +class TestAuthManagerPopulate: + async def test_lists_all_servers( + self, auth_config: Path, fake_store: FakeTokenStorage + ) -> None: + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + assert set(app._server_names) == {"atlassian", "github", "local"} + + async def test_shows_token_count_for_oauth_server( + self, auth_config: Path, fake_store: FakeTokenStorage + ) -> None: + from textual.widgets import DataTable + + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + table = app.query_one(DataTable) + # collect all cell strings from each row + rows = [] + for row_key in table.rows: + rows.append( + [str(c) for c in table.get_row(row_key)] + ) + flat = [cell for row in rows for cell in row] + # atlassian's url is in the fake store, so 1 token cached + assert any("1 token" in c for c in flat) + # github has no matching token → "none cached" + assert any("none cached" in c for c in flat) + # local is non-oauth → "N/A" + assert any("N/A" in c for c in flat) + + async def test_auth_type_uppercased( + self, auth_config: Path, fake_store: FakeTokenStorage + ) -> None: + from textual.widgets import DataTable + + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + table = app.query_one(DataTable) + rows = [ + [str(c) for c in table.get_row(row_key)] + for row_key in table.rows + ] + oauth_rows = [row for row in rows if "OAUTH" in row] + assert len(oauth_rows) == 2 + + +class TestAuthManagerLogout: + async def test_logout_clears_tokens_and_refreshes( + self, auth_config: Path, fake_store: FakeTokenStorage + ) -> None: + """On success the store is cleared and the table is re-populated + from scratch — note that the ``✓ All OAuth tokens cleared`` status + is immediately overwritten by ``_populate_table``'s hint line, which + is a pre-existing quirk of the TUI.""" + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + app.action_logout() + await pilot.pause(0.05) + assert fake_store._cache.cleared + # server list was rebuilt by _populate_table → same names, + # fresh ordering + assert set(app._server_names) == {"atlassian", "github", "local"} + + async def test_logout_reports_error( + self, + auth_config: Path, + fake_store: FakeTokenStorage, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + def boom() -> None: + raise RuntimeError("kaboom") + + monkeypatch.setattr(config_mod, "clear_tokens", boom) + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + app.action_logout() + await pilot.pause(0.05) + status = str(app.query_one("#status").render()) # type: ignore[union-attr] + assert "failed" in status.lower() + assert "kaboom" in status + + +class TestAuthManagerLogin: + async def test_login_for_non_oauth_server_is_noop( + self, auth_config: Path, fake_store: FakeTokenStorage + ) -> None: + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + # move cursor to the ``local`` row (non-oauth) + from textual.widgets import DataTable + + table = app.query_one(DataTable) + local_idx = app._server_names.index("local") + table.move_cursor(row=local_idx) + await pilot.pause(0.02) + + await app.action_login() + status = str(app.query_one("#status").render()) # type: ignore[union-attr] + assert "does not use oauth" in status.lower() + + async def test_login_success_refreshes_table( + self, + auth_config: Path, + fake_store: FakeTokenStorage, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """On success the worker calls ``probe_server`` with the selected + server's raw config and refreshes the table.""" + probe_calls: list[tuple[str, dict]] = [] + + async def fake_probe(name: str, raw: dict) -> list[dict]: + probe_calls.append((name, raw)) + return [{"name": "t1", "description": ""}, {"name": "t2", "description": ""}] + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + from textual.widgets import DataTable + + table = app.query_one(DataTable) + atl_idx = app._server_names.index("atlassian") + table.move_cursor(row=atl_idx) + await pilot.pause(0.02) + + await app.action_login() + # wait for the background worker to finish + for _ in range(40): + if probe_calls: + break + await pilot.pause(0.05) + # extra tick to let the post-probe table refresh complete + await pilot.pause(0.05) + + assert len(probe_calls) == 1 + assert probe_calls[0][0] == "atlassian" + assert probe_calls[0][1]["auth"] == "oauth" + # table was rebuilt after success + assert set(app._server_names) == {"atlassian", "github", "local"} + + async def test_login_failure_reports_error( + self, + auth_config: Path, + fake_store: FakeTokenStorage, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + async def fake_probe(name: str, raw: dict) -> list[dict]: + raise RuntimeError("auth denied") + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + from textual.widgets import DataTable + + table = app.query_one(DataTable) + atl_idx = app._server_names.index("atlassian") + table.move_cursor(row=atl_idx) + await pilot.pause(0.02) + + await app.action_login() + for _ in range(40): + status = str(app.query_one("#status").render()) # type: ignore[union-attr] + if "Auth failed" in status: + break + await pilot.pause(0.05) + status = str(app.query_one("#status").render()) # type: ignore[union-attr] + assert "Auth failed" in status + assert "auth denied" in status + + async def test_login_without_selection_is_noop( + self, + auth_config: Path, + fake_store: FakeTokenStorage, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + # empty config → no rows → ``_selected_server`` returns None + empty = auth_config.parent / "empty.json" + empty.write_text(json.dumps({"mcpServers": {}})) + app = AuthManagerApp(empty) + async with app.run_test() as pilot: + await pilot.pause(0.05) + # should just return, no crash + await app.action_login() + + +class TestAuthManagerQuit: + async def test_quit_exits_cleanly( + self, auth_config: Path, fake_store: FakeTokenStorage + ) -> None: + app = AuthManagerApp(auth_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + app.action_quit() + await pilot.pause(0.05) + assert not app.is_running diff --git a/tests/integration/test_tui_mcp_manager.py b/tests/integration/test_tui_mcp_manager.py new file mode 100644 index 0000000..9f00a67 --- /dev/null +++ b/tests/integration/test_tui_mcp_manager.py @@ -0,0 +1,443 @@ +"""Integration tests for ``MCPManagerApp`` (the ``jarvis mcp`` TUI). + +Drives the Textual app via ``run_test`` / ``Pilot``. ``probe_server`` is +stubbed so the app never opens sockets. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from jarvis import probe as probe_mod +from jarvis.tui import MCPManagerApp + + +async def _await_probe(pilot, attempts: int = 40) -> None: + """Wait until the background worker has replaced every ``probing…`` + placeholder with actual tool nodes. + + Raises ``AssertionError`` (with the list of servers that are still + probing) if the worker hasn't finished within ``attempts`` ticks, so + tests fail fast and loudly instead of silently passing with a half- + populated tree. + """ + app = pilot.app + tree = app.query_one("Tree") + still_probing: list[str] = [] + for _ in range(attempts): + still_probing = [] + for node in tree.root.children: + d = node.data + if d and d.get("type") == "server" and d.get("enabled", True): + labels = [str(c.label) for c in node.children] + if any("probing" in lbl for lbl in labels): + still_probing.append(d["name"]) + if not still_probing: + return + await pilot.pause(0.05) + raise AssertionError( + f"_await_probe timed out after {attempts} attempts " + f"(~{attempts * 0.05:.2f}s); still probing: {still_probing}" + ) + + +@pytest.fixture +def stub_probe(monkeypatch: pytest.MonkeyPatch): + """Return deterministic probe results for any server name.""" + + async def fake_probe(name: str, raw: dict) -> list[dict]: + return [ + {"name": "tool_a", "description": "first tool"}, + {"name": "tool_b", "description": "second tool"}, + ] + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + # The TUI imports ``probe_server`` lazily via + # ``from jarvis.probe import probe_server`` inside the worker, so + # patching the module attribute is sufficient. + return fake_probe + + +@pytest.fixture +def mcp_config(data_dir: Path) -> Path: + path = data_dir / "servers.json" + path.write_text( + json.dumps( + { + "mcpServers": { + "alpha": {"url": "http://alpha", "transport": "http"}, + "beta": { + "url": "http://beta", + "transport": "http", + "disabledTools": ["tool_b"], + }, + "gamma": { + "url": "http://gamma", + "transport": "http", + "enabled": False, + }, + } + }, + indent=2, + ) + ) + return path + + +class TestMCPManagerLoad: + async def test_populates_tree_from_config( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + tree = app.query_one("Tree") + server_names = { + n.data["name"] + for n in tree.root.children + if n.data and n.data.get("type") == "server" + } + assert server_names == {"alpha", "beta", "gamma"} + + async def test_initial_enabled_state_reflects_config( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + tree = app.query_one("Tree") + enabled_map = { + n.data["name"]: n.data["enabled"] + for n in tree.root.children + if n.data and n.data.get("type") == "server" + } + assert enabled_map == {"alpha": True, "beta": True, "gamma": False} + + async def test_probing_replaces_placeholder_with_tools( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + tree = app.query_one("Tree") + alpha_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "alpha" + ) + tool_labels = [str(c.label) for c in alpha_node.children] + assert any("tool_a" in lbl for lbl in tool_labels) + assert any("tool_b" in lbl for lbl in tool_labels) + + async def test_disabled_server_shows_hint( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + tree = app.query_one("Tree") + gamma_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "gamma" + ) + labels = [str(c.label) for c in gamma_node.children] + assert any("server disabled" in lbl for lbl in labels) + + +class TestMCPManagerToggleAndSave: + async def test_toggle_server_disables_and_saves( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + tree = app.query_one("Tree") + # move cursor to the alpha node (first child of root) + alpha_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "alpha" + ) + tree.select_node(alpha_node) + await pilot.pause(0.02) + + app.action_toggle_item() + await pilot.pause(0.02) + assert alpha_node.data["enabled"] is False + + app.action_quit_save() + await pilot.pause(0.05) + + saved = json.loads(mcp_config.read_text()) + assert saved["mcpServers"]["alpha"]["enabled"] is False + + async def test_quit_save_without_changes_preserves_config( + self, mcp_config: Path, stub_probe + ) -> None: + before = json.loads(mcp_config.read_text()) + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + app.action_quit_save() + await pilot.pause(0.05) + + after = json.loads(mcp_config.read_text()) + # enabled-true servers should still not carry an explicit "enabled" + assert after["mcpServers"]["alpha"] == before["mcpServers"]["alpha"] + + async def test_toggle_tool_updates_disabled_set( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + tree = app.query_one("Tree") + alpha_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "alpha" + ) + # first tool under alpha → tool_a (enabled) + tool_node = alpha_node.children[0] + tree.select_node(tool_node) + await pilot.pause(0.02) + + app.action_toggle_item() + await pilot.pause(0.02) + assert tool_node.data["enabled"] is False + assert "tool_a" in app._disabled_tools_cache["alpha"] + + app.action_quit_save() + await pilot.pause(0.05) + + saved = json.loads(mcp_config.read_text()) + assert saved["mcpServers"]["alpha"]["disabledTools"] == ["tool_a"] + + async def test_cannot_toggle_tool_under_disabled_server( + self, mcp_config: Path, stub_probe + ) -> None: + """beta starts enabled; disable it then try to toggle one of its tools.""" + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + tree = app.query_one("Tree") + beta_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "beta" + ) + # disable the server + tree.select_node(beta_node) + await pilot.pause(0.02) + app.action_toggle_item() # beta now disabled + await pilot.pause(0.02) + + # cursor still on beta; attempt to toggle a tool — nothing should + # happen (and the status bar should warn). The tools are still + # children from the earlier probe. + if beta_node.children: + tool_node = beta_node.children[0] + before = tool_node.data["enabled"] + tree.select_node(tool_node) + await pilot.pause(0.02) + app.action_toggle_item() + await pilot.pause(0.02) + # state must be unchanged + assert tool_node.data["enabled"] == before + + async def test_enabling_previously_disabled_server_removes_key( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await pilot.pause(0.05) + tree = app.query_one("Tree") + gamma_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "gamma" + ) + tree.select_node(gamma_node) + await pilot.pause(0.02) + app.action_toggle_item() + await pilot.pause(0.02) + app.action_quit_save() + await pilot.pause(0.05) + + saved = json.loads(mcp_config.read_text()) + assert "enabled" not in saved["mcpServers"]["gamma"] + + +class TestMCPManagerRefresh: + async def test_refresh_resets_probed_tools( + self, mcp_config: Path, stub_probe + ) -> None: + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + tree = app.query_one("Tree") + alpha_node = next( + n + for n in tree.root.children + if n.data and n.data.get("name") == "alpha" + ) + assert alpha_node.data["probed_tools"] + + app.action_refresh() + # action_refresh immediately wipes probed_tools and re-adds a + # "probing…" placeholder; the worker then re-populates. + assert alpha_node.data["probed_tools"] == [] + await _await_probe(pilot) + assert alpha_node.data["probed_tools"] + + +class TestMCPManagerEdgeCases: + async def test_toggle_with_no_cursor_is_noop( + self, data_dir: Path, stub_probe + ) -> None: + """An empty tree has no cursor node — action_toggle_item must early- + return instead of crashing.""" + empty = data_dir / "empty.json" + empty.write_text(json.dumps({"mcpServers": {}})) + app = MCPManagerApp(empty) + async with app.run_test() as pilot: + await pilot.pause(0.05) + # no cursor node at all + app.action_toggle_item() + + async def test_tool_toggle_blocked_when_parent_disabled_sets_status( + self, mcp_config: Path, stub_probe + ) -> None: + """Select a tool under a server, then disable the server, then select + the tool again and attempt to toggle — the status bar should warn.""" + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + tree = app.query_one("Tree") + alpha = next( + n for n in tree.root.children if n.data and n.data.get("name") == "alpha" + ) + tool = alpha.children[0] + # disable alpha directly via its node data — simulates the user + # having disabled it while the cursor sat on the tool + alpha.data["enabled"] = False + + tree.select_node(tool) + await pilot.pause(0.02) + + before = tool.data["enabled"] + app.action_toggle_item() + await pilot.pause(0.02) + + assert tool.data["enabled"] == before # unchanged + status = str(app.query_one("#status").render()) + assert "Enable the server first" in status + + async def test_toggle_tool_off_then_on_uses_discard_path( + self, mcp_config: Path, stub_probe + ) -> None: + """Hits line 254 — ``disabled.discard(tool_name)`` when re-enabling + a tool that *is* in the disabled set.""" + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + tree = app.query_one("Tree") + alpha = next( + n for n in tree.root.children if n.data and n.data.get("name") == "alpha" + ) + tool = alpha.children[0] + tree.select_node(tool) + await pilot.pause(0.02) + app.action_toggle_item() # disable + await pilot.pause(0.02) + assert tool.data["name"] in app._disabled_tools_cache["alpha"] + + app.action_toggle_item() # re-enable + await pilot.pause(0.02) + assert tool.data["name"] not in app._disabled_tools_cache["alpha"] + + async def test_save_skips_server_nodes_missing_from_config( + self, mcp_config: Path, stub_probe + ) -> None: + """Hits line 101 — when a tree node references a server name that's + been removed from ``raw_config`` under our feet, ``_save_config`` + must skip it without crashing.""" + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + # yank alpha out of the in-memory config + del app.raw_config["mcpServers"]["alpha"] + app.action_quit_save() + await pilot.pause(0.05) + + saved = json.loads(mcp_config.read_text()) + assert "alpha" not in saved["mcpServers"] + assert "beta" in saved["mcpServers"] + + async def test_probe_failure_yields_empty_tool_list( + self, + mcp_config: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Hits lines 209-214 — an exception in ``probe_server`` inside the + background worker must downgrade to an empty tool list for that + server without killing the app.""" + + async def angry_probe(name: str, raw: dict) -> list[dict]: + if name == "alpha": + raise RuntimeError("nope") + return [{"name": "only_beta_tool", "description": ""}] + + monkeypatch.setattr(probe_mod, "probe_server", angry_probe) + + app = MCPManagerApp(mcp_config) + async with app.run_test() as pilot: + await _await_probe(pilot) + + tree = app.query_one("Tree") + alpha = next( + n for n in tree.root.children if n.data and n.data.get("name") == "alpha" + ) + # no tool children on alpha — probe failed + assert list(alpha.children) == [] + # but beta got its tool + beta = next( + n for n in tree.root.children if n.data and n.data.get("name") == "beta" + ) + beta_labels = [str(c.label) for c in beta.children] + assert any("only_beta_tool" in lbl for lbl in beta_labels) + + +class TestLoadConfigInMCPManager: + async def test_parse_error_does_not_crash_app( + self, data_dir: Path, stub_probe + ) -> None: + """A malformed config must not crash ``MCPManagerApp`` on mount. + + The parse-error status set by ``on_mount`` is immediately overwritten + by ``_populate_tree``'s "No servers configured." and then by + ``_probe_all``'s "No enabled servers." (a known pre-existing quirk + of the TUI), so the final observable status is the latter. + """ + bad = data_dir / "bad.json" + bad.write_text("{ not json") + app = MCPManagerApp(bad) + async with app.run_test() as pilot: + await pilot.pause(0.1) + status = str(app.query_one("#status").render()) # type: ignore[union-attr] + # the final status must be exactly one of the two terminal + # messages emitted by _populate_tree / _probe_all + assert status in ( + "No servers configured.", + "No enabled servers.", + ), f"unexpected status: {status!r}" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_api_helpers.py b/tests/unit/test_api_helpers.py new file mode 100644 index 0000000..b226541 --- /dev/null +++ b/tests/unit/test_api_helpers.py @@ -0,0 +1,113 @@ +"""Unit tests for low-level helpers in ``jarvis.api``.""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path + +import pytest + +from jarvis.api import atomic_write, config_locks, get_lock + + +class TestAtomicWrite: + def test_writes_json_with_indent(self, tmp_path: Path) -> None: + target = tmp_path / "out.json" + atomic_write(target, {"a": 1, "b": [2, 3]}) + content = target.read_text() + assert json.loads(content) == {"a": 1, "b": [2, 3]} + # indent=2 ⇒ multi-line output + assert "\n" in content + + def test_overwrites_existing_file(self, tmp_path: Path) -> None: + target = tmp_path / "out.json" + target.write_text('{"old": true}') + atomic_write(target, {"new": True}) + assert json.loads(target.read_text()) == {"new": True} + + def test_cleans_up_temp_file_on_serialization_error( + self, tmp_path: Path + ) -> None: + target = tmp_path / "out.json" + # sets are not JSON-serialisable — json.dumps raises before os.replace + with pytest.raises(TypeError): + atomic_write(target, {"bad": {1, 2, 3}}) + # the original file should not exist … + assert not target.exists() + # … and no orphan ``.tmp`` files should be left behind + leftovers = list(tmp_path.glob("*.tmp")) + assert leftovers == [] + + +class TestGetLock: + def setup_method(self) -> None: + config_locks.clear() + + def test_returns_same_lock_for_same_path(self, tmp_path: Path) -> None: + p = tmp_path / "c.json" + assert get_lock(p) is get_lock(p) + + def test_different_paths_get_different_locks(self, tmp_path: Path) -> None: + a = tmp_path / "a.json" + b = tmp_path / "b.json" + assert get_lock(a) is not get_lock(b) + + def test_resolves_path_before_keying(self, tmp_path: Path) -> None: + # ``foo/./c.json`` resolves to the same path as ``foo/c.json`` + p1 = tmp_path / "c.json" + p2 = tmp_path / "." / "c.json" + assert get_lock(p1) is get_lock(p2) + + def test_lock_is_asyncio_lock(self, tmp_path: Path) -> None: + assert isinstance(get_lock(tmp_path / "x.json"), asyncio.Lock) + + +class TestStartApiThread: + def test_starts_uvicorn_in_daemon_thread( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """``start_api_thread`` must spawn uvicorn on the correct port as a + daemon thread and not block the caller. + """ + import sys + import threading + import types + + captured: dict = {} + started = threading.Event() + + def fake_run(**kwargs) -> None: + captured.update(kwargs) + # Record the thread uvicorn.run was actually invoked on — the + # point of ``start_api_thread`` is that this is *not* the main + # thread and that the thread is a daemon. + current = threading.current_thread() + captured["thread"] = current + captured["is_main_thread"] = current is threading.main_thread() + captured["is_daemon"] = current.daemon + started.set() + + fake_uvicorn = types.SimpleNamespace(run=fake_run) + monkeypatch.setitem(sys.modules, "uvicorn", fake_uvicorn) + + from jarvis.api import start_api_thread + + start_api_thread(mcp_port=7070, api_port=7071) + assert started.wait(timeout=2), "uvicorn.run never called" + + # kwargs passed to uvicorn.run + assert captured["host"] == "127.0.0.1" + assert captured["port"] == 7071 + assert captured["log_level"] == "error" + assert captured["app"] is not None + + # thread identity + daemon flag + assert captured["is_main_thread"] is False, ( + "uvicorn.run was invoked on the main thread — start_api_thread " + "should have spawned a background thread" + ) + assert captured["is_daemon"] is True, ( + "background thread must be a daemon so it doesn't outlive the " + "main process" + ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..6249f1e --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,277 @@ +"""Unit tests for ``jarvis.config``.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from jarvis import config as config_mod +from jarvis.config import ( + active_config_from_presets, + configure_servers, + expand_env_vars, + get_disabled_tools, + load_presets, + load_raw_config, + save_presets, +) + + +# ── expand_env_vars ─────────────────────────────────────────────────────────── + + +class TestExpandEnvVars: + def test_substitutes_known_variable(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MY_KEY", "secret") + assert expand_env_vars("token=${MY_KEY}") == "token=secret" + + def test_leaves_unknown_variable_untouched( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("DEFINITELY_NOT_SET", raising=False) + assert expand_env_vars("x=${DEFINITELY_NOT_SET}") == "x=${DEFINITELY_NOT_SET}" + + def test_substitutes_multiple_variables( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("A", "1") + monkeypatch.setenv("B", "2") + assert expand_env_vars("${A}-${B}") == "1-2" + + def test_no_placeholders_is_pass_through(self) -> None: + assert expand_env_vars("plain value") == "plain value" + + +# ── load_raw_config ─────────────────────────────────────────────────────────── + + +class TestLoadRawConfig: + def test_filters_disabled_servers(self, servers_json: Path) -> None: + mcp_dict, raw_servers = load_raw_config(servers_json) + # ``gamma`` has ``"enabled": false`` in the fixture + assert set(mcp_dict["mcpServers"].keys()) == {"alpha", "beta"} + assert set(raw_servers.keys()) == {"alpha", "beta"} + + def test_strips_non_standard_keys(self, servers_json: Path) -> None: + mcp_dict, _ = load_raw_config(servers_json) + for srv in mcp_dict["mcpServers"].values(): + assert "enabled" not in srv + assert "disabledTools" not in srv + + def test_preserves_standard_keys(self, servers_json: Path) -> None: + mcp_dict, _ = load_raw_config(servers_json) + assert mcp_dict["mcpServers"]["alpha"]["url"] == "https://alpha.example.com/mcp" + assert mcp_dict["mcpServers"]["beta"]["command"] == "echo" + assert mcp_dict["mcpServers"]["beta"]["args"] == ["hello"] + + def test_enabled_true_is_kept(self, data_dir: Path) -> None: + path = data_dir / "s.json" + path.write_text( + json.dumps( + {"mcpServers": {"x": {"url": "http://x", "enabled": True}}} + ) + ) + mcp_dict, _raw = load_raw_config(path) + assert "x" in mcp_dict["mcpServers"] + assert "enabled" not in mcp_dict["mcpServers"]["x"] + + +# ── get_disabled_tools ──────────────────────────────────────────────────────── + + +class TestGetDisabledTools: + def test_returns_prefixed_names(self, servers_json: Path) -> None: + assert get_disabled_tools(servers_json) == {"beta_noisy"} + + def test_skips_disabled_servers(self, data_dir: Path) -> None: + path = data_dir / "s.json" + path.write_text( + json.dumps( + { + "mcpServers": { + "off": { + "url": "http://off", + "enabled": False, + "disabledTools": ["a", "b"], + } + } + } + ) + ) + assert get_disabled_tools(path) == set() + + def test_empty_when_no_disabled_tools(self, data_dir: Path) -> None: + path = data_dir / "s.json" + path.write_text( + json.dumps({"mcpServers": {"x": {"url": "http://x"}}}) + ) + assert get_disabled_tools(path) == set() + + +# ── Preset management ──────────────────────────────────────────────────────── + + +class TestPresets: + def test_load_presets_returns_empty_when_missing(self, data_dir: Path) -> None: + assert load_presets() == {"presets": [], "activePresetID": None} + + def test_save_and_load_roundtrip(self, data_dir: Path) -> None: + payload = { + "presets": [ + { + "id": "1", + "name": "work", + "filePath": str(data_dir / "a.json"), + } + ], + "activePresetID": "1", + } + save_presets(payload) + assert load_presets() == payload + + def test_save_creates_data_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + nested = tmp_path / "nested" / "jarvis" + # do *not* pre-create the dir; save_presets must do it + monkeypatch.setattr(config_mod, "DATA_DIR", nested) + monkeypatch.setattr(config_mod, "PRESETS_PATH", nested / "presets.json") + save_presets({"presets": [], "activePresetID": None}) + assert (nested / "presets.json").exists() + + +# ── active_config_from_presets ─────────────────────────────────────────────── + + +class TestActiveConfigFromPresets: + def test_returns_default_when_no_presets(self, data_dir: Path) -> None: + path = active_config_from_presets() + assert path == data_dir / "servers.json" + # default file should have been created + assert path.exists() + assert json.loads(path.read_text()) == {"mcpServers": {}} + + def test_returns_preset_file_when_active(self, data_dir: Path) -> None: + preset_file = data_dir / "work.json" + preset_file.write_text('{"mcpServers": {}}') + save_presets( + { + "presets": [ + {"id": "p1", "name": "work", "filePath": str(preset_file)} + ], + "activePresetID": "p1", + } + ) + assert active_config_from_presets() == preset_file + + def test_falls_back_to_default_if_preset_file_missing( + self, data_dir: Path + ) -> None: + save_presets( + { + "presets": [ + { + "id": "p1", + "name": "work", + "filePath": str(data_dir / "missing.json"), + } + ], + "activePresetID": "p1", + } + ) + assert active_config_from_presets() == data_dir / "servers.json" + + def test_falls_back_to_default_if_id_not_found(self, data_dir: Path) -> None: + save_presets( + { + "presets": [{"id": "p1", "name": "w", "filePath": "/nope"}], + "activePresetID": "bogus", + } + ) + assert active_config_from_presets() == data_dir / "servers.json" + + +# ── configure_servers ──────────────────────────────────────────────────────── + + +class TestConfigureServers: + def test_expands_env_in_env_values( + self, + data_dir: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + from fastmcp.mcp_config import MCPConfig + + monkeypatch.setenv("MY_TOKEN", "abc123") + cfg = MCPConfig.model_validate( + { + "mcpServers": { + "s": { + "command": "echo", + "args": ["hi"], + "env": {"TOKEN": "${MY_TOKEN}", "LITERAL": "plain"}, + } + } + } + ) + configure_servers(cfg) + assert cfg.mcpServers["s"].env == {"TOKEN": "abc123", "LITERAL": "plain"} + + def test_no_env_is_noop(self) -> None: + from fastmcp.mcp_config import MCPConfig + + cfg = MCPConfig.model_validate( + {"mcpServers": {"s": {"command": "echo", "args": []}}} + ) + configure_servers(cfg) + # should not raise, env stays None/empty + assert not getattr(cfg.mcpServers["s"], "env", None) + + def test_oauth_server_gets_oauth_client( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from fastmcp.mcp_config import MCPConfig + + captured: dict = {} + + class FakeOAuth: + def __init__(self, **kwargs) -> None: + captured.update(kwargs) + + monkeypatch.setattr(config_mod, "OAuth", FakeOAuth) + + cfg = MCPConfig.model_validate( + { + "mcpServers": { + "o": { + "url": "https://o.example.com/mcp", + "transport": "http", + "auth": "oauth", + } + } + } + ) + configure_servers(cfg) + assert isinstance(cfg.mcpServers["o"].auth, FakeOAuth) + assert captured["callback_port"] == 9876 + assert captured["client_name"] == "Jarvis Proxy" + assert captured["token_storage"] is config_mod.token_storage + + +# ── clear_tokens ───────────────────────────────────────────────────────────── + + +class TestClearTokens: + def test_calls_cache_clear(self, monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + + class FakeCache: + def clear(self) -> None: + calls.append("clear") + + class FakeStore: + _cache = FakeCache() + + monkeypatch.setattr(config_mod, "token_storage", FakeStore()) + config_mod.clear_tokens() + assert calls == ["clear"] diff --git a/tests/unit/test_probe.py b/tests/unit/test_probe.py new file mode 100644 index 0000000..c36c956 --- /dev/null +++ b/tests/unit/test_probe.py @@ -0,0 +1,303 @@ +"""Unit tests for ``jarvis.probe`` (network-free). + +``probe_server`` itself actually connects to a backend, so we don't exercise it +directly here. Instead we verify the helpers (``free_port``, the warning +filter) and ``probe_all_servers`` with ``probe_server`` monkeypatched. +""" + +from __future__ import annotations + +import logging +import socket + +import pytest + +from jarvis import probe as probe_mod +from jarvis.probe import SuppressMcpSessionWarning, free_port, probe_all_servers + + +# ── free_port ──────────────────────────────────────────────────────────────── + + +class TestFreePort: + def test_returns_port_in_valid_range(self) -> None: + port = free_port() + assert 1 <= port <= 65535 + + def test_port_is_actually_bindable(self) -> None: + port = free_port() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", port)) + + def test_consecutive_calls_return_different_ports(self) -> None: + # Not strictly guaranteed by the OS but overwhelmingly likely; this + # protects against returning a hard-coded constant. + ports = {free_port() for _ in range(5)} + assert len(ports) > 1 + + +# ── SuppressMcpSessionWarning ──────────────────────────────────────────────── + + +class TestSuppressMcpSessionWarning: + def _make_record(self, exc: BaseException | None) -> logging.LogRecord: + record = logging.LogRecord( + name="fastmcp.client.transports.config", + level=logging.WARNING, + pathname=__file__, + lineno=1, + msg="Failed to connect", + args=(), + exc_info=(type(exc), exc, None) if exc else None, + ) + return record + + def test_demotes_mcp_error_warning_to_debug(self) -> None: + from mcp import McpError + from mcp.types import ErrorData + + err = McpError(ErrorData(code=-32000, message="boom")) + record = self._make_record(err) + assert SuppressMcpSessionWarning().filter(record) is True + assert record.levelno == logging.DEBUG + assert record.levelname == "DEBUG" + + def test_unrelated_warning_is_unchanged(self) -> None: + record = self._make_record(RuntimeError("unrelated")) + assert SuppressMcpSessionWarning().filter(record) is True + assert record.levelno == logging.WARNING + + def test_warning_without_exc_info_is_unchanged(self) -> None: + record = self._make_record(None) + assert SuppressMcpSessionWarning().filter(record) is True + assert record.levelno == logging.WARNING + + +# ── probe_all_servers ──────────────────────────────────────────────────────── + + +class TestProbeAllServers: + async def test_returns_results_per_server( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + async def fake_probe(name: str, raw: dict) -> list[dict]: + return [{"name": f"{name}_tool", "description": ""}] + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + + result = await probe_all_servers( + { + "a": {"url": "http://a"}, + "b": {"url": "http://b"}, + } + ) + assert result == { + "a": [{"name": "a_tool", "description": ""}], + "b": [{"name": "b_tool", "description": ""}], + } + + async def test_failed_probe_yields_empty_list( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + async def fake_probe(name: str, raw: dict) -> list[dict]: + if name == "bad": + raise RuntimeError("boom") + return [{"name": "t", "description": ""}] + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + + result = await probe_all_servers( + {"good": {"url": "http://g"}, "bad": {"url": "http://b"}} + ) + assert result["good"] == [{"name": "t", "description": ""}] + assert result["bad"] == [] + + async def test_timeout_yields_empty_list( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + import asyncio + + async def hanging_probe(name: str, raw: dict) -> list[dict]: + await asyncio.sleep(10) + return [] + + monkeypatch.setattr(probe_mod, "probe_server", hanging_probe) + + result = await probe_all_servers({"slow": {"url": "http://s"}}, timeout=0.05) + assert result == {"slow": []} + + async def test_empty_input(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def fake_probe(name: str, raw: dict) -> list[dict]: # pragma: no cover + raise AssertionError("should not be called") + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + assert await probe_all_servers({}) == {} + + async def test_base_exception_still_caught_as_empty( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # A BaseException subclass that is *not* KeyboardInterrupt/SystemExit/ + # GeneratorExit must still produce an empty list for that server. + class Weird(BaseException): + pass + + async def fake_probe(name: str, raw: dict) -> list[dict]: + raise Weird("weird") + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + result = await probe_all_servers({"x": {"url": "http://x"}}) + assert result == {"x": []} + + async def test_prints_probe_failure_to_stderr( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] + ) -> None: + async def fake_probe(name: str, raw: dict) -> list[dict]: + raise ValueError("bad things") + + monkeypatch.setattr(probe_mod, "probe_server", fake_probe) + await probe_all_servers({"myserver": {"url": "http://x"}}) + err = capsys.readouterr().err + assert "[myserver]" in err + assert "ValueError" in err + assert "bad things" in err + + +# ── silence() context manager ──────────────────────────────────────────────── + + +class TestSilence: + def test_redirects_stderr_and_restores_it( + self, data_dir, monkeypatch: pytest.MonkeyPatch + ) -> None: + # ``silence`` writes to DATA_DIR / "jarvis.log" — the ``data_dir`` + # fixture already points DATA_DIR at an isolated temp dir. + import sys + from jarvis.probe import silence + + original_stderr = sys.stderr + with silence(): + assert sys.stderr is not original_stderr + print("swallowed stderr", file=sys.stderr) + assert sys.stderr is original_stderr + + log_path = data_dir / "jarvis.log" + assert log_path.exists() + assert "swallowed stderr" in log_path.read_text() + + def test_removes_log_handler_on_exit(self, data_dir) -> None: + import logging + from jarvis.probe import silence + + root = logging.getLogger() + before = list(root.handlers) + with silence(): + assert len(root.handlers) == len(before) + 1 + assert list(root.handlers) == before + + def test_cleans_up_on_exception(self, data_dir) -> None: + import logging + import sys + from jarvis.probe import silence + + original_stderr = sys.stderr + root = logging.getLogger() + before = list(root.handlers) + + with pytest.raises(RuntimeError, match="boom"): + with silence(): + raise RuntimeError("boom") + + assert sys.stderr is original_stderr + assert list(root.handlers) == before + + +# ── probe_server() ─────────────────────────────────────────────────────────── + + +class TestProbeServer: + async def test_returns_tool_list_with_prefix_stripped( + self, monkeypatch: pytest.MonkeyPatch, data_dir + ) -> None: + """``probe_server`` should call ``create_proxy`` on a single-server + config, then list_tools() and strip the ``{name}_`` prefix.""" + from types import SimpleNamespace + from jarvis import probe as probe_mod_inner + + class FakeProxy: + async def list_tools(self) -> list: + return [ + SimpleNamespace(name="myserver_alpha", description="first"), + SimpleNamespace(name="myserver_beta", description=None), + SimpleNamespace(name="unprefixed", description="c"), + ] + + captured: dict = {} + + def fake_create_proxy(cfg, *, name: str): + captured["cfg"] = cfg + captured["name"] = name + return FakeProxy() + + monkeypatch.setattr(probe_mod_inner, "create_proxy", fake_create_proxy) + # avoid configure_servers running for real (not needed, but safer) + monkeypatch.setattr(probe_mod_inner, "configure_servers", lambda cfg: None) + + result = await probe_mod_inner.probe_server( + "myserver", {"url": "http://x", "transport": "http"} + ) + assert result == [ + {"name": "alpha", "description": "first"}, + {"name": "beta", "description": ""}, + {"name": "unprefixed", "description": "c"}, + ] + assert captured["name"] == "probe_myserver" + assert "myserver" in captured["cfg"].mcpServers + + async def test_oauth_server_uses_free_port( + self, monkeypatch: pytest.MonkeyPatch, data_dir + ) -> None: + """OAuth servers should receive an OAuth client with a *free* callback + port, not the hard-coded 9876 that the long-running server uses.""" + from types import SimpleNamespace + from jarvis import probe as probe_mod_inner + + captured_oauth: dict = {} + + class FakeOAuth: + def __init__(self, **kwargs) -> None: + captured_oauth.update(kwargs) + + class FakeProxy: + async def list_tools(self) -> list: + return [] + + monkeypatch.setattr(probe_mod_inner, "OAuth", FakeOAuth) + monkeypatch.setattr(probe_mod_inner, "create_proxy", lambda cfg, *, name: FakeProxy()) + monkeypatch.setattr(probe_mod_inner, "free_port", lambda: 55555) + + result = await probe_mod_inner.probe_server( + "oauth_server", + {"url": "http://o", "transport": "http", "auth": "oauth"}, + ) + assert result == [] + assert captured_oauth["callback_port"] == 55555 + assert captured_oauth["client_name"] == "Jarvis Proxy" + + async def test_system_exit_is_converted_to_oserror( + self, monkeypatch: pytest.MonkeyPatch, data_dir + ) -> None: + """``list_tools`` raising SystemExit (e.g. uvicorn bail-out) must + be surfaced as OSError, not propagated as SystemExit.""" + from jarvis import probe as probe_mod_inner + + class FakeProxy: + async def list_tools(self) -> list: + raise SystemExit(1) + + monkeypatch.setattr(probe_mod_inner, "create_proxy", lambda cfg, *, name: FakeProxy()) + monkeypatch.setattr(probe_mod_inner, "configure_servers", lambda cfg: None) + + with pytest.raises(OSError, match="uvicorn exited"): + await probe_mod_inner.probe_server( + "x", {"url": "http://x", "transport": "http"} + ) diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py new file mode 100644 index 0000000..c976bff --- /dev/null +++ b/tests/unit/test_proxy.py @@ -0,0 +1,107 @@ +"""Unit tests for jarvis.proxy.build_proxy.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from fastmcp.mcp_config import MCPConfig +from fastmcp.server import FastMCP + + +def _make_config() -> MCPConfig: + return MCPConfig.model_validate( + { + "mcpServers": { + "gl": {"command": "npx", "args": ["-y", "some-mcp"]}, + "remote": { + "url": "https://remote.example.com/mcp", + "transport": "http", + }, + } + } + ) + + +def test_build_proxy_returns_fastmcp(): + from jarvis.proxy import build_proxy + + with ( + patch("jarvis.proxy.StatefulProxyClient"), + patch("jarvis.proxy.ProxyClient"), + patch("jarvis.proxy.ProxyProvider"), + ): + result = build_proxy(_make_config(), name="test") + assert isinstance(result, FastMCP) + + +def test_build_proxy_uses_stateful_for_stdio(): + from jarvis.proxy import build_proxy + + with ( + patch("jarvis.proxy.StatefulProxyClient") as mock_stateful, + patch("jarvis.proxy.ProxyClient") as mock_proxy, + patch("jarvis.proxy.ProxyProvider"), + ): + build_proxy(_make_config(), name="test") + + # stdio server "gl" → StatefulProxyClient called once + assert mock_stateful.call_count == 1 + # http server "remote" → ProxyClient called once + assert mock_proxy.call_count == 1 + + +def test_build_proxy_uses_new_stateful_as_factory_for_stdio(): + from jarvis.proxy import build_proxy + from fastmcp.server.providers.proxy import ProxyProvider + + captured_factories = [] + real_init = ProxyProvider.__init__ + + def capturing_init(self, client_factory, **kwargs): + captured_factories.append(client_factory) + real_init(self, client_factory, **kwargs) + + with ( + patch.object(ProxyProvider, "__init__", capturing_init), + patch("jarvis.proxy.StatefulProxyClient") as mock_stateful, + patch("jarvis.proxy.ProxyClient") as mock_proxy, + ): + mock_stateful_instance = MagicMock() + mock_stateful.return_value = mock_stateful_instance + mock_proxy_instance = MagicMock() + mock_proxy.return_value = mock_proxy_instance + build_proxy(_make_config(), name="test") + + # Two providers added: one per server + assert len(captured_factories) == 2 + # Relies on dict insertion order (Python 3.7+): "gl" first, "remote" second. + # Factory for stdio server must be new_stateful bound method + assert captured_factories[0] == mock_stateful_instance.new_stateful + # Factory for http server must be new bound method + assert captured_factories[1] == mock_proxy_instance.new + + +def test_build_proxy_adds_provider_per_server(): + from jarvis.proxy import build_proxy + from fastmcp.server import FastMCP + + added = [] + real_add = FastMCP.add_provider + + def capturing_add(self, provider, *, namespace=""): + added.append(namespace) + real_add(self, provider, namespace=namespace) + + with ( + patch.object(FastMCP, "add_provider", capturing_add), + patch("jarvis.proxy.StatefulProxyClient"), + patch("jarvis.proxy.ProxyClient"), + patch("jarvis.proxy.ProxyProvider"), + ): + build_proxy(_make_config(), name="test") + + # FastMCP.__init__ calls add_provider once internally (namespace="") for the + # local provider, so we expect 3 total: 1 from init + 2 from build_proxy. + named = [ns for ns in added if ns] + assert len(named) == 2 + assert set(named) == {"gl", "remote"} diff --git a/tests/unit/test_tui.py b/tests/unit/test_tui.py new file mode 100644 index 0000000..702a96c --- /dev/null +++ b/tests/unit/test_tui.py @@ -0,0 +1,36 @@ +"""Unit tests for pure helpers in ``jarvis.tui``. + +We avoid instantiating the Textual ``App`` classes — those require a running +event loop and a terminal. Instead we exercise the module-level ``load_config`` +helper which is the only piece of non-UI logic. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +from jarvis.tui import load_config + + +class TestLoadConfig: + def test_reads_existing_config(self, tmp_path: Path) -> None: + path = tmp_path / "c.json" + payload = {"mcpServers": {"x": {"url": "http://x"}}} + path.write_text(json.dumps(payload)) + raw, err = load_config(path) + assert raw == payload + assert err is None + + def test_missing_file_returns_empty_no_error(self, tmp_path: Path) -> None: + raw, err = load_config(tmp_path / "nope.json") + assert raw == {"mcpServers": {}} + assert err is None + + def test_invalid_json_returns_error(self, tmp_path: Path) -> None: + path = tmp_path / "bad.json" + path.write_text("{ not json") + raw, err = load_config(path) + assert raw == {"mcpServers": {}} + assert err is not None + assert "parse error" in err.lower()