diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c3ea5a..858d0aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,6 +30,9 @@ jobs: - name: Install Vaara (editable, no deps) run: pip install -e . --no-deps + - name: Install server extra (fastapi, uvicorn, httpx) for HTTP transport tests + run: pip install 'fastapi>=0.110' 'uvicorn>=0.27' 'httpx>=0.27' + - name: Lint (ruff) run: ruff check . diff --git a/CHANGELOG.md b/CHANGELOG.md index fd148a3..f13d179 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,97 @@ and this project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.ht ## [Unreleased] +## [0.40.0] - 2026-05-28 + +**Theme: deployment shape. One Vaara process now serves a fleet of +upstream MCP servers, with multi-tenant policy, audit, and attestation +on the same substrate.** + +The v0.39 sidecar shape ran one Vaara process per upstream. v0.40 +turns that into a single process that speaks Streamable HTTP, holds +N upstream MCP-server connections, picks the upstream per request +from a header, scopes every score, audit record, and OVERT envelope +to a tenant, and reloads per-tenant policy in place. + +### Added +- `vaara-mcp-proxy --transport http --http-host H --http-port P`: + Streamable HTTP transport at `POST /mcp`, backed by FastAPI / + uvicorn (the `vaara[server]` extra already shipped in v0.39 for + `vaara serve`). The endpoint reads `X-Vaara-Tenant` and + `X-Vaara-Upstream` per request, pushes them into ContextVars, and + dispatches into the existing `_handle_request` path so the policy, + perimeter, OVERT, and progress-notification handling all light up + unchanged. Notifications (no JSON-RPC `id`) return 202 Accepted. + Bodies above 1 MiB return 413. +- `vaara-mcp-proxy --upstream NAME=CMD` (repeatable) for fan-out. + One Vaara process holds N `UpstreamMCPClient` instances in a name + -> client map. Bare `--upstream CMD` keeps the v0.39 single- + upstream contract; it lands in the "default" slot. Commands that + themselves contain `=` (e.g. `python -m foo --bar=baz`) stay + intact because the name-side regex only matches short alphanumeric + slugs. When more than one upstream is configured, a request with + no `X-Vaara-Upstream` header returns 400 with the list of valid + slots; silent fallback to whichever slot won the sort would be a + failure mode that surfaces only in production. Single-upstream + deployments keep the silent-default contract. +- `tenant_id` is first-class through the request, decision, audit, + and attestation layers: + - `ScoreRequest`, `AuditEventRequest`, and `PolicyReloadRequest` + accept a `tenant_id` body field, with `X-Vaara-Tenant` as the + HTTP-header alternative. Body wins over header. + - `AuditRecord` gains a `tenant_id` field, excluded from + `compute_hash()` so pre-v0.40 chains still re-verify on load. + - `AuditTrail` keeps an `action_id -> tenant_id` map seeded by + `record_action_requested`, so every follow-up record + (`risk_scored`, `decision`, `execution`, `escalation`, + `outcome`, `policy_override`) inherits the same scope without + every caller threading `tenant_id` through every signature. + The map is soft-capped (50k entries, 12.5% eviction on + pressure) so long-running deployments cannot leak memory. + - `SQLiteAuditBackend.write_record` prefers the per-record + `tenant_id` when set, with the instance-scoped `tenant_id` + (legacy CLI tooling path) as fallback. A single backend + instance can now serve a multi-tenant runtime. + - OVERT envelopes carry `tenant_id` as a `non_content_metadata` + claim when present. +- `vaara.policy.registry.PolicyRegistry`: one `PolicyController` per + tenant, with the empty string slot reserved as the default + fallback for unmatched lookups. +- `vaara serve --policy-dir DIR`: loads one YAML/JSON policy per + file. Filename stem = `tenant_id`; `default.yaml` lands in the + fallback slot. Mutually exclusive with `--policy`. +- `POST /v1/policy/reload` accepts a `tenant_id` body field (or + `X-Vaara-Tenant` header) and routes to the right registry slot; + creates the slot on first reload. + +### Changed +- `Pipeline.intercept` takes a `tenant_id` keyword that flows onto + the `ActionRequest` and into the audit trail. Default `""` keeps + the v0.39 single-tenant contract. +- `AdaptiveScorer.evaluate` dispatches allow / deny thresholds per + tenant at call time. A new `policy_lookup` constructor arg (and + `set_policy_lookup` setter for late binding from `ServerState`) + takes a `Callable[[str], Optional[Policy]]`; on every evaluate, the + scorer asks the registry for the calling tenant's policy and uses + its thresholds. An unknown or unmapped tenant falls back to the + scorer-bound defaults that the default-slot listener keeps fresh on + reload. The backend decision dict surfaces the applied + `threshold_allow` and `threshold_deny` so operators can confirm + which tenant's policy ran. MWU expert state, the conformal + calibrator, agent profiles, and sequence patterns stay shared + across tenants; only threshold application is per-tenant in v0.40. + +### Scope notes +- The HTTP transport on the proxy is POST-only. GET-SSE for + server-initiated notifications (sampling, server-pushed progress) + is v0.41. The audit + OVERT emission path for upstream-originated + notifications still works unchanged on stdio. +- Classifier bundle and conformal-calibrator hot-reload remain a + restart operation in v0.40. Per-tenant policy reload IS hot; that + is the configuration plane that needed to be live across tenants. + Classifier reload waits on a shared singleton lifecycle plus + per-tenant scoping question (v0.41 candidate). + ## [0.39.2] - 2026-05-27 **Theme: SEP-2787 envelope v2 shape, full wire round-trip, versioned @@ -1713,7 +1804,8 @@ and backward-compatible. Together they reposition Vaara from a Python library to a runtime kernel that control planes, audit consumers, and orchestration frameworks reference. The HTTP contract at `docs/openapi.yaml` is versioned `/v1/` independently of the project -version, following the OPA pattern. +version, so the wire surface can stabilise without locking the +library cadence. ### Added - **HTTP API reference server (`vaara[server]` extra).** Exposes the @@ -1789,10 +1881,9 @@ it governs. action class declared, matched sequences known). - **`vaara.policy.test_cases_io` module.** `load_test_cases(path)` reads a YAML or JSON cases document and returns a list of - `PolicyTestCase`. Document shape mirrors typical OPA / Conftest - test files: a top-level `cases:` list with `action_class`, - `risk_score`, optional `matched_sequences`, and an `expect:` block - carrying `verdict` and optional `route`. + `PolicyTestCase`. Document shape: a top-level `cases:` list with + `action_class`, `risk_score`, optional `matched_sequences`, and an + `expect:` block carrying `verdict` and optional `route`. - **`vaara policy validate POLICY_PATH [--json]`** and **`vaara policy test POLICY_PATH --cases CASES_PATH [--json]`** CLI subcommands. Both honour standard CI exit codes: validate returns diff --git a/README.md b/README.md index 6961b74..1ec79c6 100644 --- a/README.md +++ b/README.md @@ -168,13 +168,32 @@ if (r.decision === "deny") throw new Error("blocked"); `vaara.integrations.mcp_proxy.VaaraMCPProxy` sits between an MCP client (Claude Code, Cursor, any MCP-capable host) and an upstream MCP server. Every `tools/call` from the client routes through Vaara's interception pipeline before reaching the upstream. Allowed calls forward transparently and report the upstream outcome back to the scorer. Blocked calls return an MCP `isError: true` response with the block reason. The initialization handshake and `notifications/*` forward unchanged. `tools/list`, `resources/list`, `resources/read`, `prompts/list`, and `prompts/get` route through the operator perimeter before reaching the client or upstream. ```bash -python -m vaara.integrations.mcp_proxy \ +vaara-mcp-proxy \ --upstream npx --upstream-arg -y --upstream-arg @sap/mdk-mcp-server \ --db ./mcp_audit.db ``` Point your MCP client at the proxy instead of the upstream. The audit chain captures every tool call without changing client or upstream behavior. Distinct from `mcp_server`, which exposes Vaara itself as an MCP server for agents that consult Vaara as a tool. +
+Fleet shape (v0.40): one proxy, many upstreams, multi-tenant policy + +`vaara-mcp-proxy` also runs over Streamable HTTP with fan-out, so one process can serve a fleet of upstream MCP servers: + +```bash +vaara-mcp-proxy \ + --transport http \ + --http-host 127.0.0.1 \ + --http-port 8765 \ + --upstream 'github=npx -y @github/mcp-server' \ + --upstream 'sap=npx -y @sap/mdk-mcp-server' +``` + +Each `POST /mcp` reads two headers. `X-Vaara-Upstream` picks the upstream slot. `X-Vaara-Tenant` scopes the policy, audit chain, and OVERT envelope for that call. Single-upstream deployments keep the v0.39 silent-default contract. Multi-upstream deployments require `X-Vaara-Upstream` per call and return 400 with the available slot list when the header is missing. + +The reference HTTP API server (`vaara serve --policy-dir DIR`) loads one YAML or JSON policy per file in the directory (filename stem becomes the `tenant_id`, `default.yaml` lands in the fallback slot) and hot-reloads per tenant via `POST /v1/policy/reload` with a `tenant_id` body field or `X-Vaara-Tenant` header. The scorer dispatches allow and deny thresholds per call against the calling tenant's policy at `evaluate()` time. +
+
Operator perimeter: tool, resource, prompt filtering @@ -194,7 +213,7 @@ vaara keygen --dev --out signing.pem head -c 32 /dev/urandom > op.key # 3. Run the proxy with OVERT emission turned on. -python -m vaara.integrations.mcp_proxy \ +vaara-mcp-proxy \ --upstream npx --upstream-arg -y --upstream-arg @sap/mdk-mcp-server \ --overt-signing-key signing.pem \ --overt-operator-key op.key \ diff --git a/clients/ts/package.json b/clients/ts/package.json index 4c02587..11d047c 100644 --- a/clients/ts/package.json +++ b/clients/ts/package.json @@ -1,6 +1,6 @@ { "name": "@vaara/client", - "version": "0.39.2", + "version": "0.40.0", "description": "TypeScript client for the Vaara HTTP API. Conformal risk scoring, hash-chained audit, policy reload, named detectors.", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/pyproject.toml b/pyproject.toml index bb399c7..522ba27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vaara" -version = "0.39.2" +version = "0.40.0" description = "Adaptive AI Agent Execution Layer for risk scoring, audit trails, and regulatory compliance" requires-python = ">=3.10" license = "Apache-2.0" @@ -58,6 +58,7 @@ rebuff = ["rebuff>=0.1"] [project.scripts] vaara = "vaara.cli:main" vaara-audit = "vaara.audit_cli:main" +vaara-mcp-proxy = "vaara.integrations.mcp_proxy:main" [tool.setuptools.packages.find] where = ["src"] diff --git a/src/vaara/__init__.py b/src/vaara/__init__.py index da5ae1b..f7071ee 100644 --- a/src/vaara/__init__.py +++ b/src/vaara/__init__.py @@ -6,7 +6,7 @@ oversight. """ -__version__ = "0.39.2" +__version__ = "0.40.0" from vaara.pipeline import InterceptionPipeline, InterceptionResult diff --git a/src/vaara/audit/sqlite_backend.py b/src/vaara/audit/sqlite_backend.py index 13ec5fe..7f9df84 100644 --- a/src/vaara/audit/sqlite_backend.py +++ b/src/vaara/audit/sqlite_backend.py @@ -357,7 +357,11 @@ def write_record(self, record: AuditRecord) -> None: _strict_json_dumps(record.regulatory_articles), record.previous_hash, record.record_hash, - self._tenant_id, + # Per-record tenant_id wins so a single backend instance + # can serve a multi-tenant runtime (v0.40+). Empty record + # tenant_id falls back to instance scope for the legacy + # single-tenant init path. + record.tenant_id or self._tenant_id, record.system_operation, record.data_usage, record.decision_making, @@ -674,6 +678,7 @@ def _row_to_record(self, row: tuple) -> AuditRecord: agent_id = self._redaction_cache[agent_id] # Defensive indexing: rows from older queries may not include # the v3 columns. Use a guard so loading old DBs still works. + tenant_id = row[11] if len(row) > 11 else "" sys_op = row[12] if len(row) > 12 else None data_use = row[13] if len(row) > 13 else None dec_mk = row[14] if len(row) > 14 else None @@ -689,6 +694,7 @@ def _row_to_record(self, row: tuple) -> AuditRecord: regulatory_articles=json.loads(row[7]), previous_hash=row[8], record_hash=row[9], + tenant_id=tenant_id or "", system_operation=sys_op, data_usage=data_use, decision_making=dec_mk, diff --git a/src/vaara/audit/trail.py b/src/vaara/audit/trail.py index 6d08712..4f7cb13 100644 --- a/src/vaara/audit/trail.py +++ b/src/vaara/audit/trail.py @@ -258,6 +258,9 @@ class AuditRecord: data_usage: Optional[str] = None decision_making: Optional[str] = None limitations: Optional[str] = None + # v0.40: multi-tenant scoping. Empty string = single-tenant deployment. + # Excluded from compute_hash() to preserve pre-v0.40 chain re-verification. + tenant_id: str = "" def __post_init__(self) -> None: # Loaded-from-DB records carry a non-empty record_hash. Skip @@ -492,6 +495,12 @@ def __init__( self._by_action: dict[str, list[AuditRecord]] = defaultdict(list) self._last_hash = "" self._on_record = on_record + # v0.40 multi-tenant: action_id -> tenant_id, seeded by + # record_action_requested. Subsequent record_* calls (decision, + # execution, escalation) look up the action_id so every record in + # the lifecycle carries the same tenant scope without forcing + # every caller to thread tenant_id through every method signature. + self._tenant_for_action: dict[str, str] = {} # Counts on_record callback failures so callers can detect # persistence divergence at runtime (e.g., DB gone, disk full). # Without this, a silent logger.error is the only signal and the @@ -554,9 +563,33 @@ def verify_chain(self) -> Optional[str]: # ── Recording events ────────────────────────────────────────── + # Defense-in-depth cap for direct-trail callers that bypass the pipeline's + # length cap on tenant_id. The HTTP boundary already caps at 256 via the + # Pydantic schema, but the AuditTrail public API is reachable from + # embedders that construct ActionRequest directly. A 50MB tenant_id would + # otherwise balloon every record on the hash chain and the in-memory + # action -> tenant map. + _MAX_TENANT_ID_LEN = 256 + # Soft cap on the action -> tenant map. Long-running multi-tenant + # deployments would otherwise leak memory at one entry per action, + # because OUTCOME_RECORDED arrives well after ACTION_REQUESTED and the + # map cannot be cleared at decision time. When the cap is reached the + # oldest 1/8 of the map is evicted; subsequent lookups for evicted + # actions fall back to "" tenant, which is the legacy single-tenant + # contract — correct fail-soft behaviour. + _MAX_ACTION_TENANT_MAP = 50_000 + def record_action_requested(self, request: ActionRequest) -> str: """Record that an agent requested an action. Returns the action_id.""" action_id = str(uuid.uuid4()) + tenant_id = getattr(request, "tenant_id", "") or "" + if tenant_id: + tenant_id = self._cap_record_str(tenant_id, self._MAX_TENANT_ID_LEN) + if len(self._tenant_for_action) >= self._MAX_ACTION_TENANT_MAP: + evict = max(1, self._MAX_ACTION_TENANT_MAP // 8) + for stale in list(self._tenant_for_action)[:evict]: + self._tenant_for_action.pop(stale, None) + self._tenant_for_action[action_id] = tenant_id articles = self._get_regulatory_articles( EventType.ACTION_REQUESTED, @@ -600,10 +633,20 @@ def record_action_requested(self, request: ActionRequest) -> str: "sequence_position": request.sequence_position, }, regulatory_articles=articles, + tenant_id=tenant_id, )) return action_id + def _tenant_for(self, action_id: str) -> str: + """Resolve the tenant scope for an existing action lifecycle. + + Returns the tenant_id captured at record_action_requested time so + every follow-up record (risk_scored, decision, execution, + escalation, outcome) carries the same scope automatically. + """ + return self._tenant_for_action.get(action_id, "") + def record_risk_scored( self, action_id: str, @@ -631,6 +674,7 @@ def record_risk_scored( tool_name=self._cap_record_str(tool_name, self._MAX_TOOL_NAME_LEN), data=safe_assessment, regulatory_articles=articles, + tenant_id=self._tenant_for(action_id), )) def record_decision( @@ -666,6 +710,7 @@ def record_decision( "risk_score": risk_score, }, regulatory_articles=articles, + tenant_id=self._tenant_for(action_id), )) def record_execution( @@ -702,6 +747,7 @@ def record_execution( data={"result_summary": self._cap_record_dict_bytes( safe_result, self._MAX_EXECUTION_RESULT_JSON_BYTES )}, + tenant_id=self._tenant_for(action_id), )) def record_escalation( @@ -731,6 +777,7 @@ def record_escalation( "risk_score": risk_score, }, regulatory_articles=articles, + tenant_id=self._tenant_for(action_id), )) def record_escalation_resolved( @@ -760,6 +807,7 @@ def record_escalation_resolved( "justification": self._cap_record_str(justification, self._MAX_JUSTIFICATION_LEN), }, regulatory_articles=articles, + tenant_id=self._tenant_for(action_id), )) def record_outcome( @@ -794,6 +842,7 @@ def record_outcome( ), }, regulatory_articles=articles, + tenant_id=self._tenant_for(action_id), )) # Length caps for caller-controlled free-text fields on this direct @@ -923,6 +972,7 @@ def record_policy_override( "new_decision": new_decision, }, regulatory_articles=articles, + tenant_id=self._tenant_for(action_id), )) # ── Querying ────────────────────────────────────────────────── diff --git a/src/vaara/cli.py b/src/vaara/cli.py index 19f0ec2..9241e6e 100644 --- a/src/vaara/cli.py +++ b/src/vaara/cli.py @@ -1015,9 +1015,33 @@ def _cmd_serve(args: argparse.Namespace) -> int: from vaara.server import create_app - controller = None policy_path = getattr(args, "policy", None) - if policy_path: + policy_dir = getattr(args, "policy_dir", None) + if policy_path and policy_dir: + print( + "vaara serve: pass either --policy or --policy-dir, not both.", + file=sys.stderr, + ) + return 2 + + controller = None + registry = None + if policy_dir: + from vaara.policy.registry import PolicyRegistry + from vaara.policy.schema import PolicyError + + registry = PolicyRegistry() + try: + tenants = registry.load_directory(Path(policy_dir).expanduser()) + except PolicyError as exc: + print(f"vaara serve: --policy-dir failed to load: {exc}", file=sys.stderr) + return 2 + print( + f"vaara serve: loaded {len(tenants)} tenant policies " + f"(tenants={tenants!r})", + file=sys.stderr, + ) + elif policy_path: from vaara.policy.controller import PolicyController from vaara.policy.validate import validate_source @@ -1032,7 +1056,7 @@ def _cmd_serve(args: argparse.Namespace) -> int: return 2 controller = PolicyController(policy_obj) - app = create_app(policy_controller=controller) + app = create_app(policy_controller=controller, policy_registry=registry) uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level) return 0 @@ -1394,6 +1418,15 @@ def build_parser() -> argparse.ArgumentParser: "sequence patterns are applied to the scorer at startup." ), ) + pserve.add_argument( + "--policy-dir", + default=None, + help=( + "Directory of per-tenant policy files. Each *.yaml/*.yml/*.json " + "file is loaded; filename stem becomes the tenant_id " + "(default.yaml -> fallback). Mutually exclusive with --policy." + ), + ) pserve.add_argument( "--log-level", default="info", diff --git a/src/vaara/integrations/mcp_proxy.py b/src/vaara/integrations/mcp_proxy.py index db88754..193441a 100644 --- a/src/vaara/integrations/mcp_proxy.py +++ b/src/vaara/integrations/mcp_proxy.py @@ -25,9 +25,11 @@ from __future__ import annotations import argparse +import contextvars import json import logging import os +import re import sys import threading from pathlib import Path @@ -48,9 +50,48 @@ from vaara.pipeline import InterceptionPipeline from vaara.taxonomy.actions import ActionRequest +# Optional dependency: only the streamable-HTTP transport needs FastAPI / +# Starlette. Keep the import lazy so the stdio path stays installable with +# the base extras only. +try: # pragma: no cover - import guard + from starlette.requests import Request as _StarletteRequest +except ImportError: # pragma: no cover + _StarletteRequest = None # type: ignore[assignment] + logger = logging.getLogger(__name__) +# v0.40 per-request request scope. HTTP transport sets these per inbound +# request so _handle_request and friends can look up the right upstream and +# tag the audit/interception trail with the request's tenant_id without +# threading the values through every helper signature. +_REQUEST_UPSTREAM: contextvars.ContextVar[str] = contextvars.ContextVar( + "vaara_mcp_upstream", default="default", +) +_REQUEST_TENANT: contextvars.ContextVar[str] = contextvars.ContextVar( + "vaara_mcp_tenant", default="", +) + +def _safe_log(value: Any, max_len: int = 200) -> str: + """Sanitise a user-supplied string for safe logging. + + Strips CR/LF and other control characters so an attacker who controls + a tool / resource / prompt name can't inject fake log lines, and caps + length so a multi-megabyte name doesn't blow the log up. + """ + if not isinstance(value, str): + value = str(value) + cleaned = "".join(c if c.isprintable() and c not in ("\r", "\n") else "?" for c in value) + return cleaned[:max_len] + + +# Largest single MCP JSON-RPC message accepted on the /mcp HTTP endpoint. +# Real tool calls and responses fit comfortably; the cap stops a malicious +# client from exhausting memory at parse time. v0.41 can promote this to +# a CLI flag if a real workload needs more headroom. +_MCP_HTTP_MAX_BODY_BYTES = 1 * 1024 * 1024 + + class VaaraMCPProxy: """Transparent MCP proxy with Vaara interception on tool calls.""" @@ -58,7 +99,7 @@ class VaaraMCPProxy: def __init__( self, - upstream_command: list[str], + upstream_command: Optional[list[str]] = None, pipeline: Optional[InterceptionPipeline] = None, db_path: Optional[Path] = None, agent_id_default: str = "mcp-proxy-client", @@ -69,6 +110,7 @@ def __init__( prompt_allowlist: Optional[set[str]] = None, prompt_denylist: Optional[set[str]] = None, overt_emitter: Optional[OVERTReceiptEmitter] = None, + upstreams: Optional[dict[str, list[str]]] = None, ) -> None: if pipeline is not None: self._pipeline = pipeline @@ -98,17 +140,86 @@ def __init__( ) self._stdout_lock = threading.Lock() self._overt = overt_emitter - # progressToken -> (action_id, agent_id, tool_name). Populated when a - # tools/call enters interception with params._meta.progressToken set, - # consulted by the upstream-notification handler so progress events + # progressToken -> (action_id, agent_id, tool_name, tenant_id). Populated + # when a tools/call enters interception with params._meta.progressToken + # set, consulted by the upstream-notification handler so progress events # arriving mid-call carry the originating action_id into the audit - # record and into the OVERT envelope's non_content_metadata. - self._inflight_progress: dict[Any, tuple[str, str, str]] = {} + # record and into the OVERT envelope's non_content_metadata. The tenant + # is captured at tools/call time because the upstream reader thread that + # delivers later notifications does not inherit the request ContextVars. + self._inflight_progress: dict[Any, tuple[str, str, str, str]] = {} self._inflight_lock = threading.Lock() - self._upstream = UpstreamMCPClient( - command=upstream_command, - on_notification=self._on_upstream_notification, - ) + # v0.40 fan-out: hold N upstream MCP servers in a name -> client map. + # The single-upstream legacy entry point (positional ``upstream_command``) + # lands under the "default" name. ``--upstream NAME=CMD`` via CLI or + # ``upstreams={"NAME": [cmd, ...]}`` populates the map directly. The + # HTTP transport reads X-Vaara-Upstream per inbound request to pick; + # stdio transport stays on "default". + if upstreams and upstream_command is not None: + raise ValueError( + "Pass either upstream_command (single-upstream legacy) or " + "upstreams (multi-upstream fan-out), not both.", + ) + upstream_map: dict[str, list[str]] = {} + if upstreams: + upstream_map = {name: list(cmd) for name, cmd in upstreams.items()} + elif upstream_command is not None: + upstream_map = {"default": list(upstream_command)} + else: + raise ValueError( + "VaaraMCPProxy requires upstream_command or upstreams.", + ) + default_alias_target: Optional[str] = None + if "default" not in upstream_map: + # Pick a stable fallback so requests without X-Vaara-Upstream + # still resolve. Lexicographic first keeps multi-tenant fleets + # deterministic across restarts. Alias the slot rather than + # cloning the command so we never spawn a duplicate subprocess. + default_alias_target = sorted(upstream_map)[0] + self._upstreams: dict[str, UpstreamMCPClient] = { + name: UpstreamMCPClient( + command=command, + on_notification=self._on_upstream_notification, + ) + for name, command in upstream_map.items() + } + if default_alias_target is not None: + self._upstreams["default"] = self._upstreams[default_alias_target] + # ``self._upstream`` resolves to the per-request upstream via the + # ``_REQUEST_UPSTREAM`` ContextVar. stdio transport leaves the + # default ("default") in place, so existing single-upstream callers + # see exactly one client. HTTP transport sets the ctxvar per + # request to dispatch into the right fleet member. + + @property + def _upstream(self) -> UpstreamMCPClient: + """Resolve the upstream MCP client for the current request scope. + + HTTP transport sets ``_REQUEST_UPSTREAM`` per inbound request. stdio + transport never sets it; the default value "default" routes to the + legacy single-upstream slot. Unknown names raise ``ProxyError`` so + a client asking for a fleet member that does not exist gets a + loud failure instead of a silent reroute. + """ + name = _REQUEST_UPSTREAM.get() + client = self._upstreams.get(name) + if client is None: + raise ProxyError( + f"No upstream named {name!r}; configured names: " + f"{sorted(self._upstreams)!r}", + ) + return client + + @_upstream.setter + def _upstream(self, client: UpstreamMCPClient) -> None: + """Replace the default-slot upstream client. + + Test fixtures and embedders that previously assigned + ``proxy._upstream = MagicMock()`` keep working under the v0.40 + fan-out shape — the assignment lands in the "default" slot and the + property reads it back. + """ + self._upstreams["default"] = client @staticmethod def _is_filtered(name: object, allowlist: Optional[set[str]], denylist: set[str]) -> bool: @@ -150,6 +261,163 @@ def run(self) -> None: continue self._write_to_client(self._handle_request(request)) + def run_http(self, host: str, port: int, log_level: str = "info") -> None: + """Run the proxy on Streamable HTTP (MCP 2026 transport). + + POST /mcp accepts one JSON-RPC message and returns one JSON + response. Notifications (no ``id``) return 202 Accepted and are + forwarded to the upstream without a reply. Multi-tenant / + fan-out scope is read per request from the + ``X-Vaara-Tenant`` and ``X-Vaara-Upstream`` headers and pushed + into the per-request ContextVars before dispatch. + """ + try: + import uvicorn + from fastapi import FastAPI, Header, HTTPException, Response + from fastapi.responses import JSONResponse + except ImportError as exc: + raise RuntimeError( + "vaara-mcp-proxy --transport http requires the 'server' " + "extra. Install with: pip install 'vaara[server]'" + ) from exc + if _StarletteRequest is None: + raise RuntimeError( + "starlette is required for the streamable-HTTP transport. " + "Install with: pip install 'vaara[server]'" + ) + + proxy = self + + app = FastAPI( + title="Vaara MCP Proxy", + version=_VAARA_VERSION, + description=( + "Streamable HTTP transport in front of one or more upstream " + "MCP servers, with Vaara runtime governance applied to every " + "tools/call." + ), + ) + + @app.get("/health") + async def health() -> dict: + return { + "status": "ok", + "proxy": proxy.PROXY_NAME, + "upstreams": sorted(proxy._upstreams.keys()), + } + + @app.post("/mcp") + async def mcp_endpoint( + request: _StarletteRequest, + x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"), + x_vaara_upstream: Optional[str] = Header(default=None, alias="X-Vaara-Upstream"), + ) -> Response: + # 1 MiB cap on a single MCP JSON-RPC message. Real tool calls and + # responses fit comfortably; anything larger is either a misuse or + # a DoS attempt against the proxy's JSON parser. The cap is the + # same order as the MCP reference servers' limit. Hard-cap here + # before json.loads runs so a malicious payload cannot exhaust + # memory at parse time. + content_length = request.headers.get("content-length") + if content_length is not None: + try: + if int(content_length) > _MCP_HTTP_MAX_BODY_BYTES: + return JSONResponse( + status_code=413, + content={"error": { + "code": "payload_too_large", + "message": ( + f"MCP message exceeds " + f"{_MCP_HTTP_MAX_BODY_BYTES} bytes" + ), + }}, + ) + except ValueError: + pass # bogus content-length, fall through to actual read + raw = await request.body() + if len(raw) > _MCP_HTTP_MAX_BODY_BYTES: + return JSONResponse( + status_code=413, + content={"error": { + "code": "payload_too_large", + "message": ( + f"MCP message exceeds " + f"{_MCP_HTTP_MAX_BODY_BYTES} bytes" + ), + }}, + ) + try: + payload = json.loads(raw.decode("utf-8") or "{}") + except json.JSONDecodeError: + return JSONResponse( + status_code=400, + content=proxy._error_response(None, -32700, "Parse error"), + ) + + header_name = (x_vaara_upstream or "").strip() + # Real upstream slots are everything except the "default" alias. + # When the operator configured exactly one real slot (single- + # upstream deployment), silent fallback preserves the v0.39 + # contract. When the operator configured a fleet, ambiguity is + # an error: missing X-Vaara-Upstream returns 400 with the list + # so the client knows which slots are available, instead of + # silently routing to whichever slot won the sort. + real_slots = sorted(n for n in proxy._upstreams if n != "default") + ambiguous_fanout = len(real_slots) > 1 + if not header_name: + if ambiguous_fanout: + raise HTTPException( + status_code=400, + detail={ + "error": { + "code": "upstream_required", + "message": ( + "X-Vaara-Upstream header is required " + "when the proxy serves more than one " + "upstream. Available upstreams: " + f"{real_slots!r}" + ), + } + }, + ) + upstream_name = "default" + else: + upstream_name = header_name + if upstream_name not in proxy._upstreams: + raise HTTPException( + status_code=404, + detail={ + "error": { + "code": "unknown_upstream", + "message": ( + f"No upstream named {upstream_name!r}; " + f"configured: {sorted(proxy._upstreams)!r}" + ), + } + }, + ) + + upstream_token = _REQUEST_UPSTREAM.set(upstream_name) + tenant_token = _REQUEST_TENANT.set((x_vaara_tenant or "").strip()) + try: + if isinstance(payload, dict) and "id" not in payload: + try: + proxy._upstream.notify(payload) + except ProxyError: + logger.exception("Failed to forward HTTP notification") + return Response(status_code=202) + response = proxy._handle_request(payload) + return JSONResponse(content=response) + finally: + _REQUEST_UPSTREAM.reset(upstream_token) + _REQUEST_TENANT.reset(tenant_token) + + logger.info( + "Vaara MCP proxy starting on http://%s:%d (%s, upstreams=%s)", + host, port, self.PROXY_NAME, sorted(self._upstreams.keys()), + ) + uvicorn.run(app, host=host, port=port, log_level=log_level) + def _handle_request(self, request: Any) -> dict: if not isinstance(request, dict): return self._error_response(None, -32600, "Invalid Request: not a JSON object") @@ -257,7 +525,8 @@ def _handle_tools_call(self, request: dict) -> dict: arguments = {} if self._tool_filtered(tool_name): logger.warning( - "tools/call rejected at perimeter (operator filter): %s", tool_name, + "tools/call rejected at perimeter (operator filter): %s", + _safe_log(tool_name), ) block_payload = { "vaara_blocked": True, @@ -290,6 +559,7 @@ def _handle_tools_call(self, request: dict) -> dict: # registry (fail-closed). Correct default for runtime governance. result = self._pipeline.intercept( agent_id=agent_id, tool_name=tool_name, parameters=arguments, + tenant_id=_REQUEST_TENANT.get(), ) progress_token = self._progress_token(params) if not result.allowed: @@ -326,6 +596,7 @@ def _handle_tools_call(self, request: dict) -> dict: str(getattr(result, "action_id", None) or ""), agent_id, tool_name, + _REQUEST_TENANT.get(), ) try: upstream_response = self._upstream.request(request) @@ -364,7 +635,8 @@ def _handle_resources_read(self, request: dict) -> dict: uri = "" if self._resource_filtered(uri): logger.warning( - "resources/read rejected at perimeter (operator filter): %s", uri, + "resources/read rejected at perimeter (operator filter): %s", + _safe_log(uri), ) self._overt_emit( surface="mcp.resource.read", @@ -415,7 +687,8 @@ def _handle_prompts_get(self, request: dict) -> dict: arguments = {} if self._prompt_filtered(name): logger.warning( - "prompts/get rejected at perimeter (operator filter): %s", name, + "prompts/get rejected at perimeter (operator filter): %s", + _safe_log(name), ) self._overt_emit( surface="mcp.prompt.get", @@ -478,6 +751,7 @@ def _record_perimeter_audit( parameters: dict, decision: str, reason: str, + tenant_id: Optional[str] = None, ) -> None: """Write a request+decision audit pair for a read-oriented MCP access. @@ -488,9 +762,14 @@ def _record_perimeter_audit( policy while still producing the two records that anchor every access to the hash chain. Failures here are logged and swallowed: a perimeter audit failure must not block legitimate - upstream traffic. + upstream traffic. ``tenant_id`` is taken from the request + ContextVar by default; async callers that run outside the + originating request (upstream notification reader thread) pass + the captured tenant explicitly. """ import time as _time + if tenant_id is None: + tenant_id = _REQUEST_TENANT.get() try: registry = self._pipeline.registry action_type = registry.classify(tool_name, parameters) @@ -500,6 +779,7 @@ def _record_perimeter_audit( action_type=action_type, parameters=parameters or {}, timestamp_utc=_time.strftime("%Y-%m-%dT%H:%M:%SZ", _time.gmtime()), + tenant_id=tenant_id, ) action_id = self._pipeline.trail.record_action_requested(req) self._pipeline.trail.record_decision( @@ -524,12 +804,16 @@ def _overt_emit( decision: str, reason: str, extra: Optional[dict] = None, + tenant_id: Optional[str] = None, ) -> None: """Emit one OVERT Base Envelope for an MCP interaction. No-op when no emitter is configured. Failures are logged and swallowed: an attestation-side failure must not block legitimate - upstream traffic, mirroring the perimeter-audit rule. + upstream traffic, mirroring the perimeter-audit rule. ``tenant_id`` + is taken from the request ContextVar by default; async callers + that run outside the originating request pass the captured tenant + explicitly so the OVERT claim attributes to the right tenant. """ if self._overt is None: return @@ -540,6 +824,10 @@ def _overt_emit( "decision": decision, "reason": reason, } + if tenant_id is None: + tenant_id = _REQUEST_TENANT.get() + if tenant_id: + non_content_metadata["tenant_id"] = tenant_id if extra: non_content_metadata.update(extra) request_payload = strict_json_dumps( @@ -604,11 +892,17 @@ def _audit_progress_notification(self, message: dict) -> None: parent_action_id = "" agent_id = self._agent_id_default parent_tool = "" + # Notifications arrive on the upstream reader thread, which does + # not inherit the request ContextVars. Pull the tenant captured + # at tools/call time out of the inflight map instead of reading + # _REQUEST_TENANT here, otherwise the audit + OVERT claim land + # under empty tenant scope. + captured_tenant = "" if token is not None: with self._inflight_lock: entry = self._inflight_progress.get(token) if entry is not None: - parent_action_id, agent_id, parent_tool = entry + parent_action_id, agent_id, parent_tool, captured_tenant = entry self._record_perimeter_audit( agent_id=agent_id, tool_name="mcp.notification.progress", @@ -619,6 +913,7 @@ def _audit_progress_notification(self, message: dict) -> None: }, decision="observed", reason="upstream progress notification observed", + tenant_id=captured_tenant, ) self._overt_emit( surface="mcp.notification.progress", @@ -632,6 +927,7 @@ def _audit_progress_notification(self, message: dict) -> None: "parent_action_id": parent_action_id, "parent_tool": parent_tool, }, + tenant_id=captured_tenant, ) def _audit_message_notification(self, message: dict) -> None: @@ -644,12 +940,17 @@ def _audit_message_notification(self, message: dict) -> None: log_logger = params.get("logger", "") if not isinstance(log_logger, str): log_logger = "" + # Log notifications carry no progressToken, so there is no way to + # recover the originating request's tenant from the reader thread. + # Pass tenant_id="" explicitly to make the fail-soft scope visible + # rather than reading _REQUEST_TENANT in a thread that never set it. self._record_perimeter_audit( agent_id=self._agent_id_default, tool_name="mcp.notification.message", parameters={"level": level, "logger": log_logger}, decision="observed", reason="upstream log notification observed", + tenant_id="", ) self._overt_emit( surface="mcp.notification.message", @@ -659,6 +960,7 @@ def _audit_message_notification(self, message: dict) -> None: decision="observed", reason="upstream log notification observed", extra={"agent_id": self._agent_id_default}, + tenant_id="", ) def _write_to_client(self, payload: dict) -> None: @@ -672,7 +974,11 @@ def _error_response(req_id: Any, code: int, message: str) -> dict: return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}} def close(self) -> None: - self._upstream.close() + for client in self._upstreams.values(): + try: + client.close() + except Exception: + logger.exception("Failed to close upstream MCP client") if self._backend is not None: self._backend.close() @@ -680,11 +986,39 @@ def close(self) -> None: def main(argv: Optional[list[str]] = None) -> None: parser = argparse.ArgumentParser( prog="vaara-mcp-proxy", - description="Vaara runtime governance proxy in front of an upstream MCP server.", + description="Vaara runtime governance proxy in front of one or more upstream MCP servers.", + ) + parser.add_argument( + "--upstream", action="append", default=[], dest="upstreams", + help=( + "Upstream MCP server command. Repeatable for v0.40 fan-out: " + "`--upstream NAME=CMD` registers under a named slot; bare " + "`--upstream CMD` lands under 'default'. The first slot (or " + "'default' when supplied) is the stdio fallback." + ), ) - parser.add_argument("--upstream", required=True, help="Upstream MCP server command") parser.add_argument("--upstream-arg", action="append", default=[], dest="upstream_args", - help="Argument to pass to the upstream command (repeatable)") + help="Argument to pass to the (first) upstream command (repeatable)") + parser.add_argument( + "--transport", + choices=["stdio", "http"], + default="stdio", + help=( + "stdio (default) reads JSON-RPC from stdin/stdout, suitable for " + "in-process MCP clients. http exposes Streamable HTTP " + "(POST /mcp) for fleet / multi-tenant deployments and " + "requires the [server] extra." + ), + ) + parser.add_argument("--http-host", default="127.0.0.1", + help="Bind host when --transport http (default 127.0.0.1)") + parser.add_argument("--http-port", type=int, default=8765, + help="Bind port when --transport http (default 8765)") + parser.add_argument( + "--http-log-level", + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + ) parser.add_argument("--db", type=Path, default=None, help="Audit database path (default: $VAARA_DB or ./vaara_audit.db)") parser.add_argument("--agent-id", default="mcp-proxy-client", @@ -743,8 +1077,19 @@ def main(argv: Optional[list[str]] = None) -> None: ), ) + upstreams = _parse_upstream_specs(args.upstreams, args.upstream_args) + if not upstreams: + parser.error( + "at least one --upstream is required (e.g. `--upstream " + "github=github-mcp-server` or `--upstream npx`).", + ) + + legacy_single = ( + list(next(iter(upstreams.values()))) if len(upstreams) == 1 else None + ) proxy = VaaraMCPProxy( - upstream_command=[args.upstream, *args.upstream_args], + upstream_command=legacy_single, + upstreams=upstreams if legacy_single is None else None, db_path=args.db, agent_id_default=args.agent_id, allowlist=tool_allow, denylist=tool_deny if tool_deny else None, @@ -755,11 +1100,61 @@ def main(argv: Optional[list[str]] = None) -> None: overt_emitter=overt_emitter, ) try: - proxy.run() + if args.transport == "http": + proxy.run_http( + host=args.http_host, + port=args.http_port, + log_level=args.http_log_level, + ) + else: + proxy.run() finally: proxy.close() +# A fan-out slot name is a short alphanumeric slug. The narrow pattern +# stops _parse_upstream_specs from confusing a command that itself +# contains '=' (e.g. ``python -m foo --bar=baz``) with a NAME=CMD prefix. +_UPSTREAM_NAME_RE = re.compile(r"^[a-zA-Z0-9_\-]{1,64}$") + + +def _parse_upstream_specs( + upstream_specs: list[str], legacy_args: list[str], +) -> dict[str, list[str]]: + """Turn ``--upstream`` / ``--upstream-arg`` CLI input into a fan-out map. + + Each ``--upstream`` is either ``NAME=CMD`` (named — NAME is a short + alphanumeric slug) or ``CMD`` (lands under "default"). Commands that + contain ``=`` (e.g. ``python -m foo --bar=baz``) stay intact because + the NAME-prefix check rejects anything whose left-of-``=`` half isn't + a valid slug. Legacy ``--upstream-arg`` values append to the first + named slot for back-compat with single-upstream callers. + """ + upstreams: dict[str, list[str]] = {} + first_name: Optional[str] = None + for spec in upstream_specs: + if "=" in spec: + candidate_name, _, candidate_cmd = spec.partition("=") + candidate_name = candidate_name.strip() + candidate_cmd = candidate_cmd.strip() + if _UPSTREAM_NAME_RE.match(candidate_name) and candidate_cmd: + name, command = candidate_name, candidate_cmd + else: + name, command = "default", spec + else: + name, command = "default", spec + if not name or not command: + raise SystemExit( + f"invalid --upstream value {spec!r}; expected NAME=CMD or CMD", + ) + upstreams[name] = [command] + if first_name is None: + first_name = name + if legacy_args and first_name is not None: + upstreams[first_name].extend(legacy_args) + return upstreams + + def _build_overt_emitter_from_args( args: argparse.Namespace, *, policy_hash: bytes, ) -> Optional[OVERTReceiptEmitter]: diff --git a/src/vaara/pipeline.py b/src/vaara/pipeline.py index 3876987..30afc13 100644 --- a/src/vaara/pipeline.py +++ b/src/vaara/pipeline.py @@ -266,6 +266,7 @@ def intercept( session_id: str = "", parent_action_id: Optional[str] = None, sequence_position: int = 0, + tenant_id: str = "", ) -> InterceptionResult: """Intercept an agent action request. @@ -314,6 +315,7 @@ def intercept( parent_action_id=parent_action_id, sequence_position=sequence_position, timestamp_utc=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + tenant_id=tenant_id, ) # 3. Record the request in audit trail diff --git a/src/vaara/policy/controller.py b/src/vaara/policy/controller.py index 9908d18..b41efba 100644 --- a/src/vaara/policy/controller.py +++ b/src/vaara/policy/controller.py @@ -69,7 +69,7 @@ def add_listener(self, listener: PolicyListener) -> None: listener(self._policy) def reload( - self, source: Union[str, Path, dict], *, format: Optional[str] = None + self, source: Union[str, Path, dict, Policy], *, format: Optional[str] = None ) -> ReloadResult: """Parse, validate, and apply a new policy. @@ -77,10 +77,14 @@ def reload( When omitted, ``.yaml``/``.yml`` paths use the YAML loader, dicts bypass parsing, and everything else goes through JSON. + Already-validated ``Policy`` instances may be passed directly; the + registry path (``vaara.policy.registry.PolicyRegistry``) uses this + to swap a per-tenant policy that was parsed in bulk. + Raises ``PolicyError`` if the source is malformed; in that case the previously loaded policy remains live. """ - new_policy = _load(source, format) + new_policy = source if isinstance(source, Policy) else _load(source, format) with self._lock: self._policy = new_policy self._version += 1 diff --git a/src/vaara/policy/registry.py b/src/vaara/policy/registry.py new file mode 100644 index 0000000..e48601c --- /dev/null +++ b/src/vaara/policy/registry.py @@ -0,0 +1,149 @@ +"""Per-tenant policy registry. + +A ``PolicyRegistry`` owns one ``PolicyController`` per tenant_id, with the +empty string ("") reserved for the default / fallback policy used when a +request carries no tenant scope or no tenant-specific policy is loaded. + +Filename convention for ``load_directory``: + +* ``default.yaml`` / ``default.json`` → tenant_id "" +* ``TENANT.yaml`` / ``TENANT.json`` → tenant_id "TENANT" + +This is the v0.40 multi-tenant policy plane. Single-tenant deployments +keep using ``vaara serve --policy PATH``, which lands in the "" slot. +""" + +from __future__ import annotations + +import threading +from pathlib import Path +from typing import Optional, Union + +from vaara.policy.controller import PolicyController, ReloadResult +from vaara.policy.loader import from_json, from_yaml +from vaara.policy.schema import Policy, PolicyError + + +_DEFAULT_TENANT = "" +_POLICY_SUFFIXES = (".yaml", ".yml", ".json") + + +def _filename_to_tenant(stem: str) -> str: + return "" if stem.lower() == "default" else stem + + +def _load_path(path: Path) -> Policy: + if path.suffix in (".yaml", ".yml"): + return from_yaml(path) + return from_json(path) + + +class PolicyRegistry: + """Holds one PolicyController per tenant. Thread-safe.""" + + def __init__(self) -> None: + self._controllers: dict[str, PolicyController] = {} + self._lock = threading.RLock() + + def __contains__(self, tenant_id: str) -> bool: + with self._lock: + return tenant_id in self._controllers + + def tenants(self) -> list[str]: + with self._lock: + return sorted(self._controllers.keys()) + + def get(self, tenant_id: str) -> Optional[PolicyController]: + """Return the controller for ``tenant_id``, falling back to the + default ("") slot. Returns None if neither is registered. + """ + with self._lock: + if tenant_id in self._controllers: + return self._controllers[tenant_id] + return self._controllers.get(_DEFAULT_TENANT) + + def get_exact(self, tenant_id: str) -> Optional[PolicyController]: + """Return only an exact-match controller, no default fallback.""" + with self._lock: + return self._controllers.get(tenant_id) + + def register(self, tenant_id: str, controller: PolicyController) -> None: + with self._lock: + self._controllers[tenant_id] = controller + + def reload( + self, + tenant_id: str, + source: Union[str, Path, dict], + *, + format: Optional[str] = None, + ) -> ReloadResult: + """Reload one tenant's policy. Creates the slot if missing.""" + with self._lock: + controller = self._controllers.get(tenant_id) + if controller is None: + policy = _materialise(source, format) + controller = PolicyController(policy) + self._controllers[tenant_id] = controller + return ReloadResult( + version=controller.version, + thresholds_default_escalate=policy.thresholds_default.escalate, + thresholds_default_deny=policy.thresholds_default.deny, + sequence_count=len(policy.sequences), + action_class_count=len(policy.action_classes), + escalation_route_count=len(policy.escalation_routes), + ) + return controller.reload(source, format=format) + + def load_directory(self, directory: Union[str, Path]) -> list[str]: + """Load every ``*.yaml``/``*.yml``/``*.json`` file in ``directory`` as + one tenant's policy. Returns the list of tenant_ids loaded. + + Raises ``PolicyError`` if any file fails to parse — partial loads + are not allowed, the registry stays untouched. + """ + directory = Path(directory) + if not directory.is_dir(): + raise PolicyError(f"policy directory does not exist: {directory}") + + candidates: list[tuple[str, Path]] = [] + for entry in sorted(directory.iterdir()): + if entry.suffix not in _POLICY_SUFFIXES or not entry.is_file(): + continue + candidates.append((_filename_to_tenant(entry.stem), entry)) + + if not candidates: + raise PolicyError(f"policy directory holds no policy files: {directory}") + + parsed: list[tuple[str, Policy]] = [ + (tenant_id, _load_path(path)) for tenant_id, path in candidates + ] + with self._lock: + for tenant_id, policy in parsed: + existing = self._controllers.get(tenant_id) + if existing is None: + self._controllers[tenant_id] = PolicyController(policy) + else: + existing.reload(policy) + return [tenant_id for tenant_id, _ in parsed] + + +def _materialise(source: Union[str, Path, dict], fmt: Optional[str]) -> Policy: + """Mirror PolicyController._load for the new-tenant fast path.""" + from vaara.policy.loader import from_dict + if isinstance(source, dict): + return from_dict(source) + if fmt == "yaml": + return from_yaml(source) + if fmt == "json": + return from_json(source) + if isinstance(source, Path): + return _load_path(source) + if isinstance(source, str): + if source.lstrip().startswith("{"): + return from_json(source) + return _load_path(Path(source)) + raise PolicyError(f"unsupported policy source type: {type(source).__name__}") + + +__all__ = ["PolicyRegistry"] diff --git a/src/vaara/scorer/adaptive.py b/src/vaara/scorer/adaptive.py index 929bc29..a2eed81 100644 --- a/src/vaara/scorer/adaptive.py +++ b/src/vaara/scorer/adaptive.py @@ -31,7 +31,7 @@ from collections import OrderedDict, deque from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional if TYPE_CHECKING: from vaara.policy.schema import Policy @@ -137,6 +137,8 @@ def to_backend_decision(self) -> dict: "action": self.decision.value, "reason": self.explanation, "backend": "vaara_adaptive", + "threshold_allow": self.threshold_allow, + "threshold_deny": self.threshold_deny, "raw_result": { "point_estimate": self.point_estimate, "conformal_interval": [self.conformal_lower, self.conformal_upper], @@ -609,6 +611,7 @@ def __init__( max_tracked_agents: int = 10_000, pre_seed_calibration: bool = True, mondrian_categories: bool = False, + policy_lookup: Optional[Callable[[str], Any]] = None, ) -> None: """ Args: @@ -638,6 +641,17 @@ def __init__( pre-Mondrian marginal behaviour. The default seed prior always lands in the default bucket regardless of this flag — synthetic benign pairs have no real category. + policy_lookup: Optional callable taking a tenant_id string and + returning that tenant's Policy (or None if the tenant has + no policy of its own). When set and the evaluate-time + context carries a non-empty tenant_id, the scorer uses + that tenant's allow/deny thresholds for THIS call instead + of the scorer-bound defaults. Tenants that resolve to + None fall back to the scorer-bound defaults; an empty + tenant_id skips the lookup entirely. The MWU expert + state, conformal calibrator, agent profiles, and sequence + detector remain shared across tenants — only threshold + application is per-tenant in v0.40. """ if threshold_allow >= threshold_deny: raise ValueError( @@ -647,6 +661,7 @@ def __init__( ) self._threshold_allow = threshold_allow self._threshold_deny = threshold_deny + self._policy_lookup = policy_lookup self._burst_window = burst_window_seconds self._burst_threshold = burst_threshold self._max_tracked_agents = max_tracked_agents @@ -735,6 +750,48 @@ def _seed_conformal_prior(self) -> None: actual = 0.05 self._conformal.add_calibration_point(predicted, actual) + def set_policy_lookup( + self, policy_lookup: Optional[Callable[[str], Any]] + ) -> None: + """Late-binding setter for the per-tenant policy lookup. + + ``ServerState`` constructs the scorer before the ``PolicyRegistry`` + is wired up in some code paths, so the lookup needs to attach + after construction without going through __init__. Runs under + the scorer's RLock so an in-flight evaluate either sees the old + lookup (or None) or the new one — never a torn assignment. + """ + with self._lock: + self._policy_lookup = policy_lookup + + def _thresholds_for(self, tenant_id: str) -> tuple[float, float]: + """Resolve (allow, deny) thresholds for this call. + + Empty tenant_id or no policy_lookup configured returns the + scorer-bound defaults rebound by the most recent apply_policy + on the default ("") slot. A configured lookup that returns + None (tenant has no policy of its own) also falls back to the + defaults. _thresholds_for is called from _evaluate_locked while + the scorer lock is held; the lookup acquires the registry lock + on top, so lock ordering is consistently scorer -> registry + across this codebase (never the reverse). + """ + if not tenant_id or self._policy_lookup is None: + return (self._threshold_allow, self._threshold_deny) + try: + tenant_policy = self._policy_lookup(tenant_id) + except Exception: + return (self._threshold_allow, self._threshold_deny) + if tenant_policy is None: + return (self._threshold_allow, self._threshold_deny) + from vaara.policy.schema import Policy # local import to avoid cycles + if not isinstance(tenant_policy, Policy): + return (self._threshold_allow, self._threshold_deny) + return ( + tenant_policy.thresholds_default.escalate, + tenant_policy.thresholds_default.deny, + ) + def _calib_category(self, tool_name: str) -> Optional[str]: """Category to route through to the calibrator for this action. @@ -771,11 +828,17 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any: # Extract fields from context dict tool_name = context.get("tool_name", "unknown") agent_id = context.get("agent_id", "anonymous") + tenant_id = context.get("tenant_id", "") or "" base_risk = _coerce_unit_float(context.get("base_risk_score", 0.5), 0.5) agent_confidence = _coerce_optional_unit_float(context.get("agent_confidence")) reversibility = context.get("reversibility", "partially_reversible") blast_radius = context.get("blast_radius", "local") + # Per-tenant threshold dispatch. Empty tenant_id or no lookup + # configured returns the scorer-bound defaults bound at apply_policy + # time on the default ("") slot. + threshold_allow, threshold_deny = self._thresholds_for(tenant_id) + # Build risk signals from each expert signals = self._compute_signals( tool_name=tool_name, @@ -801,9 +864,9 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any: # If the worst-case (within 1-alpha confidence) is safe, allow it. # If the best-case is dangerous, deny it. decision_score = upper - if decision_score < self._threshold_allow: + if decision_score < threshold_allow: decision = Decision.ALLOW - elif decision_score > self._threshold_deny: + elif decision_score > threshold_deny: decision = Decision.DENY else: decision = Decision.ESCALATE @@ -825,7 +888,7 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any: explanation = ( f"{decision.value}: risk={point_estimate:.3f} " f"[{lower:.3f}, {upper:.3f}] " - f"(threshold allow<{self._threshold_allow} deny>{self._threshold_deny})" + f"(threshold allow<{threshold_allow} deny>{threshold_deny})" ) assessment = RiskAssessment( @@ -837,8 +900,8 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any: decision=decision, signals=signals, mwu_weights=self._mwu.weights, - threshold_allow=self._threshold_allow, - threshold_deny=self._threshold_deny, + threshold_allow=threshold_allow, + threshold_deny=threshold_deny, sequence_risk=signals.get("sequence_pattern", 0.0), calibration_size=self._conformal.calibration_size, effective_alpha=eff_alpha, @@ -869,11 +932,14 @@ def dry_run_evaluate(self, context: dict[str, Any]) -> Any: def _dry_run_evaluate_locked(self, context: dict[str, Any]) -> Any: tool_name = context.get("tool_name", "unknown") agent_id = context.get("agent_id", "anonymous") + tenant_id = context.get("tenant_id", "") or "" base_risk = _coerce_unit_float(context.get("base_risk_score", 0.5), 0.5) agent_confidence = _coerce_optional_unit_float(context.get("agent_confidence")) reversibility = context.get("reversibility", "partially_reversible") blast_radius = context.get("blast_radius", "local") + threshold_allow, threshold_deny = self._thresholds_for(tenant_id) + # _compute_signals is read-only for everything except the # sequence detector's warning log — silence that temporarily. seq_logger = logging.getLogger( @@ -898,9 +964,9 @@ def _dry_run_evaluate_locked(self, context: dict[str, Any]) -> Any: lower, upper = self._conformal.predict_interval( point_estimate, category=bucket, ) - if upper < self._threshold_allow: + if upper < threshold_allow: decision = Decision.ALLOW - elif upper > self._threshold_deny: + elif upper > threshold_deny: decision = Decision.DENY else: decision = Decision.ESCALATE @@ -913,6 +979,8 @@ def _dry_run_evaluate_locked(self, context: dict[str, Any]) -> Any: "calibration_size": self._conformal.calibration_size, "effective_alpha": self._conformal.effective_alpha_for(bucket), "bucket_category": bucket, + "threshold_allow": threshold_allow, + "threshold_deny": threshold_deny, }, } diff --git a/src/vaara/server/app.py b/src/vaara/server/app.py index 17d24d8..047d834 100644 --- a/src/vaara/server/app.py +++ b/src/vaara/server/app.py @@ -17,6 +17,7 @@ from vaara.audit.trail import AuditTrail from vaara.policy.controller import PolicyController +from vaara.policy.registry import PolicyRegistry from vaara.scorer.adaptive import AdaptiveScorer from vaara.server.routes import register from vaara.server.state import ServerState @@ -26,6 +27,7 @@ def create_app( scorer: Optional[AdaptiveScorer] = None, audit: Optional[AuditTrail] = None, policy_controller: Optional[PolicyController] = None, + policy_registry: Optional[PolicyRegistry] = None, ) -> FastAPI: """Build the FastAPI application. @@ -34,11 +36,19 @@ def create_app( audit: Pre-configured audit trail, or None for default in-memory. policy_controller: Pre-loaded ``PolicyController``. When supplied, the scorer is registered as a listener and ``POST - /v1/policy/reload`` becomes available. When omitted, the - reload endpoint returns ``409 policy_not_configured``. + /v1/policy/reload`` becomes available. When omitted (and no + ``policy_registry``), the reload endpoint returns + ``409 policy_not_configured``. + policy_registry: Pre-loaded ``PolicyRegistry`` for multi-tenant + deployments. Mutually exclusive with ``policy_controller`` — + single-controller callers are wrapped into a registry's "" + slot automatically by ``ServerState``. """ state = ServerState( - scorer=scorer, audit=audit, policy_controller=policy_controller + scorer=scorer, + audit=audit, + policy_controller=policy_controller, + policy_registry=policy_registry, ) app = FastAPI( title="Vaara HTTP API", diff --git a/src/vaara/server/routes.py b/src/vaara/server/routes.py index c3fb5b8..e04a8ee 100644 --- a/src/vaara/server/routes.py +++ b/src/vaara/server/routes.py @@ -7,7 +7,7 @@ from datetime import datetime, timezone from typing import Optional -from fastapi import FastAPI, HTTPException, status +from fastapi import FastAPI, Header, HTTPException, status from fastapi.responses import JSONResponse from vaara import __version__ as _vaara_version @@ -31,6 +31,12 @@ def _iso(ts: float) -> str: return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat() +def _resolve_tenant(body_value: str, header_value: Optional[str]) -> str: + body = (body_value or "").strip() + header = (header_value or "").strip() + return body or header + + def register(app: FastAPI, state: ServerState) -> None: @app.exception_handler(HTTPException) @@ -65,8 +71,13 @@ async def server_info(): ) @app.post("/v1/score", response_model=S.ScoreResponse) - async def score(req: S.ScoreRequest): + async def score( + req: S.ScoreRequest, + x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"), + ): + tenant_id = _resolve_tenant(req.tenant_id, x_vaara_tenant) ctx = req.model_dump(exclude_none=True) + ctx["tenant_id"] = tenant_id try: decision_dict = state.scorer.evaluate(ctx) except Exception as exc: @@ -84,6 +95,7 @@ async def score(req: S.ScoreRequest): tool_name=req.tool_name, predicted_risk=float(raw.get("point_estimate", 0.5) or 0.5), signals=signals, + tenant_id=tenant_id, ) return S.ScoreResponse( @@ -99,8 +111,12 @@ async def score(req: S.ScoreRequest): signals=signals, mwu_weights={k: float(v) for k, v in state.scorer._mwu.weights.items()}, thresholds=S.Thresholds( - allow=state.scorer._threshold_allow, - deny=state.scorer._threshold_deny, + allow=float( + decision_dict.get("threshold_allow", state.scorer._threshold_allow) + ), + deny=float( + decision_dict.get("threshold_deny", state.scorer._threshold_deny) + ), ), sequence_risk=float(raw.get("sequence_risk", 0.0) or 0.0), calibration_size=int(raw.get("calibration_size", 0) or 0), @@ -130,7 +146,10 @@ async def score_outcome(req: S.OutcomeRequest): response_model=S.AuditEventResponse, status_code=201, ) - async def append_audit_event(req: S.AuditEventRequest): + async def append_audit_event( + req: S.AuditEventRequest, + x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"), + ): try: event_type = EventType(req.event_type) except ValueError: @@ -139,6 +158,12 @@ async def append_audit_event(req: S.AuditEventRequest): status.HTTP_400_BAD_REQUEST, ) + tenant_id = _resolve_tenant(req.tenant_id, x_vaara_tenant) + if not tenant_id: + info = state.lookup_action(req.action_id) + if info is not None: + tenant_id = info.tenant_id + record = AuditRecord( record_id=str(uuid.uuid4()), action_id=req.action_id, @@ -148,6 +173,7 @@ async def append_audit_event(req: S.AuditEventRequest): tool_name=req.tool_name or "", data=req.payload or {}, regulatory_articles=[], + tenant_id=tenant_id, ) state.audit._append(record) return S.AuditEventResponse( @@ -217,16 +243,26 @@ async def detect_pii_endpoint(req: S.DetectPIIRequest): return S.DetectPIIResponse(**result.to_dict()) @app.post("/v1/policy/reload", response_model=S.PolicyReloadResponse) - async def reload_policy(req: S.PolicyReloadRequest): + async def reload_policy( + req: S.PolicyReloadRequest, + x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"), + ): from vaara.policy.schema import PolicyError - controller = state.policy_controller - if controller is None: + tenant_id = _resolve_tenant(req.tenant_id, x_vaara_tenant) + registry = state.policy_registry + controller = ( + registry.get_exact(tenant_id) if registry is not None else None + ) + if controller is None and not tenant_id: + controller = state.policy_controller + if registry is None and controller is None: raise _error( code="policy_not_configured", message=( - "Server has no PolicyController; start with " - "`vaara serve --policy PATH` to enable reload." + "Server has no policy plane; start with " + "`vaara serve --policy PATH` or `--policy-dir DIR` to " + "enable reload." ), http_status=status.HTTP_409_CONFLICT, ) @@ -240,7 +276,10 @@ async def reload_policy(req: S.PolicyReloadRequest): source = req.body if req.body is not None else req.path try: - result = controller.reload(source, format=req.format) + if registry is not None: + result = registry.reload(tenant_id, source, format=req.format) + else: + result = controller.reload(source, format=req.format) except PolicyError as exc: raise _error( code="policy_invalid", @@ -257,4 +296,5 @@ async def reload_policy(req: S.PolicyReloadRequest): sequence_count=result.sequence_count, action_class_count=result.action_class_count, escalation_route_count=result.escalation_route_count, + tenant_id=tenant_id, ) diff --git a/src/vaara/server/schemas.py b/src/vaara/server/schemas.py index 74671ca..1fe92c6 100644 --- a/src/vaara/server/schemas.py +++ b/src/vaara/server/schemas.py @@ -41,6 +41,7 @@ class ScoreRequest(BaseModel): blast_radius: Optional[_BlastRadius] = None session_id: Optional[str] = Field(default=None, max_length=256) parent_action_id: Optional[str] = Field(default=None, max_length=128) + tenant_id: str = Field(default="", max_length=256) context: dict[str, Any] = Field(default_factory=dict) @@ -85,6 +86,7 @@ class AuditEventRequest(BaseModel): action_id: str agent_id: Optional[str] = None tool_name: Optional[str] = None + tenant_id: str = Field(default="", max_length=256) payload: dict[str, Any] = Field(default_factory=dict) @@ -203,6 +205,7 @@ class PolicyReloadRequest(BaseModel): path: Optional[str] = Field(default=None, max_length=4096) body: Optional[dict[str, Any]] = None format: Optional[Literal["json", "yaml"]] = None + tenant_id: str = Field(default="", max_length=256) class PolicyReloadResponse(BaseModel): @@ -211,3 +214,4 @@ class PolicyReloadResponse(BaseModel): sequence_count: int action_class_count: int escalation_route_count: int + tenant_id: str = "" diff --git a/src/vaara/server/state.py b/src/vaara/server/state.py index eca3baf..e34ad76 100644 --- a/src/vaara/server/state.py +++ b/src/vaara/server/state.py @@ -1,4 +1,4 @@ -"""Server state container — scorer + audit trail + policy controller singletons.""" +"""Server state container — scorer + audit trail + policy registry singletons.""" from __future__ import annotations @@ -8,6 +8,7 @@ from vaara.audit.trail import AuditTrail from vaara.policy.controller import PolicyController +from vaara.policy.registry import PolicyRegistry from vaara.scorer.adaptive import AdaptiveScorer @@ -17,6 +18,7 @@ class _ActionInfo: tool_name: str predicted_risk: float signals: dict[str, float] = field(default_factory=dict) + tenant_id: str = "" class ServerState: @@ -27,12 +29,43 @@ def __init__( scorer: Optional[AdaptiveScorer] = None, audit: Optional[AuditTrail] = None, policy_controller: Optional[PolicyController] = None, + policy_registry: Optional[PolicyRegistry] = None, ) -> None: + if policy_controller is not None and policy_registry is not None: + raise ValueError( + "Pass either policy_controller (single-tenant legacy) or " + "policy_registry (multi-tenant v0.40), not both. Mixing " + "the two silently splits threshold sources between the " + "default slot and per-tenant overrides.", + ) self.scorer = scorer or AdaptiveScorer() self.audit = audit or AuditTrail() + # v0.40: a single PolicyRegistry holds all tenant policies. The + # single-tenant entry point (`policy_controller=...`) lands in the + # empty-string "" slot for back-compat with v0.39 callers. + if policy_registry is None and policy_controller is not None: + policy_registry = PolicyRegistry() + policy_registry.register("", policy_controller) + self.policy_registry = policy_registry self.policy_controller = policy_controller if policy_controller is not None: policy_controller.add_listener(self.scorer.apply_policy) + elif policy_registry is not None: + default = policy_registry.get("") + if default is not None: + default.add_listener(self.scorer.apply_policy) + self.policy_controller = default + # v0.40 per-tenant threshold dispatch: the scorer asks the + # registry for the calling tenant's policy on every evaluate. + # An exact-match miss falls back to the scorer-bound defaults + # (which the default-slot listener keeps fresh on reload). We + # use get_exact rather than get so the scorer's own default + # path stays the single fallback channel. + if policy_registry is not None: + def _lookup(tid: str): + ctrl = policy_registry.get_exact(tid) + return ctrl.policy if ctrl is not None else None + self.scorer.set_policy_lookup(_lookup) self._lock = threading.Lock() # action_id → info captured at score time so outcome reports can # feed the MWU update without the client having to resend context. @@ -45,6 +78,7 @@ def remember_action( tool_name: str, predicted_risk: float, signals: dict[str, float], + tenant_id: str = "", ) -> None: with self._lock: self._actions[action_id] = _ActionInfo( @@ -52,6 +86,7 @@ def remember_action( tool_name=tool_name, predicted_risk=predicted_risk, signals=signals, + tenant_id=tenant_id, ) def lookup_action(self, action_id: str) -> Optional[_ActionInfo]: diff --git a/src/vaara/taxonomy/actions.py b/src/vaara/taxonomy/actions.py index e4b9c7e..d086284 100644 --- a/src/vaara/taxonomy/actions.py +++ b/src/vaara/taxonomy/actions.py @@ -136,6 +136,7 @@ class ActionRequest: parent_action_id: Optional[str] = None # For action chains sequence_position: int = 0 # Position in current action sequence timestamp_utc: str = "" + tenant_id: str = "" # v0.40 multi-tenant scope; "" = single-tenant def to_policy_context(self) -> dict: """Convert to a plain dict of policy-evaluation fields. @@ -159,6 +160,7 @@ def to_policy_context(self) -> dict: "parent_action_id": self.parent_action_id, "sequence_position": self.sequence_position, "parameters": self.parameters, + "tenant_id": self.tenant_id, } diff --git a/tests/test_integrations_mcp_proxy.py b/tests/test_integrations_mcp_proxy.py index b43e94a..aaea107 100644 --- a/tests/test_integrations_mcp_proxy.py +++ b/tests/test_integrations_mcp_proxy.py @@ -491,7 +491,7 @@ def upstream_request(req): "_meta": {"progressToken": "tok-stream"}, }, }) - assert progress_seen == [{"tok-stream": ("parent-act-9", "mcp-proxy-client", "long_tool")}] + assert progress_seen == [{"tok-stream": ("parent-act-9", "mcp-proxy-client", "long_tool", "")}] # Map is cleaned up after the call returns. assert p._inflight_progress == {} # Audit was called for the progress event with the parent correlation. diff --git a/tests/test_v040_mcp_http_transport.py b/tests/test_v040_mcp_http_transport.py new file mode 100644 index 0000000..5a67941 --- /dev/null +++ b/tests/test_v040_mcp_http_transport.py @@ -0,0 +1,262 @@ +"""v0.40 MCP proxy streamable-HTTP transport + fan-out routing.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +try: + from fastapi.testclient import TestClient +except ImportError: + pytest.skip( + "server extra not installed (pip install 'vaara[server]')", + allow_module_level=True, + ) + +from vaara.integrations import mcp_proxy +from vaara.integrations.mcp_proxy import VaaraMCPProxy, _parse_upstream_specs + + +# ── _parse_upstream_specs ────────────────────────────────────────────────── + +def test_parse_upstream_specs_bare_command_lands_in_default(): + result = _parse_upstream_specs(["echo"], []) + assert result == {"default": ["echo"]} + + +def test_parse_upstream_specs_named_slot(): + result = _parse_upstream_specs(["github=gh-mcp-server"], []) + assert result == {"github": ["gh-mcp-server"]} + + +def test_parse_upstream_specs_command_with_equals_stays_intact(): + """A command like `python -m foo --bar=baz` must NOT be split at '='.""" + result = _parse_upstream_specs(["python -m foo --bar=baz"], []) + assert result == {"default": ["python -m foo --bar=baz"]} + + +def test_parse_upstream_specs_legacy_args_join_first_slot(): + result = _parse_upstream_specs(["echo"], ["hello", "world"]) + assert result == {"default": ["echo", "hello", "world"]} + + +def test_parse_upstream_specs_multiple_named_fanout(): + result = _parse_upstream_specs(["a=cmd-a", "b=cmd-b"], []) + assert result == {"a": ["cmd-a"], "b": ["cmd-b"]} + + +def test_parse_upstream_specs_unknown_name_pattern_falls_to_default(): + """Names that aren't simple slugs aren't treated as NAME= prefix.""" + result = _parse_upstream_specs(["python -m srv=foo"], []) + # The left side ("python -m srv") fails the slug regex, so the whole + # spec is treated as a bare command under "default". + assert result == {"default": ["python -m srv=foo"]} + + +# ── VaaraMCPProxy multi-upstream constructor ─────────────────────────────── + +@pytest.fixture +def http_proxy(monkeypatch): + """A VaaraMCPProxy with mocked upstreams and pipeline.""" + monkeypatch.setattr(mcp_proxy, "UpstreamMCPClient", MagicMock()) + pipeline = MagicMock() + proxy = VaaraMCPProxy( + upstreams={"alpha": ["cmd-alpha"], "beta": ["cmd-beta"]}, + pipeline=pipeline, + ) + return proxy + + +def test_constructor_rejects_both_upstream_and_upstreams(): + with pytest.raises(ValueError, match="Pass either"): + VaaraMCPProxy( + upstream_command=["echo"], + upstreams={"a": ["echo"]}, + pipeline=MagicMock(), + ) + + +def test_constructor_requires_at_least_one_upstream(): + with pytest.raises(ValueError, match="upstream"): + VaaraMCPProxy(pipeline=MagicMock()) + + +def test_multi_upstream_populates_default_slot(http_proxy): + assert "default" in http_proxy._upstreams + # When no explicit default is provided, the sorted-first name acts as fallback. + assert http_proxy._upstreams["default"] is not None + + +def test_default_slot_aliases_first_upstream_not_duplicate(monkeypatch): + """Regression guard: default fallback must alias an existing upstream + client instead of spawning a duplicate subprocess. An earlier shape + cloned the command into upstream_map["default"] before constructing + the client map, which built a second UpstreamMCPClient for the same + command.""" + constructed: list[MagicMock] = [] + + def make_client(*_args, **_kw): + instance = MagicMock() + constructed.append(instance) + return instance + + monkeypatch.setattr( + mcp_proxy, "UpstreamMCPClient", MagicMock(side_effect=make_client), + ) + proxy = VaaraMCPProxy( + upstreams={"alpha": ["cmd-alpha"], "beta": ["cmd-beta"]}, + pipeline=MagicMock(), + ) + # Exactly two clients constructed (alpha, beta), not three. + assert len(constructed) == 2 + # "default" aliases the sorted-first real slot. + assert proxy._upstreams["default"] is proxy._upstreams["alpha"] + assert proxy._upstreams["default"] is not proxy._upstreams["beta"] + + +def test_single_upstream_lands_under_default_slot(monkeypatch): + monkeypatch.setattr(mcp_proxy, "UpstreamMCPClient", MagicMock()) + proxy = VaaraMCPProxy( + upstream_command=["echo"], pipeline=MagicMock(), + ) + assert list(proxy._upstreams) == ["default"] + + +# ── HTTP transport endpoints ─────────────────────────────────────────────── + +def _build_http_app(proxy): + """Construct the FastAPI app run_http() builds, without uvicorn.run().""" + import unittest.mock as um + + with um.patch("uvicorn.run") as run_mock: + # Capture the app by intercepting uvicorn.run() — we call the same + # method the production CLI uses, but never block on the event loop. + captured: dict = {} + + def fake_run(app, **kwargs): + captured["app"] = app + + run_mock.side_effect = fake_run + proxy.run_http(host="127.0.0.1", port=0) + return captured["app"] + + +def test_http_health_lists_upstreams(http_proxy): + app = _build_http_app(http_proxy) + client = TestClient(app) + resp = client.get("/health") + assert resp.status_code == 200 + assert set(resp.json()["upstreams"]) >= {"alpha", "beta", "default"} + + +def test_http_mcp_post_routes_to_named_upstream(http_proxy): + http_proxy._upstreams["alpha"].request.return_value = { + "jsonrpc": "2.0", "id": 1, "result": {"tools": []}, + } + app = _build_http_app(http_proxy) + client = TestClient(app) + resp = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers={"X-Vaara-Upstream": "alpha"}, + ) + assert resp.status_code == 200 + http_proxy._upstreams["alpha"].request.assert_called_once() + + +def test_http_mcp_fanout_without_header_returns_400(http_proxy): + """Multi-upstream deployment must NOT silently default-route.""" + app = _build_http_app(http_proxy) + client = TestClient(app) + resp = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + ) + assert resp.status_code == 400 + body = resp.json() + assert body["detail"]["error"]["code"] == "upstream_required" + # Operator gets the list of valid slots in the error so a client UI + # can recover without an additional health probe round-trip. + assert "alpha" in body["detail"]["error"]["message"] + assert "beta" in body["detail"]["error"]["message"] + + +def test_http_mcp_single_upstream_silent_default(monkeypatch): + """Single-upstream deployment keeps the v0.39 silent-default contract.""" + monkeypatch.setattr(mcp_proxy, "UpstreamMCPClient", MagicMock()) + proxy = VaaraMCPProxy(upstream_command=["echo"], pipeline=MagicMock()) + proxy._upstreams["default"].request.return_value = { + "jsonrpc": "2.0", "id": 1, "result": {"tools": []}, + } + app = _build_http_app(proxy) + client = TestClient(app) + resp = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + ) + assert resp.status_code == 200 + proxy._upstreams["default"].request.assert_called_once() + + +def test_http_mcp_unknown_upstream_returns_404(http_proxy): + app = _build_http_app(http_proxy) + client = TestClient(app) + resp = client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers={"X-Vaara-Upstream": "no-such-thing"}, + ) + assert resp.status_code == 404 + assert resp.json()["detail"]["error"]["code"] == "unknown_upstream" + + +def test_http_mcp_notification_returns_202(http_proxy): + app = _build_http_app(http_proxy) + client = TestClient(app) + resp = client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers={"X-Vaara-Upstream": "alpha"}, + ) + assert resp.status_code == 202 + + +def test_http_mcp_oversized_body_returns_413(http_proxy, monkeypatch): + monkeypatch.setattr(mcp_proxy, "_MCP_HTTP_MAX_BODY_BYTES", 64) + app = _build_http_app(http_proxy) + client = TestClient(app) + payload = { + "jsonrpc": "2.0", "id": 1, "method": "tools/list", + "params": {"x": "a" * 10_000}, + } + resp = client.post("/mcp", json=payload) + assert resp.status_code == 413 + assert resp.json()["error"]["code"] == "payload_too_large" + + +def test_http_mcp_bad_json_returns_parse_error(http_proxy): + app = _build_http_app(http_proxy) + client = TestClient(app) + resp = client.post("/mcp", content=b"not json") + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == -32700 + + +def test_http_mcp_tenant_header_threads_into_overt(http_proxy): + """X-Vaara-Tenant becomes a non_content_metadata claim on OVERT envelope.""" + http_proxy._overt = MagicMock() + http_proxy._upstreams["alpha"].request.return_value = { + "jsonrpc": "2.0", "id": 1, "result": {"tools": []}, + } + app = _build_http_app(http_proxy) + client = TestClient(app) + client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers={"X-Vaara-Tenant": "tenant-q", "X-Vaara-Upstream": "alpha"}, + ) + # tools/list does not emit an OVERT envelope in the current proxy + # implementation, so this test exercises only that the request + # dispatched cleanly with the header set. + http_proxy._upstreams["alpha"].request.assert_called_once() diff --git a/tests/test_v040_per_tenant_threshold.py b/tests/test_v040_per_tenant_threshold.py new file mode 100644 index 0000000..2b2dcf2 --- /dev/null +++ b/tests/test_v040_per_tenant_threshold.py @@ -0,0 +1,212 @@ +"""v0.40 per-tenant threshold dispatch at evaluate() time. + +The scorer holds defaults bound from the default ("") slot via the +standard apply_policy listener path, and on every evaluate it looks up +the calling tenant's policy from the PolicyRegistry. A tenant with its +own thresholds gets those thresholds applied to THIS call; a tenant +with no policy of its own falls back to the scorer-bound defaults. +""" + +from __future__ import annotations + +from vaara.policy.controller import PolicyController +from vaara.policy.loader import from_dict +from vaara.policy.registry import PolicyRegistry +from vaara.scorer.adaptive import AdaptiveScorer + + +def _policy_dict(escalate: float, deny: float) -> dict: + return { + "version": "0.1", + "thresholds": {"default": {"escalate": escalate, "deny": deny}}, + } + + +def _ctx(tenant_id: str = "", base_risk: float = 0.5) -> dict: + return { + "tool_name": "tool.test", + "agent_id": "agent-test", + "tenant_id": tenant_id, + "base_risk_score": base_risk, + "reversibility": "partially_reversible", + "blast_radius": "local", + } + + +def _registry_lookup(registry: PolicyRegistry): + def _lookup(tid: str): + ctrl = registry.get_exact(tid) + return ctrl.policy if ctrl is not None else None + return _lookup + + +def test_tenant_thresholds_override_default_for_that_call(): + registry = PolicyRegistry() + registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7)))) + registry.register( + "tenant-a", PolicyController(from_dict(_policy_dict(0.05, 0.10))) + ) + + scorer = AdaptiveScorer() + scorer.apply_policy(registry.get_exact("").policy) + scorer.set_policy_lookup(_registry_lookup(registry)) + + default_result = scorer.evaluate(_ctx(tenant_id="", base_risk=0.5)) + assert default_result["threshold_allow"] == 0.4 + assert default_result["threshold_deny"] == 0.7 + + strict_result = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + assert strict_result["threshold_allow"] == 0.05 + assert strict_result["threshold_deny"] == 0.10 + + again = scorer.evaluate(_ctx(tenant_id="", base_risk=0.5)) + assert again["threshold_allow"] == 0.4 + assert again["threshold_deny"] == 0.7 + + +def test_unknown_tenant_falls_back_to_scorer_defaults(): + registry = PolicyRegistry() + registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7)))) + + scorer = AdaptiveScorer() + scorer.apply_policy(registry.get_exact("").policy) + scorer.set_policy_lookup(_registry_lookup(registry)) + + result = scorer.evaluate(_ctx(tenant_id="ghost", base_risk=0.5)) + assert result["threshold_allow"] == 0.4 + assert result["threshold_deny"] == 0.7 + + +def test_empty_tenant_id_skips_lookup_and_uses_defaults(): + calls = [] + + def _lookup(tid: str): + calls.append(tid) + return None + + scorer = AdaptiveScorer(threshold_allow=0.4, threshold_deny=0.7) + scorer.set_policy_lookup(_lookup) + + result = scorer.evaluate(_ctx(tenant_id="", base_risk=0.5)) + assert result["threshold_allow"] == 0.4 + assert result["threshold_deny"] == 0.7 + assert calls == [] + + +def test_lookup_exception_falls_back_to_defaults(): + def _broken_lookup(tid: str): + raise RuntimeError("registry unreachable") + + scorer = AdaptiveScorer(threshold_allow=0.4, threshold_deny=0.7) + scorer.set_policy_lookup(_broken_lookup) + + result = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + assert result["threshold_allow"] == 0.4 + assert result["threshold_deny"] == 0.7 + + +def test_tenant_reload_visible_to_next_evaluate_without_listener(): + registry = PolicyRegistry() + registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7)))) + registry.register( + "tenant-a", PolicyController(from_dict(_policy_dict(0.30, 0.50))) + ) + + scorer = AdaptiveScorer() + scorer.apply_policy(registry.get_exact("").policy) + scorer.set_policy_lookup(_registry_lookup(registry)) + + first = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + assert first["threshold_allow"] == 0.30 + + registry.reload("tenant-a", _policy_dict(0.05, 0.10)) + after = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + assert after["threshold_allow"] == 0.05 + assert after["threshold_deny"] == 0.10 + + +def test_dry_run_evaluate_uses_per_tenant_thresholds(): + registry = PolicyRegistry() + registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7)))) + registry.register( + "tenant-a", PolicyController(from_dict(_policy_dict(0.05, 0.10))) + ) + + scorer = AdaptiveScorer() + scorer.apply_policy(registry.get_exact("").policy) + scorer.set_policy_lookup(_registry_lookup(registry)) + + default_dry = scorer.dry_run_evaluate(_ctx(tenant_id="", base_risk=0.5)) + strict_dry = scorer.dry_run_evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + + assert default_dry["raw_result"]["threshold_allow"] == 0.4 + assert strict_dry["raw_result"]["threshold_allow"] == 0.05 + assert strict_dry["raw_result"]["threshold_deny"] == 0.10 + + +def test_server_state_wires_lookup_from_registry(): + from vaara.server.state import ServerState + + registry = PolicyRegistry() + registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7)))) + registry.register( + "tenant-a", PolicyController(from_dict(_policy_dict(0.05, 0.10))) + ) + + state = ServerState(policy_registry=registry) + result = state.scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + assert result["threshold_allow"] == 0.05 + assert result["threshold_deny"] == 0.10 + + +def test_score_response_surfaces_per_tenant_thresholds_over_http(): + """The /v1/score response's `thresholds` block reflects the + per-call values, not the scorer's bound defaults. Regression test: + smoke-test caught the response surfacing 0.4/0.7 for every tenant + even though the decision itself was already dispatching correctly. + """ + import pytest as _pytest + try: + from fastapi.testclient import TestClient + from vaara.server import create_app + except ImportError: + _pytest.skip("server extra not installed (pip install 'vaara[server]')") + return # pytest.skip raises; the return keeps static analysers happy + + registry = PolicyRegistry() + registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7)))) + registry.register( + "tenant-strict", + PolicyController(from_dict(_policy_dict(0.05, 0.10))), + ) + app = create_app(policy_registry=registry) + client = TestClient(app) + req = {"tool_name": "tool.read", "agent_id": "ag-1", "action_type": "data_read"} + + default_resp = client.post("/v1/score", json=req).json() + assert default_resp["thresholds"]["allow"] == 0.4 + assert default_resp["thresholds"]["deny"] == 0.7 + + strict_resp = client.post( + "/v1/score", json=req, headers={"X-Vaara-Tenant": "tenant-strict"} + ).json() + assert strict_resp["thresholds"]["allow"] == 0.05 + assert strict_resp["thresholds"]["deny"] == 0.10 + + unknown_resp = client.post( + "/v1/score", json=req, headers={"X-Vaara-Tenant": "ghost"} + ).json() + assert unknown_resp["thresholds"]["allow"] == 0.4 + assert unknown_resp["thresholds"]["deny"] == 0.7 + + +def test_non_policy_lookup_return_falls_back(): + def _bad_typed_lookup(tid: str): + return {"not": "a Policy"} + + scorer = AdaptiveScorer(threshold_allow=0.4, threshold_deny=0.7) + scorer.set_policy_lookup(_bad_typed_lookup) + + result = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5)) + assert result["threshold_allow"] == 0.4 + assert result["threshold_deny"] == 0.7 diff --git a/tests/test_v040_policy_registry.py b/tests/test_v040_policy_registry.py new file mode 100644 index 0000000..e32a3e7 --- /dev/null +++ b/tests/test_v040_policy_registry.py @@ -0,0 +1,121 @@ +"""v0.40 PolicyRegistry per-tenant policy plane + /v1/policy/reload routing.""" + +from __future__ import annotations + +import pytest + +try: + from fastapi.testclient import TestClient + + from vaara.server import create_app +except ImportError: + pytest.skip( + "server extra not installed (pip install 'vaara[server]')", + allow_module_level=True, + ) + +from vaara.policy.controller import PolicyController +from vaara.policy.loader import from_dict +from vaara.policy.registry import PolicyRegistry +from vaara.policy.schema import PolicyError + + +def _minimal_policy() -> dict: + return { + "version": "0.1", + "thresholds_default": {"escalate": 0.5, "deny": 0.9}, + } + + +def test_policy_registry_reload_creates_new_tenant_slot(): + registry = PolicyRegistry() + result = registry.reload("tenant-a", _minimal_policy()) + assert result.version == 1 + assert "tenant-a" in registry + + +def test_policy_registry_get_falls_back_to_default_slot(): + registry = PolicyRegistry() + registry.reload("", _minimal_policy()) + assert registry.get("nonexistent") is not None + assert registry.get_exact("nonexistent") is None + + +def test_policy_registry_load_directory(tmp_path): + (tmp_path / "default.json").write_text( + '{"version": "0.1", "thresholds_default": {"escalate": 0.4, "deny": 0.8}}' + ) + (tmp_path / "tenant-a.json").write_text( + '{"version": "0.1", "thresholds_default": {"escalate": 0.3, "deny": 0.7}}' + ) + registry = PolicyRegistry() + tenants = registry.load_directory(tmp_path) + assert sorted(tenants) == ["", "tenant-a"] + assert "" in registry + assert "tenant-a" in registry + + +def test_policy_registry_load_directory_rejects_empty(tmp_path): + registry = PolicyRegistry() + with pytest.raises(PolicyError): + registry.load_directory(tmp_path) + + +def test_policy_registry_accepts_policy_instance_on_reload(): + """Bulk-load path passes Policy directly to avoid re-parsing.""" + registry = PolicyRegistry() + registry.reload("", _minimal_policy()) + policy_obj = from_dict(_minimal_policy()) + result = registry.reload("", policy_obj) + assert result.version == 2 + + +def test_policy_reload_per_tenant_via_body(): + registry = PolicyRegistry() + registry.reload("", _minimal_policy()) + app = create_app(policy_registry=registry) + client = TestClient(app) + resp = client.post( + "/v1/policy/reload", + json={ + "body": { + "version": "0.1", + "thresholds_default": {"escalate": 0.2, "deny": 0.6}, + }, + "tenant_id": "tenant-x", + }, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["tenant_id"] == "tenant-x" + assert "tenant-x" in registry + + +def test_policy_reload_per_tenant_via_header(): + registry = PolicyRegistry() + registry.reload("", _minimal_policy()) + app = create_app(policy_registry=registry) + client = TestClient(app) + resp = client.post( + "/v1/policy/reload", + json={"body": _minimal_policy()}, + headers={"X-Vaara-Tenant": "tenant-y"}, + ) + assert resp.status_code == 200 + assert resp.json()["tenant_id"] == "tenant-y" + + +def test_policy_reload_back_compat_single_controller(): + controller = PolicyController(from_dict(_minimal_policy())) + app = create_app(policy_controller=controller) + client = TestClient(app) + resp = client.post("/v1/policy/reload", json={"body": _minimal_policy()}) + assert resp.status_code == 200 + assert resp.json()["tenant_id"] == "" + + +def test_policy_reload_unconfigured_returns_409(): + app = create_app() + client = TestClient(app) + resp = client.post("/v1/policy/reload", json={"body": _minimal_policy()}) + assert resp.status_code == 409 + assert resp.json()["error"]["code"] == "policy_not_configured" diff --git a/tests/test_v040_tenant.py b/tests/test_v040_tenant.py new file mode 100644 index 0000000..1685ba1 --- /dev/null +++ b/tests/test_v040_tenant.py @@ -0,0 +1,165 @@ +"""v0.40 tenant_id end-to-end propagation tests.""" + +from __future__ import annotations + +import pytest + +try: + from fastapi.testclient import TestClient + + from vaara.server import create_app +except ImportError: + pytest.skip( + "server extra not installed (pip install 'vaara[server]')", + allow_module_level=True, + ) + +from vaara.audit.trail import AuditRecord, AuditTrail, EventType +from vaara.taxonomy.actions import ( + ActionCategory, + ActionRequest, + ActionType, + BlastRadius, + Reversibility, +) + + +def _action_type(name: str = "t") -> ActionType: + return ActionType( + name=name, + category=ActionCategory.DATA, + reversibility=Reversibility.FULLY, + blast_radius=BlastRadius.LOCAL, + ) + + +def _minimal_policy() -> dict: + return { + "version": 1, + "thresholds_default": {"escalate": 0.5, "deny": 0.9}, + } + + +def test_score_body_tenant_id_lands_in_action_info(): + app = create_app() + client = TestClient(app) + resp = client.post( + "/v1/score", + json={"tool_name": "search", "agent_id": "a", "tenant_id": "tenant-a"}, + ) + assert resp.status_code == 200 + info = app.state.vaara.lookup_action(resp.json()["action_id"]) + assert info is not None + assert info.tenant_id == "tenant-a" + + +def test_score_header_tenant_id_used_when_body_empty(): + app = create_app() + client = TestClient(app) + resp = client.post( + "/v1/score", + json={"tool_name": "search", "agent_id": "a"}, + headers={"X-Vaara-Tenant": "tenant-b"}, + ) + assert resp.status_code == 200 + info = app.state.vaara.lookup_action(resp.json()["action_id"]) + assert info.tenant_id == "tenant-b" + + +def test_score_body_tenant_wins_over_header(): + app = create_app() + client = TestClient(app) + resp = client.post( + "/v1/score", + json={"tool_name": "search", "agent_id": "a", "tenant_id": "body-tenant"}, + headers={"X-Vaara-Tenant": "header-tenant"}, + ) + info = app.state.vaara.lookup_action(resp.json()["action_id"]) + assert info.tenant_id == "body-tenant" + + +def test_audit_event_writes_with_tenant_from_action_lookup(): + app = create_app() + client = TestClient(app) + score = client.post( + "/v1/score", + json={"tool_name": "search", "agent_id": "a", "tenant_id": "tenant-c"}, + ) + action_id = score.json()["action_id"] + resp = client.post( + "/v1/audit/events", + json={ + "event_type": "action_executed", + "action_id": action_id, + "agent_id": "a", + "tool_name": "search", + "payload": {"result": "ok"}, + }, + ) + assert resp.status_code == 201 + records = [ + r for r in app.state.vaara.audit._records if r.action_id == action_id + ] + assert any(r.tenant_id == "tenant-c" for r in records) + + +def test_audit_trail_tenant_propagates_to_followup_records(): + trail = AuditTrail() + action_type = _action_type() + req = ActionRequest( + agent_id="a", tool_name="t", action_type=action_type, + parameters={}, tenant_id="tenant-d", + ) + action_id = trail.record_action_requested(req) + trail.record_decision( + action_id=action_id, agent_id="a", tool_name="t", + decision="allow", reason="ok", risk_score=0.1, + ) + trail.record_execution( + action_id=action_id, agent_id="a", tool_name="t", result={"ok": True}, + ) + records = trail.get_action_trail(action_id) + assert len(records) == 3 + assert all(r.tenant_id == "tenant-d" for r in records) + + +def test_audit_trail_tenant_map_evicts_under_pressure(): + trail = AuditTrail() + trail._MAX_ACTION_TENANT_MAP = 100 + action_type = _action_type() + last_ids: list[str] = [] + for i in range(150): + req = ActionRequest( + agent_id="a", tool_name="t", action_type=action_type, + tenant_id=f"t{i}", + ) + last_ids.append(trail.record_action_requested(req)) + assert len(trail._tenant_for_action) <= 100 + assert trail._tenant_for_action.get(last_ids[-1]) == "t149" + assert last_ids[0] not in trail._tenant_for_action + + +def test_audit_trail_caps_tenant_id_length(): + trail = AuditTrail() + action_type = _action_type() + huge = "x" * 10_000 + req = ActionRequest( + agent_id="a", tool_name="t", action_type=action_type, + tenant_id=huge, + ) + action_id = trail.record_action_requested(req) + record = trail.get_action_trail(action_id)[0] + assert len(record.tenant_id) <= trail._MAX_TENANT_ID_LEN + + +def test_audit_record_hash_excludes_tenant_id(): + """tenant_id is NOT part of compute_hash so pre-v0.40 chains re-verify.""" + rec_no = AuditRecord( + record_id="r1", action_id="a1", event_type=EventType.ACTION_REQUESTED, + timestamp=1.0, agent_id="a", tool_name="t", + ) + rec_with = AuditRecord( + record_id="r1", action_id="a1", event_type=EventType.ACTION_REQUESTED, + timestamp=1.0, agent_id="a", tool_name="t", tenant_id="tenant-x", + ) + assert rec_no.compute_hash() == rec_with.compute_hash()