diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 8c3ea5a..858d0aa 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -30,6 +30,9 @@ jobs:
- name: Install Vaara (editable, no deps)
run: pip install -e . --no-deps
+ - name: Install server extra (fastapi, uvicorn, httpx) for HTTP transport tests
+ run: pip install 'fastapi>=0.110' 'uvicorn>=0.27' 'httpx>=0.27'
+
- name: Lint (ruff)
run: ruff check .
diff --git a/CHANGELOG.md b/CHANGELOG.md
index fd148a3..f13d179 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,97 @@ and this project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.ht
## [Unreleased]
+## [0.40.0] - 2026-05-28
+
+**Theme: deployment shape. One Vaara process now serves a fleet of
+upstream MCP servers, with multi-tenant policy, audit, and attestation
+on the same substrate.**
+
+The v0.39 sidecar shape ran one Vaara process per upstream. v0.40
+turns that into a single process that speaks Streamable HTTP, holds
+N upstream MCP-server connections, picks the upstream per request
+from a header, scopes every score, audit record, and OVERT envelope
+to a tenant, and reloads per-tenant policy in place.
+
+### Added
+- `vaara-mcp-proxy --transport http --http-host H --http-port P`:
+ Streamable HTTP transport at `POST /mcp`, backed by FastAPI /
+ uvicorn (the `vaara[server]` extra already shipped in v0.39 for
+ `vaara serve`). The endpoint reads `X-Vaara-Tenant` and
+ `X-Vaara-Upstream` per request, pushes them into ContextVars, and
+ dispatches into the existing `_handle_request` path so the policy,
+ perimeter, OVERT, and progress-notification handling all light up
+ unchanged. Notifications (no JSON-RPC `id`) return 202 Accepted.
+ Bodies above 1 MiB return 413.
+- `vaara-mcp-proxy --upstream NAME=CMD` (repeatable) for fan-out.
+ One Vaara process holds N `UpstreamMCPClient` instances in a name
+ -> client map. Bare `--upstream CMD` keeps the v0.39 single-
+ upstream contract; it lands in the "default" slot. Commands that
+ themselves contain `=` (e.g. `python -m foo --bar=baz`) stay
+ intact because the name-side regex only matches short alphanumeric
+ slugs. When more than one upstream is configured, a request with
+ no `X-Vaara-Upstream` header returns 400 with the list of valid
+ slots; silent fallback to whichever slot won the sort would be a
+ failure mode that surfaces only in production. Single-upstream
+ deployments keep the silent-default contract.
+- `tenant_id` is first-class through the request, decision, audit,
+ and attestation layers:
+ - `ScoreRequest`, `AuditEventRequest`, and `PolicyReloadRequest`
+ accept a `tenant_id` body field, with `X-Vaara-Tenant` as the
+ HTTP-header alternative. Body wins over header.
+ - `AuditRecord` gains a `tenant_id` field, excluded from
+ `compute_hash()` so pre-v0.40 chains still re-verify on load.
+ - `AuditTrail` keeps an `action_id -> tenant_id` map seeded by
+ `record_action_requested`, so every follow-up record
+ (`risk_scored`, `decision`, `execution`, `escalation`,
+ `outcome`, `policy_override`) inherits the same scope without
+ every caller threading `tenant_id` through every signature.
+ The map is soft-capped (50k entries, 12.5% eviction on
+ pressure) so long-running deployments cannot leak memory.
+ - `SQLiteAuditBackend.write_record` prefers the per-record
+ `tenant_id` when set, with the instance-scoped `tenant_id`
+ (legacy CLI tooling path) as fallback. A single backend
+ instance can now serve a multi-tenant runtime.
+ - OVERT envelopes carry `tenant_id` as a `non_content_metadata`
+ claim when present.
+- `vaara.policy.registry.PolicyRegistry`: one `PolicyController` per
+ tenant, with the empty string slot reserved as the default
+ fallback for unmatched lookups.
+- `vaara serve --policy-dir DIR`: loads one YAML/JSON policy per
+ file. Filename stem = `tenant_id`; `default.yaml` lands in the
+ fallback slot. Mutually exclusive with `--policy`.
+- `POST /v1/policy/reload` accepts a `tenant_id` body field (or
+ `X-Vaara-Tenant` header) and routes to the right registry slot;
+ creates the slot on first reload.
+
+### Changed
+- `Pipeline.intercept` takes a `tenant_id` keyword that flows onto
+ the `ActionRequest` and into the audit trail. Default `""` keeps
+ the v0.39 single-tenant contract.
+- `AdaptiveScorer.evaluate` dispatches allow / deny thresholds per
+ tenant at call time. A new `policy_lookup` constructor arg (and
+ `set_policy_lookup` setter for late binding from `ServerState`)
+ takes a `Callable[[str], Optional[Policy]]`; on every evaluate, the
+ scorer asks the registry for the calling tenant's policy and uses
+ its thresholds. An unknown or unmapped tenant falls back to the
+ scorer-bound defaults that the default-slot listener keeps fresh on
+ reload. The backend decision dict surfaces the applied
+ `threshold_allow` and `threshold_deny` so operators can confirm
+ which tenant's policy ran. MWU expert state, the conformal
+ calibrator, agent profiles, and sequence patterns stay shared
+ across tenants; only threshold application is per-tenant in v0.40.
+
+### Scope notes
+- The HTTP transport on the proxy is POST-only. GET-SSE for
+ server-initiated notifications (sampling, server-pushed progress)
+ is v0.41. The audit + OVERT emission path for upstream-originated
+ notifications still works unchanged on stdio.
+- Classifier bundle and conformal-calibrator hot-reload remain a
+ restart operation in v0.40. Per-tenant policy reload IS hot; that
+ is the configuration plane that needed to be live across tenants.
+ Classifier reload waits on a shared singleton lifecycle plus
+ per-tenant scoping question (v0.41 candidate).
+
## [0.39.2] - 2026-05-27
**Theme: SEP-2787 envelope v2 shape, full wire round-trip, versioned
@@ -1713,7 +1804,8 @@ and backward-compatible. Together they reposition Vaara from a Python
library to a runtime kernel that control planes, audit consumers, and
orchestration frameworks reference. The HTTP contract at
`docs/openapi.yaml` is versioned `/v1/` independently of the project
-version, following the OPA pattern.
+version, so the wire surface can stabilise without locking the
+library cadence.
### Added
- **HTTP API reference server (`vaara[server]` extra).** Exposes the
@@ -1789,10 +1881,9 @@ it governs.
action class declared, matched sequences known).
- **`vaara.policy.test_cases_io` module.** `load_test_cases(path)`
reads a YAML or JSON cases document and returns a list of
- `PolicyTestCase`. Document shape mirrors typical OPA / Conftest
- test files: a top-level `cases:` list with `action_class`,
- `risk_score`, optional `matched_sequences`, and an `expect:` block
- carrying `verdict` and optional `route`.
+ `PolicyTestCase`. Document shape: a top-level `cases:` list with
+ `action_class`, `risk_score`, optional `matched_sequences`, and an
+ `expect:` block carrying `verdict` and optional `route`.
- **`vaara policy validate POLICY_PATH [--json]`** and **`vaara
policy test POLICY_PATH --cases CASES_PATH [--json]`** CLI
subcommands. Both honour standard CI exit codes: validate returns
diff --git a/README.md b/README.md
index 6961b74..1ec79c6 100644
--- a/README.md
+++ b/README.md
@@ -168,13 +168,32 @@ if (r.decision === "deny") throw new Error("blocked");
`vaara.integrations.mcp_proxy.VaaraMCPProxy` sits between an MCP client (Claude Code, Cursor, any MCP-capable host) and an upstream MCP server. Every `tools/call` from the client routes through Vaara's interception pipeline before reaching the upstream. Allowed calls forward transparently and report the upstream outcome back to the scorer. Blocked calls return an MCP `isError: true` response with the block reason. The initialization handshake and `notifications/*` forward unchanged. `tools/list`, `resources/list`, `resources/read`, `prompts/list`, and `prompts/get` route through the operator perimeter before reaching the client or upstream.
```bash
-python -m vaara.integrations.mcp_proxy \
+vaara-mcp-proxy \
--upstream npx --upstream-arg -y --upstream-arg @sap/mdk-mcp-server \
--db ./mcp_audit.db
```
Point your MCP client at the proxy instead of the upstream. The audit chain captures every tool call without changing client or upstream behavior. Distinct from `mcp_server`, which exposes Vaara itself as an MCP server for agents that consult Vaara as a tool.
+
+Fleet shape (v0.40): one proxy, many upstreams, multi-tenant policy
+
+`vaara-mcp-proxy` also runs over Streamable HTTP with fan-out, so one process can serve a fleet of upstream MCP servers:
+
+```bash
+vaara-mcp-proxy \
+ --transport http \
+ --http-host 127.0.0.1 \
+ --http-port 8765 \
+ --upstream 'github=npx -y @github/mcp-server' \
+ --upstream 'sap=npx -y @sap/mdk-mcp-server'
+```
+
+Each `POST /mcp` reads two headers. `X-Vaara-Upstream` picks the upstream slot. `X-Vaara-Tenant` scopes the policy, audit chain, and OVERT envelope for that call. Single-upstream deployments keep the v0.39 silent-default contract. Multi-upstream deployments require `X-Vaara-Upstream` per call and return 400 with the available slot list when the header is missing.
+
+The reference HTTP API server (`vaara serve --policy-dir DIR`) loads one YAML or JSON policy per file in the directory (filename stem becomes the `tenant_id`, `default.yaml` lands in the fallback slot) and hot-reloads per tenant via `POST /v1/policy/reload` with a `tenant_id` body field or `X-Vaara-Tenant` header. The scorer dispatches allow and deny thresholds per call against the calling tenant's policy at `evaluate()` time.
+
+
Operator perimeter: tool, resource, prompt filtering
@@ -194,7 +213,7 @@ vaara keygen --dev --out signing.pem
head -c 32 /dev/urandom > op.key
# 3. Run the proxy with OVERT emission turned on.
-python -m vaara.integrations.mcp_proxy \
+vaara-mcp-proxy \
--upstream npx --upstream-arg -y --upstream-arg @sap/mdk-mcp-server \
--overt-signing-key signing.pem \
--overt-operator-key op.key \
diff --git a/clients/ts/package.json b/clients/ts/package.json
index 4c02587..11d047c 100644
--- a/clients/ts/package.json
+++ b/clients/ts/package.json
@@ -1,6 +1,6 @@
{
"name": "@vaara/client",
- "version": "0.39.2",
+ "version": "0.40.0",
"description": "TypeScript client for the Vaara HTTP API. Conformal risk scoring, hash-chained audit, policy reload, named detectors.",
"main": "dist/index.js",
"types": "dist/index.d.ts",
diff --git a/pyproject.toml b/pyproject.toml
index bb399c7..522ba27 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vaara"
-version = "0.39.2"
+version = "0.40.0"
description = "Adaptive AI Agent Execution Layer for risk scoring, audit trails, and regulatory compliance"
requires-python = ">=3.10"
license = "Apache-2.0"
@@ -58,6 +58,7 @@ rebuff = ["rebuff>=0.1"]
[project.scripts]
vaara = "vaara.cli:main"
vaara-audit = "vaara.audit_cli:main"
+vaara-mcp-proxy = "vaara.integrations.mcp_proxy:main"
[tool.setuptools.packages.find]
where = ["src"]
diff --git a/src/vaara/__init__.py b/src/vaara/__init__.py
index da5ae1b..f7071ee 100644
--- a/src/vaara/__init__.py
+++ b/src/vaara/__init__.py
@@ -6,7 +6,7 @@
oversight.
"""
-__version__ = "0.39.2"
+__version__ = "0.40.0"
from vaara.pipeline import InterceptionPipeline, InterceptionResult
diff --git a/src/vaara/audit/sqlite_backend.py b/src/vaara/audit/sqlite_backend.py
index 13ec5fe..7f9df84 100644
--- a/src/vaara/audit/sqlite_backend.py
+++ b/src/vaara/audit/sqlite_backend.py
@@ -357,7 +357,11 @@ def write_record(self, record: AuditRecord) -> None:
_strict_json_dumps(record.regulatory_articles),
record.previous_hash,
record.record_hash,
- self._tenant_id,
+ # Per-record tenant_id wins so a single backend instance
+ # can serve a multi-tenant runtime (v0.40+). Empty record
+ # tenant_id falls back to instance scope for the legacy
+ # single-tenant init path.
+ record.tenant_id or self._tenant_id,
record.system_operation,
record.data_usage,
record.decision_making,
@@ -674,6 +678,7 @@ def _row_to_record(self, row: tuple) -> AuditRecord:
agent_id = self._redaction_cache[agent_id]
# Defensive indexing: rows from older queries may not include
# the v3 columns. Use a guard so loading old DBs still works.
+ tenant_id = row[11] if len(row) > 11 else ""
sys_op = row[12] if len(row) > 12 else None
data_use = row[13] if len(row) > 13 else None
dec_mk = row[14] if len(row) > 14 else None
@@ -689,6 +694,7 @@ def _row_to_record(self, row: tuple) -> AuditRecord:
regulatory_articles=json.loads(row[7]),
previous_hash=row[8],
record_hash=row[9],
+ tenant_id=tenant_id or "",
system_operation=sys_op,
data_usage=data_use,
decision_making=dec_mk,
diff --git a/src/vaara/audit/trail.py b/src/vaara/audit/trail.py
index 6d08712..4f7cb13 100644
--- a/src/vaara/audit/trail.py
+++ b/src/vaara/audit/trail.py
@@ -258,6 +258,9 @@ class AuditRecord:
data_usage: Optional[str] = None
decision_making: Optional[str] = None
limitations: Optional[str] = None
+ # v0.40: multi-tenant scoping. Empty string = single-tenant deployment.
+ # Excluded from compute_hash() to preserve pre-v0.40 chain re-verification.
+ tenant_id: str = ""
def __post_init__(self) -> None:
# Loaded-from-DB records carry a non-empty record_hash. Skip
@@ -492,6 +495,12 @@ def __init__(
self._by_action: dict[str, list[AuditRecord]] = defaultdict(list)
self._last_hash = ""
self._on_record = on_record
+ # v0.40 multi-tenant: action_id -> tenant_id, seeded by
+ # record_action_requested. Subsequent record_* calls (decision,
+ # execution, escalation) look up the action_id so every record in
+ # the lifecycle carries the same tenant scope without forcing
+ # every caller to thread tenant_id through every method signature.
+ self._tenant_for_action: dict[str, str] = {}
# Counts on_record callback failures so callers can detect
# persistence divergence at runtime (e.g., DB gone, disk full).
# Without this, a silent logger.error is the only signal and the
@@ -554,9 +563,33 @@ def verify_chain(self) -> Optional[str]:
# ── Recording events ──────────────────────────────────────────
+ # Defense-in-depth cap for direct-trail callers that bypass the pipeline's
+ # length cap on tenant_id. The HTTP boundary already caps at 256 via the
+ # Pydantic schema, but the AuditTrail public API is reachable from
+ # embedders that construct ActionRequest directly. A 50MB tenant_id would
+ # otherwise balloon every record on the hash chain and the in-memory
+ # action -> tenant map.
+ _MAX_TENANT_ID_LEN = 256
+ # Soft cap on the action -> tenant map. Long-running multi-tenant
+ # deployments would otherwise leak memory at one entry per action,
+ # because OUTCOME_RECORDED arrives well after ACTION_REQUESTED and the
+ # map cannot be cleared at decision time. When the cap is reached the
+ # oldest 1/8 of the map is evicted; subsequent lookups for evicted
+ # actions fall back to "" tenant, which is the legacy single-tenant
+ # contract — correct fail-soft behaviour.
+ _MAX_ACTION_TENANT_MAP = 50_000
+
def record_action_requested(self, request: ActionRequest) -> str:
"""Record that an agent requested an action. Returns the action_id."""
action_id = str(uuid.uuid4())
+ tenant_id = getattr(request, "tenant_id", "") or ""
+ if tenant_id:
+ tenant_id = self._cap_record_str(tenant_id, self._MAX_TENANT_ID_LEN)
+ if len(self._tenant_for_action) >= self._MAX_ACTION_TENANT_MAP:
+ evict = max(1, self._MAX_ACTION_TENANT_MAP // 8)
+ for stale in list(self._tenant_for_action)[:evict]:
+ self._tenant_for_action.pop(stale, None)
+ self._tenant_for_action[action_id] = tenant_id
articles = self._get_regulatory_articles(
EventType.ACTION_REQUESTED,
@@ -600,10 +633,20 @@ def record_action_requested(self, request: ActionRequest) -> str:
"sequence_position": request.sequence_position,
},
regulatory_articles=articles,
+ tenant_id=tenant_id,
))
return action_id
+ def _tenant_for(self, action_id: str) -> str:
+ """Resolve the tenant scope for an existing action lifecycle.
+
+ Returns the tenant_id captured at record_action_requested time so
+ every follow-up record (risk_scored, decision, execution,
+ escalation, outcome) carries the same scope automatically.
+ """
+ return self._tenant_for_action.get(action_id, "")
+
def record_risk_scored(
self,
action_id: str,
@@ -631,6 +674,7 @@ def record_risk_scored(
tool_name=self._cap_record_str(tool_name, self._MAX_TOOL_NAME_LEN),
data=safe_assessment,
regulatory_articles=articles,
+ tenant_id=self._tenant_for(action_id),
))
def record_decision(
@@ -666,6 +710,7 @@ def record_decision(
"risk_score": risk_score,
},
regulatory_articles=articles,
+ tenant_id=self._tenant_for(action_id),
))
def record_execution(
@@ -702,6 +747,7 @@ def record_execution(
data={"result_summary": self._cap_record_dict_bytes(
safe_result, self._MAX_EXECUTION_RESULT_JSON_BYTES
)},
+ tenant_id=self._tenant_for(action_id),
))
def record_escalation(
@@ -731,6 +777,7 @@ def record_escalation(
"risk_score": risk_score,
},
regulatory_articles=articles,
+ tenant_id=self._tenant_for(action_id),
))
def record_escalation_resolved(
@@ -760,6 +807,7 @@ def record_escalation_resolved(
"justification": self._cap_record_str(justification, self._MAX_JUSTIFICATION_LEN),
},
regulatory_articles=articles,
+ tenant_id=self._tenant_for(action_id),
))
def record_outcome(
@@ -794,6 +842,7 @@ def record_outcome(
),
},
regulatory_articles=articles,
+ tenant_id=self._tenant_for(action_id),
))
# Length caps for caller-controlled free-text fields on this direct
@@ -923,6 +972,7 @@ def record_policy_override(
"new_decision": new_decision,
},
regulatory_articles=articles,
+ tenant_id=self._tenant_for(action_id),
))
# ── Querying ──────────────────────────────────────────────────
diff --git a/src/vaara/cli.py b/src/vaara/cli.py
index 19f0ec2..9241e6e 100644
--- a/src/vaara/cli.py
+++ b/src/vaara/cli.py
@@ -1015,9 +1015,33 @@ def _cmd_serve(args: argparse.Namespace) -> int:
from vaara.server import create_app
- controller = None
policy_path = getattr(args, "policy", None)
- if policy_path:
+ policy_dir = getattr(args, "policy_dir", None)
+ if policy_path and policy_dir:
+ print(
+ "vaara serve: pass either --policy or --policy-dir, not both.",
+ file=sys.stderr,
+ )
+ return 2
+
+ controller = None
+ registry = None
+ if policy_dir:
+ from vaara.policy.registry import PolicyRegistry
+ from vaara.policy.schema import PolicyError
+
+ registry = PolicyRegistry()
+ try:
+ tenants = registry.load_directory(Path(policy_dir).expanduser())
+ except PolicyError as exc:
+ print(f"vaara serve: --policy-dir failed to load: {exc}", file=sys.stderr)
+ return 2
+ print(
+ f"vaara serve: loaded {len(tenants)} tenant policies "
+ f"(tenants={tenants!r})",
+ file=sys.stderr,
+ )
+ elif policy_path:
from vaara.policy.controller import PolicyController
from vaara.policy.validate import validate_source
@@ -1032,7 +1056,7 @@ def _cmd_serve(args: argparse.Namespace) -> int:
return 2
controller = PolicyController(policy_obj)
- app = create_app(policy_controller=controller)
+ app = create_app(policy_controller=controller, policy_registry=registry)
uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level)
return 0
@@ -1394,6 +1418,15 @@ def build_parser() -> argparse.ArgumentParser:
"sequence patterns are applied to the scorer at startup."
),
)
+ pserve.add_argument(
+ "--policy-dir",
+ default=None,
+ help=(
+ "Directory of per-tenant policy files. Each *.yaml/*.yml/*.json "
+ "file is loaded; filename stem becomes the tenant_id "
+ "(default.yaml -> fallback). Mutually exclusive with --policy."
+ ),
+ )
pserve.add_argument(
"--log-level",
default="info",
diff --git a/src/vaara/integrations/mcp_proxy.py b/src/vaara/integrations/mcp_proxy.py
index db88754..193441a 100644
--- a/src/vaara/integrations/mcp_proxy.py
+++ b/src/vaara/integrations/mcp_proxy.py
@@ -25,9 +25,11 @@
from __future__ import annotations
import argparse
+import contextvars
import json
import logging
import os
+import re
import sys
import threading
from pathlib import Path
@@ -48,9 +50,48 @@
from vaara.pipeline import InterceptionPipeline
from vaara.taxonomy.actions import ActionRequest
+# Optional dependency: only the streamable-HTTP transport needs FastAPI /
+# Starlette. Keep the import lazy so the stdio path stays installable with
+# the base extras only.
+try: # pragma: no cover - import guard
+ from starlette.requests import Request as _StarletteRequest
+except ImportError: # pragma: no cover
+ _StarletteRequest = None # type: ignore[assignment]
+
logger = logging.getLogger(__name__)
+# v0.40 per-request request scope. HTTP transport sets these per inbound
+# request so _handle_request and friends can look up the right upstream and
+# tag the audit/interception trail with the request's tenant_id without
+# threading the values through every helper signature.
+_REQUEST_UPSTREAM: contextvars.ContextVar[str] = contextvars.ContextVar(
+ "vaara_mcp_upstream", default="default",
+)
+_REQUEST_TENANT: contextvars.ContextVar[str] = contextvars.ContextVar(
+ "vaara_mcp_tenant", default="",
+)
+
+def _safe_log(value: Any, max_len: int = 200) -> str:
+ """Sanitise a user-supplied string for safe logging.
+
+ Strips CR/LF and other control characters so an attacker who controls
+ a tool / resource / prompt name can't inject fake log lines, and caps
+ length so a multi-megabyte name doesn't blow the log up.
+ """
+ if not isinstance(value, str):
+ value = str(value)
+ cleaned = "".join(c if c.isprintable() and c not in ("\r", "\n") else "?" for c in value)
+ return cleaned[:max_len]
+
+
+# Largest single MCP JSON-RPC message accepted on the /mcp HTTP endpoint.
+# Real tool calls and responses fit comfortably; the cap stops a malicious
+# client from exhausting memory at parse time. v0.41 can promote this to
+# a CLI flag if a real workload needs more headroom.
+_MCP_HTTP_MAX_BODY_BYTES = 1 * 1024 * 1024
+
+
class VaaraMCPProxy:
"""Transparent MCP proxy with Vaara interception on tool calls."""
@@ -58,7 +99,7 @@ class VaaraMCPProxy:
def __init__(
self,
- upstream_command: list[str],
+ upstream_command: Optional[list[str]] = None,
pipeline: Optional[InterceptionPipeline] = None,
db_path: Optional[Path] = None,
agent_id_default: str = "mcp-proxy-client",
@@ -69,6 +110,7 @@ def __init__(
prompt_allowlist: Optional[set[str]] = None,
prompt_denylist: Optional[set[str]] = None,
overt_emitter: Optional[OVERTReceiptEmitter] = None,
+ upstreams: Optional[dict[str, list[str]]] = None,
) -> None:
if pipeline is not None:
self._pipeline = pipeline
@@ -98,17 +140,86 @@ def __init__(
)
self._stdout_lock = threading.Lock()
self._overt = overt_emitter
- # progressToken -> (action_id, agent_id, tool_name). Populated when a
- # tools/call enters interception with params._meta.progressToken set,
- # consulted by the upstream-notification handler so progress events
+ # progressToken -> (action_id, agent_id, tool_name, tenant_id). Populated
+ # when a tools/call enters interception with params._meta.progressToken
+ # set, consulted by the upstream-notification handler so progress events
# arriving mid-call carry the originating action_id into the audit
- # record and into the OVERT envelope's non_content_metadata.
- self._inflight_progress: dict[Any, tuple[str, str, str]] = {}
+ # record and into the OVERT envelope's non_content_metadata. The tenant
+ # is captured at tools/call time because the upstream reader thread that
+ # delivers later notifications does not inherit the request ContextVars.
+ self._inflight_progress: dict[Any, tuple[str, str, str, str]] = {}
self._inflight_lock = threading.Lock()
- self._upstream = UpstreamMCPClient(
- command=upstream_command,
- on_notification=self._on_upstream_notification,
- )
+ # v0.40 fan-out: hold N upstream MCP servers in a name -> client map.
+ # The single-upstream legacy entry point (positional ``upstream_command``)
+ # lands under the "default" name. ``--upstream NAME=CMD`` via CLI or
+ # ``upstreams={"NAME": [cmd, ...]}`` populates the map directly. The
+ # HTTP transport reads X-Vaara-Upstream per inbound request to pick;
+ # stdio transport stays on "default".
+ if upstreams and upstream_command is not None:
+ raise ValueError(
+ "Pass either upstream_command (single-upstream legacy) or "
+ "upstreams (multi-upstream fan-out), not both.",
+ )
+ upstream_map: dict[str, list[str]] = {}
+ if upstreams:
+ upstream_map = {name: list(cmd) for name, cmd in upstreams.items()}
+ elif upstream_command is not None:
+ upstream_map = {"default": list(upstream_command)}
+ else:
+ raise ValueError(
+ "VaaraMCPProxy requires upstream_command or upstreams.",
+ )
+ default_alias_target: Optional[str] = None
+ if "default" not in upstream_map:
+ # Pick a stable fallback so requests without X-Vaara-Upstream
+ # still resolve. Lexicographic first keeps multi-tenant fleets
+ # deterministic across restarts. Alias the slot rather than
+ # cloning the command so we never spawn a duplicate subprocess.
+ default_alias_target = sorted(upstream_map)[0]
+ self._upstreams: dict[str, UpstreamMCPClient] = {
+ name: UpstreamMCPClient(
+ command=command,
+ on_notification=self._on_upstream_notification,
+ )
+ for name, command in upstream_map.items()
+ }
+ if default_alias_target is not None:
+ self._upstreams["default"] = self._upstreams[default_alias_target]
+ # ``self._upstream`` resolves to the per-request upstream via the
+ # ``_REQUEST_UPSTREAM`` ContextVar. stdio transport leaves the
+ # default ("default") in place, so existing single-upstream callers
+ # see exactly one client. HTTP transport sets the ctxvar per
+ # request to dispatch into the right fleet member.
+
+ @property
+ def _upstream(self) -> UpstreamMCPClient:
+ """Resolve the upstream MCP client for the current request scope.
+
+ HTTP transport sets ``_REQUEST_UPSTREAM`` per inbound request. stdio
+ transport never sets it; the default value "default" routes to the
+ legacy single-upstream slot. Unknown names raise ``ProxyError`` so
+ a client asking for a fleet member that does not exist gets a
+ loud failure instead of a silent reroute.
+ """
+ name = _REQUEST_UPSTREAM.get()
+ client = self._upstreams.get(name)
+ if client is None:
+ raise ProxyError(
+ f"No upstream named {name!r}; configured names: "
+ f"{sorted(self._upstreams)!r}",
+ )
+ return client
+
+ @_upstream.setter
+ def _upstream(self, client: UpstreamMCPClient) -> None:
+ """Replace the default-slot upstream client.
+
+ Test fixtures and embedders that previously assigned
+ ``proxy._upstream = MagicMock()`` keep working under the v0.40
+ fan-out shape — the assignment lands in the "default" slot and the
+ property reads it back.
+ """
+ self._upstreams["default"] = client
@staticmethod
def _is_filtered(name: object, allowlist: Optional[set[str]], denylist: set[str]) -> bool:
@@ -150,6 +261,163 @@ def run(self) -> None:
continue
self._write_to_client(self._handle_request(request))
+ def run_http(self, host: str, port: int, log_level: str = "info") -> None:
+ """Run the proxy on Streamable HTTP (MCP 2026 transport).
+
+ POST /mcp accepts one JSON-RPC message and returns one JSON
+ response. Notifications (no ``id``) return 202 Accepted and are
+ forwarded to the upstream without a reply. Multi-tenant /
+ fan-out scope is read per request from the
+ ``X-Vaara-Tenant`` and ``X-Vaara-Upstream`` headers and pushed
+ into the per-request ContextVars before dispatch.
+ """
+ try:
+ import uvicorn
+ from fastapi import FastAPI, Header, HTTPException, Response
+ from fastapi.responses import JSONResponse
+ except ImportError as exc:
+ raise RuntimeError(
+ "vaara-mcp-proxy --transport http requires the 'server' "
+ "extra. Install with: pip install 'vaara[server]'"
+ ) from exc
+ if _StarletteRequest is None:
+ raise RuntimeError(
+ "starlette is required for the streamable-HTTP transport. "
+ "Install with: pip install 'vaara[server]'"
+ )
+
+ proxy = self
+
+ app = FastAPI(
+ title="Vaara MCP Proxy",
+ version=_VAARA_VERSION,
+ description=(
+ "Streamable HTTP transport in front of one or more upstream "
+ "MCP servers, with Vaara runtime governance applied to every "
+ "tools/call."
+ ),
+ )
+
+ @app.get("/health")
+ async def health() -> dict:
+ return {
+ "status": "ok",
+ "proxy": proxy.PROXY_NAME,
+ "upstreams": sorted(proxy._upstreams.keys()),
+ }
+
+ @app.post("/mcp")
+ async def mcp_endpoint(
+ request: _StarletteRequest,
+ x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"),
+ x_vaara_upstream: Optional[str] = Header(default=None, alias="X-Vaara-Upstream"),
+ ) -> Response:
+ # 1 MiB cap on a single MCP JSON-RPC message. Real tool calls and
+ # responses fit comfortably; anything larger is either a misuse or
+ # a DoS attempt against the proxy's JSON parser. The cap is the
+ # same order as the MCP reference servers' limit. Hard-cap here
+ # before json.loads runs so a malicious payload cannot exhaust
+ # memory at parse time.
+ content_length = request.headers.get("content-length")
+ if content_length is not None:
+ try:
+ if int(content_length) > _MCP_HTTP_MAX_BODY_BYTES:
+ return JSONResponse(
+ status_code=413,
+ content={"error": {
+ "code": "payload_too_large",
+ "message": (
+ f"MCP message exceeds "
+ f"{_MCP_HTTP_MAX_BODY_BYTES} bytes"
+ ),
+ }},
+ )
+ except ValueError:
+ pass # bogus content-length, fall through to actual read
+ raw = await request.body()
+ if len(raw) > _MCP_HTTP_MAX_BODY_BYTES:
+ return JSONResponse(
+ status_code=413,
+ content={"error": {
+ "code": "payload_too_large",
+ "message": (
+ f"MCP message exceeds "
+ f"{_MCP_HTTP_MAX_BODY_BYTES} bytes"
+ ),
+ }},
+ )
+ try:
+ payload = json.loads(raw.decode("utf-8") or "{}")
+ except json.JSONDecodeError:
+ return JSONResponse(
+ status_code=400,
+ content=proxy._error_response(None, -32700, "Parse error"),
+ )
+
+ header_name = (x_vaara_upstream or "").strip()
+ # Real upstream slots are everything except the "default" alias.
+ # When the operator configured exactly one real slot (single-
+ # upstream deployment), silent fallback preserves the v0.39
+ # contract. When the operator configured a fleet, ambiguity is
+ # an error: missing X-Vaara-Upstream returns 400 with the list
+ # so the client knows which slots are available, instead of
+ # silently routing to whichever slot won the sort.
+ real_slots = sorted(n for n in proxy._upstreams if n != "default")
+ ambiguous_fanout = len(real_slots) > 1
+ if not header_name:
+ if ambiguous_fanout:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": {
+ "code": "upstream_required",
+ "message": (
+ "X-Vaara-Upstream header is required "
+ "when the proxy serves more than one "
+ "upstream. Available upstreams: "
+ f"{real_slots!r}"
+ ),
+ }
+ },
+ )
+ upstream_name = "default"
+ else:
+ upstream_name = header_name
+ if upstream_name not in proxy._upstreams:
+ raise HTTPException(
+ status_code=404,
+ detail={
+ "error": {
+ "code": "unknown_upstream",
+ "message": (
+ f"No upstream named {upstream_name!r}; "
+ f"configured: {sorted(proxy._upstreams)!r}"
+ ),
+ }
+ },
+ )
+
+ upstream_token = _REQUEST_UPSTREAM.set(upstream_name)
+ tenant_token = _REQUEST_TENANT.set((x_vaara_tenant or "").strip())
+ try:
+ if isinstance(payload, dict) and "id" not in payload:
+ try:
+ proxy._upstream.notify(payload)
+ except ProxyError:
+ logger.exception("Failed to forward HTTP notification")
+ return Response(status_code=202)
+ response = proxy._handle_request(payload)
+ return JSONResponse(content=response)
+ finally:
+ _REQUEST_UPSTREAM.reset(upstream_token)
+ _REQUEST_TENANT.reset(tenant_token)
+
+ logger.info(
+ "Vaara MCP proxy starting on http://%s:%d (%s, upstreams=%s)",
+ host, port, self.PROXY_NAME, sorted(self._upstreams.keys()),
+ )
+ uvicorn.run(app, host=host, port=port, log_level=log_level)
+
def _handle_request(self, request: Any) -> dict:
if not isinstance(request, dict):
return self._error_response(None, -32600, "Invalid Request: not a JSON object")
@@ -257,7 +525,8 @@ def _handle_tools_call(self, request: dict) -> dict:
arguments = {}
if self._tool_filtered(tool_name):
logger.warning(
- "tools/call rejected at perimeter (operator filter): %s", tool_name,
+ "tools/call rejected at perimeter (operator filter): %s",
+ _safe_log(tool_name),
)
block_payload = {
"vaara_blocked": True,
@@ -290,6 +559,7 @@ def _handle_tools_call(self, request: dict) -> dict:
# registry (fail-closed). Correct default for runtime governance.
result = self._pipeline.intercept(
agent_id=agent_id, tool_name=tool_name, parameters=arguments,
+ tenant_id=_REQUEST_TENANT.get(),
)
progress_token = self._progress_token(params)
if not result.allowed:
@@ -326,6 +596,7 @@ def _handle_tools_call(self, request: dict) -> dict:
str(getattr(result, "action_id", None) or ""),
agent_id,
tool_name,
+ _REQUEST_TENANT.get(),
)
try:
upstream_response = self._upstream.request(request)
@@ -364,7 +635,8 @@ def _handle_resources_read(self, request: dict) -> dict:
uri = ""
if self._resource_filtered(uri):
logger.warning(
- "resources/read rejected at perimeter (operator filter): %s", uri,
+ "resources/read rejected at perimeter (operator filter): %s",
+ _safe_log(uri),
)
self._overt_emit(
surface="mcp.resource.read",
@@ -415,7 +687,8 @@ def _handle_prompts_get(self, request: dict) -> dict:
arguments = {}
if self._prompt_filtered(name):
logger.warning(
- "prompts/get rejected at perimeter (operator filter): %s", name,
+ "prompts/get rejected at perimeter (operator filter): %s",
+ _safe_log(name),
)
self._overt_emit(
surface="mcp.prompt.get",
@@ -478,6 +751,7 @@ def _record_perimeter_audit(
parameters: dict,
decision: str,
reason: str,
+ tenant_id: Optional[str] = None,
) -> None:
"""Write a request+decision audit pair for a read-oriented MCP access.
@@ -488,9 +762,14 @@ def _record_perimeter_audit(
policy while still producing the two records that anchor every
access to the hash chain. Failures here are logged and
swallowed: a perimeter audit failure must not block legitimate
- upstream traffic.
+ upstream traffic. ``tenant_id`` is taken from the request
+ ContextVar by default; async callers that run outside the
+ originating request (upstream notification reader thread) pass
+ the captured tenant explicitly.
"""
import time as _time
+ if tenant_id is None:
+ tenant_id = _REQUEST_TENANT.get()
try:
registry = self._pipeline.registry
action_type = registry.classify(tool_name, parameters)
@@ -500,6 +779,7 @@ def _record_perimeter_audit(
action_type=action_type,
parameters=parameters or {},
timestamp_utc=_time.strftime("%Y-%m-%dT%H:%M:%SZ", _time.gmtime()),
+ tenant_id=tenant_id,
)
action_id = self._pipeline.trail.record_action_requested(req)
self._pipeline.trail.record_decision(
@@ -524,12 +804,16 @@ def _overt_emit(
decision: str,
reason: str,
extra: Optional[dict] = None,
+ tenant_id: Optional[str] = None,
) -> None:
"""Emit one OVERT Base Envelope for an MCP interaction.
No-op when no emitter is configured. Failures are logged and
swallowed: an attestation-side failure must not block legitimate
- upstream traffic, mirroring the perimeter-audit rule.
+ upstream traffic, mirroring the perimeter-audit rule. ``tenant_id``
+ is taken from the request ContextVar by default; async callers
+ that run outside the originating request pass the captured tenant
+ explicitly so the OVERT claim attributes to the right tenant.
"""
if self._overt is None:
return
@@ -540,6 +824,10 @@ def _overt_emit(
"decision": decision,
"reason": reason,
}
+ if tenant_id is None:
+ tenant_id = _REQUEST_TENANT.get()
+ if tenant_id:
+ non_content_metadata["tenant_id"] = tenant_id
if extra:
non_content_metadata.update(extra)
request_payload = strict_json_dumps(
@@ -604,11 +892,17 @@ def _audit_progress_notification(self, message: dict) -> None:
parent_action_id = ""
agent_id = self._agent_id_default
parent_tool = ""
+ # Notifications arrive on the upstream reader thread, which does
+ # not inherit the request ContextVars. Pull the tenant captured
+ # at tools/call time out of the inflight map instead of reading
+ # _REQUEST_TENANT here, otherwise the audit + OVERT claim land
+ # under empty tenant scope.
+ captured_tenant = ""
if token is not None:
with self._inflight_lock:
entry = self._inflight_progress.get(token)
if entry is not None:
- parent_action_id, agent_id, parent_tool = entry
+ parent_action_id, agent_id, parent_tool, captured_tenant = entry
self._record_perimeter_audit(
agent_id=agent_id,
tool_name="mcp.notification.progress",
@@ -619,6 +913,7 @@ def _audit_progress_notification(self, message: dict) -> None:
},
decision="observed",
reason="upstream progress notification observed",
+ tenant_id=captured_tenant,
)
self._overt_emit(
surface="mcp.notification.progress",
@@ -632,6 +927,7 @@ def _audit_progress_notification(self, message: dict) -> None:
"parent_action_id": parent_action_id,
"parent_tool": parent_tool,
},
+ tenant_id=captured_tenant,
)
def _audit_message_notification(self, message: dict) -> None:
@@ -644,12 +940,17 @@ def _audit_message_notification(self, message: dict) -> None:
log_logger = params.get("logger", "")
if not isinstance(log_logger, str):
log_logger = ""
+ # Log notifications carry no progressToken, so there is no way to
+ # recover the originating request's tenant from the reader thread.
+ # Pass tenant_id="" explicitly to make the fail-soft scope visible
+ # rather than reading _REQUEST_TENANT in a thread that never set it.
self._record_perimeter_audit(
agent_id=self._agent_id_default,
tool_name="mcp.notification.message",
parameters={"level": level, "logger": log_logger},
decision="observed",
reason="upstream log notification observed",
+ tenant_id="",
)
self._overt_emit(
surface="mcp.notification.message",
@@ -659,6 +960,7 @@ def _audit_message_notification(self, message: dict) -> None:
decision="observed",
reason="upstream log notification observed",
extra={"agent_id": self._agent_id_default},
+ tenant_id="",
)
def _write_to_client(self, payload: dict) -> None:
@@ -672,7 +974,11 @@ def _error_response(req_id: Any, code: int, message: str) -> dict:
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}}
def close(self) -> None:
- self._upstream.close()
+ for client in self._upstreams.values():
+ try:
+ client.close()
+ except Exception:
+ logger.exception("Failed to close upstream MCP client")
if self._backend is not None:
self._backend.close()
@@ -680,11 +986,39 @@ def close(self) -> None:
def main(argv: Optional[list[str]] = None) -> None:
parser = argparse.ArgumentParser(
prog="vaara-mcp-proxy",
- description="Vaara runtime governance proxy in front of an upstream MCP server.",
+ description="Vaara runtime governance proxy in front of one or more upstream MCP servers.",
+ )
+ parser.add_argument(
+ "--upstream", action="append", default=[], dest="upstreams",
+ help=(
+ "Upstream MCP server command. Repeatable for v0.40 fan-out: "
+ "`--upstream NAME=CMD` registers under a named slot; bare "
+ "`--upstream CMD` lands under 'default'. The first slot (or "
+ "'default' when supplied) is the stdio fallback."
+ ),
)
- parser.add_argument("--upstream", required=True, help="Upstream MCP server command")
parser.add_argument("--upstream-arg", action="append", default=[], dest="upstream_args",
- help="Argument to pass to the upstream command (repeatable)")
+ help="Argument to pass to the (first) upstream command (repeatable)")
+ parser.add_argument(
+ "--transport",
+ choices=["stdio", "http"],
+ default="stdio",
+ help=(
+ "stdio (default) reads JSON-RPC from stdin/stdout, suitable for "
+ "in-process MCP clients. http exposes Streamable HTTP "
+ "(POST /mcp) for fleet / multi-tenant deployments and "
+ "requires the [server] extra."
+ ),
+ )
+ parser.add_argument("--http-host", default="127.0.0.1",
+ help="Bind host when --transport http (default 127.0.0.1)")
+ parser.add_argument("--http-port", type=int, default=8765,
+ help="Bind port when --transport http (default 8765)")
+ parser.add_argument(
+ "--http-log-level",
+ default="info",
+ choices=["critical", "error", "warning", "info", "debug", "trace"],
+ )
parser.add_argument("--db", type=Path, default=None,
help="Audit database path (default: $VAARA_DB or ./vaara_audit.db)")
parser.add_argument("--agent-id", default="mcp-proxy-client",
@@ -743,8 +1077,19 @@ def main(argv: Optional[list[str]] = None) -> None:
),
)
+ upstreams = _parse_upstream_specs(args.upstreams, args.upstream_args)
+ if not upstreams:
+ parser.error(
+ "at least one --upstream is required (e.g. `--upstream "
+ "github=github-mcp-server` or `--upstream npx`).",
+ )
+
+ legacy_single = (
+ list(next(iter(upstreams.values()))) if len(upstreams) == 1 else None
+ )
proxy = VaaraMCPProxy(
- upstream_command=[args.upstream, *args.upstream_args],
+ upstream_command=legacy_single,
+ upstreams=upstreams if legacy_single is None else None,
db_path=args.db, agent_id_default=args.agent_id,
allowlist=tool_allow,
denylist=tool_deny if tool_deny else None,
@@ -755,11 +1100,61 @@ def main(argv: Optional[list[str]] = None) -> None:
overt_emitter=overt_emitter,
)
try:
- proxy.run()
+ if args.transport == "http":
+ proxy.run_http(
+ host=args.http_host,
+ port=args.http_port,
+ log_level=args.http_log_level,
+ )
+ else:
+ proxy.run()
finally:
proxy.close()
+# A fan-out slot name is a short alphanumeric slug. The narrow pattern
+# stops _parse_upstream_specs from confusing a command that itself
+# contains '=' (e.g. ``python -m foo --bar=baz``) with a NAME=CMD prefix.
+_UPSTREAM_NAME_RE = re.compile(r"^[a-zA-Z0-9_\-]{1,64}$")
+
+
+def _parse_upstream_specs(
+ upstream_specs: list[str], legacy_args: list[str],
+) -> dict[str, list[str]]:
+ """Turn ``--upstream`` / ``--upstream-arg`` CLI input into a fan-out map.
+
+ Each ``--upstream`` is either ``NAME=CMD`` (named — NAME is a short
+ alphanumeric slug) or ``CMD`` (lands under "default"). Commands that
+ contain ``=`` (e.g. ``python -m foo --bar=baz``) stay intact because
+ the NAME-prefix check rejects anything whose left-of-``=`` half isn't
+ a valid slug. Legacy ``--upstream-arg`` values append to the first
+ named slot for back-compat with single-upstream callers.
+ """
+ upstreams: dict[str, list[str]] = {}
+ first_name: Optional[str] = None
+ for spec in upstream_specs:
+ if "=" in spec:
+ candidate_name, _, candidate_cmd = spec.partition("=")
+ candidate_name = candidate_name.strip()
+ candidate_cmd = candidate_cmd.strip()
+ if _UPSTREAM_NAME_RE.match(candidate_name) and candidate_cmd:
+ name, command = candidate_name, candidate_cmd
+ else:
+ name, command = "default", spec
+ else:
+ name, command = "default", spec
+ if not name or not command:
+ raise SystemExit(
+ f"invalid --upstream value {spec!r}; expected NAME=CMD or CMD",
+ )
+ upstreams[name] = [command]
+ if first_name is None:
+ first_name = name
+ if legacy_args and first_name is not None:
+ upstreams[first_name].extend(legacy_args)
+ return upstreams
+
+
def _build_overt_emitter_from_args(
args: argparse.Namespace, *, policy_hash: bytes,
) -> Optional[OVERTReceiptEmitter]:
diff --git a/src/vaara/pipeline.py b/src/vaara/pipeline.py
index 3876987..30afc13 100644
--- a/src/vaara/pipeline.py
+++ b/src/vaara/pipeline.py
@@ -266,6 +266,7 @@ def intercept(
session_id: str = "",
parent_action_id: Optional[str] = None,
sequence_position: int = 0,
+ tenant_id: str = "",
) -> InterceptionResult:
"""Intercept an agent action request.
@@ -314,6 +315,7 @@ def intercept(
parent_action_id=parent_action_id,
sequence_position=sequence_position,
timestamp_utc=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
+ tenant_id=tenant_id,
)
# 3. Record the request in audit trail
diff --git a/src/vaara/policy/controller.py b/src/vaara/policy/controller.py
index 9908d18..b41efba 100644
--- a/src/vaara/policy/controller.py
+++ b/src/vaara/policy/controller.py
@@ -69,7 +69,7 @@ def add_listener(self, listener: PolicyListener) -> None:
listener(self._policy)
def reload(
- self, source: Union[str, Path, dict], *, format: Optional[str] = None
+ self, source: Union[str, Path, dict, Policy], *, format: Optional[str] = None
) -> ReloadResult:
"""Parse, validate, and apply a new policy.
@@ -77,10 +77,14 @@ def reload(
When omitted, ``.yaml``/``.yml`` paths use the YAML loader, dicts
bypass parsing, and everything else goes through JSON.
+ Already-validated ``Policy`` instances may be passed directly; the
+ registry path (``vaara.policy.registry.PolicyRegistry``) uses this
+ to swap a per-tenant policy that was parsed in bulk.
+
Raises ``PolicyError`` if the source is malformed; in that case
the previously loaded policy remains live.
"""
- new_policy = _load(source, format)
+ new_policy = source if isinstance(source, Policy) else _load(source, format)
with self._lock:
self._policy = new_policy
self._version += 1
diff --git a/src/vaara/policy/registry.py b/src/vaara/policy/registry.py
new file mode 100644
index 0000000..e48601c
--- /dev/null
+++ b/src/vaara/policy/registry.py
@@ -0,0 +1,149 @@
+"""Per-tenant policy registry.
+
+A ``PolicyRegistry`` owns one ``PolicyController`` per tenant_id, with the
+empty string ("") reserved for the default / fallback policy used when a
+request carries no tenant scope or no tenant-specific policy is loaded.
+
+Filename convention for ``load_directory``:
+
+* ``default.yaml`` / ``default.json`` → tenant_id ""
+* ``TENANT.yaml`` / ``TENANT.json`` → tenant_id "TENANT"
+
+This is the v0.40 multi-tenant policy plane. Single-tenant deployments
+keep using ``vaara serve --policy PATH``, which lands in the "" slot.
+"""
+
+from __future__ import annotations
+
+import threading
+from pathlib import Path
+from typing import Optional, Union
+
+from vaara.policy.controller import PolicyController, ReloadResult
+from vaara.policy.loader import from_json, from_yaml
+from vaara.policy.schema import Policy, PolicyError
+
+
+_DEFAULT_TENANT = ""
+_POLICY_SUFFIXES = (".yaml", ".yml", ".json")
+
+
+def _filename_to_tenant(stem: str) -> str:
+ return "" if stem.lower() == "default" else stem
+
+
+def _load_path(path: Path) -> Policy:
+ if path.suffix in (".yaml", ".yml"):
+ return from_yaml(path)
+ return from_json(path)
+
+
+class PolicyRegistry:
+ """Holds one PolicyController per tenant. Thread-safe."""
+
+ def __init__(self) -> None:
+ self._controllers: dict[str, PolicyController] = {}
+ self._lock = threading.RLock()
+
+ def __contains__(self, tenant_id: str) -> bool:
+ with self._lock:
+ return tenant_id in self._controllers
+
+ def tenants(self) -> list[str]:
+ with self._lock:
+ return sorted(self._controllers.keys())
+
+ def get(self, tenant_id: str) -> Optional[PolicyController]:
+ """Return the controller for ``tenant_id``, falling back to the
+ default ("") slot. Returns None if neither is registered.
+ """
+ with self._lock:
+ if tenant_id in self._controllers:
+ return self._controllers[tenant_id]
+ return self._controllers.get(_DEFAULT_TENANT)
+
+ def get_exact(self, tenant_id: str) -> Optional[PolicyController]:
+ """Return only an exact-match controller, no default fallback."""
+ with self._lock:
+ return self._controllers.get(tenant_id)
+
+ def register(self, tenant_id: str, controller: PolicyController) -> None:
+ with self._lock:
+ self._controllers[tenant_id] = controller
+
+ def reload(
+ self,
+ tenant_id: str,
+ source: Union[str, Path, dict],
+ *,
+ format: Optional[str] = None,
+ ) -> ReloadResult:
+ """Reload one tenant's policy. Creates the slot if missing."""
+ with self._lock:
+ controller = self._controllers.get(tenant_id)
+ if controller is None:
+ policy = _materialise(source, format)
+ controller = PolicyController(policy)
+ self._controllers[tenant_id] = controller
+ return ReloadResult(
+ version=controller.version,
+ thresholds_default_escalate=policy.thresholds_default.escalate,
+ thresholds_default_deny=policy.thresholds_default.deny,
+ sequence_count=len(policy.sequences),
+ action_class_count=len(policy.action_classes),
+ escalation_route_count=len(policy.escalation_routes),
+ )
+ return controller.reload(source, format=format)
+
+ def load_directory(self, directory: Union[str, Path]) -> list[str]:
+ """Load every ``*.yaml``/``*.yml``/``*.json`` file in ``directory`` as
+ one tenant's policy. Returns the list of tenant_ids loaded.
+
+ Raises ``PolicyError`` if any file fails to parse — partial loads
+ are not allowed, the registry stays untouched.
+ """
+ directory = Path(directory)
+ if not directory.is_dir():
+ raise PolicyError(f"policy directory does not exist: {directory}")
+
+ candidates: list[tuple[str, Path]] = []
+ for entry in sorted(directory.iterdir()):
+ if entry.suffix not in _POLICY_SUFFIXES or not entry.is_file():
+ continue
+ candidates.append((_filename_to_tenant(entry.stem), entry))
+
+ if not candidates:
+ raise PolicyError(f"policy directory holds no policy files: {directory}")
+
+ parsed: list[tuple[str, Policy]] = [
+ (tenant_id, _load_path(path)) for tenant_id, path in candidates
+ ]
+ with self._lock:
+ for tenant_id, policy in parsed:
+ existing = self._controllers.get(tenant_id)
+ if existing is None:
+ self._controllers[tenant_id] = PolicyController(policy)
+ else:
+ existing.reload(policy)
+ return [tenant_id for tenant_id, _ in parsed]
+
+
+def _materialise(source: Union[str, Path, dict], fmt: Optional[str]) -> Policy:
+ """Mirror PolicyController._load for the new-tenant fast path."""
+ from vaara.policy.loader import from_dict
+ if isinstance(source, dict):
+ return from_dict(source)
+ if fmt == "yaml":
+ return from_yaml(source)
+ if fmt == "json":
+ return from_json(source)
+ if isinstance(source, Path):
+ return _load_path(source)
+ if isinstance(source, str):
+ if source.lstrip().startswith("{"):
+ return from_json(source)
+ return _load_path(Path(source))
+ raise PolicyError(f"unsupported policy source type: {type(source).__name__}")
+
+
+__all__ = ["PolicyRegistry"]
diff --git a/src/vaara/scorer/adaptive.py b/src/vaara/scorer/adaptive.py
index 929bc29..a2eed81 100644
--- a/src/vaara/scorer/adaptive.py
+++ b/src/vaara/scorer/adaptive.py
@@ -31,7 +31,7 @@
from collections import OrderedDict, deque
from dataclasses import dataclass, field
from enum import Enum
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
from vaara.policy.schema import Policy
@@ -137,6 +137,8 @@ def to_backend_decision(self) -> dict:
"action": self.decision.value,
"reason": self.explanation,
"backend": "vaara_adaptive",
+ "threshold_allow": self.threshold_allow,
+ "threshold_deny": self.threshold_deny,
"raw_result": {
"point_estimate": self.point_estimate,
"conformal_interval": [self.conformal_lower, self.conformal_upper],
@@ -609,6 +611,7 @@ def __init__(
max_tracked_agents: int = 10_000,
pre_seed_calibration: bool = True,
mondrian_categories: bool = False,
+ policy_lookup: Optional[Callable[[str], Any]] = None,
) -> None:
"""
Args:
@@ -638,6 +641,17 @@ def __init__(
pre-Mondrian marginal behaviour. The default seed prior
always lands in the default bucket regardless of this
flag — synthetic benign pairs have no real category.
+ policy_lookup: Optional callable taking a tenant_id string and
+ returning that tenant's Policy (or None if the tenant has
+ no policy of its own). When set and the evaluate-time
+ context carries a non-empty tenant_id, the scorer uses
+ that tenant's allow/deny thresholds for THIS call instead
+ of the scorer-bound defaults. Tenants that resolve to
+ None fall back to the scorer-bound defaults; an empty
+ tenant_id skips the lookup entirely. The MWU expert
+ state, conformal calibrator, agent profiles, and sequence
+ detector remain shared across tenants — only threshold
+ application is per-tenant in v0.40.
"""
if threshold_allow >= threshold_deny:
raise ValueError(
@@ -647,6 +661,7 @@ def __init__(
)
self._threshold_allow = threshold_allow
self._threshold_deny = threshold_deny
+ self._policy_lookup = policy_lookup
self._burst_window = burst_window_seconds
self._burst_threshold = burst_threshold
self._max_tracked_agents = max_tracked_agents
@@ -735,6 +750,48 @@ def _seed_conformal_prior(self) -> None:
actual = 0.05
self._conformal.add_calibration_point(predicted, actual)
+ def set_policy_lookup(
+ self, policy_lookup: Optional[Callable[[str], Any]]
+ ) -> None:
+ """Late-binding setter for the per-tenant policy lookup.
+
+ ``ServerState`` constructs the scorer before the ``PolicyRegistry``
+ is wired up in some code paths, so the lookup needs to attach
+ after construction without going through __init__. Runs under
+ the scorer's RLock so an in-flight evaluate either sees the old
+ lookup (or None) or the new one — never a torn assignment.
+ """
+ with self._lock:
+ self._policy_lookup = policy_lookup
+
+ def _thresholds_for(self, tenant_id: str) -> tuple[float, float]:
+ """Resolve (allow, deny) thresholds for this call.
+
+ Empty tenant_id or no policy_lookup configured returns the
+ scorer-bound defaults rebound by the most recent apply_policy
+ on the default ("") slot. A configured lookup that returns
+ None (tenant has no policy of its own) also falls back to the
+ defaults. _thresholds_for is called from _evaluate_locked while
+ the scorer lock is held; the lookup acquires the registry lock
+ on top, so lock ordering is consistently scorer -> registry
+ across this codebase (never the reverse).
+ """
+ if not tenant_id or self._policy_lookup is None:
+ return (self._threshold_allow, self._threshold_deny)
+ try:
+ tenant_policy = self._policy_lookup(tenant_id)
+ except Exception:
+ return (self._threshold_allow, self._threshold_deny)
+ if tenant_policy is None:
+ return (self._threshold_allow, self._threshold_deny)
+ from vaara.policy.schema import Policy # local import to avoid cycles
+ if not isinstance(tenant_policy, Policy):
+ return (self._threshold_allow, self._threshold_deny)
+ return (
+ tenant_policy.thresholds_default.escalate,
+ tenant_policy.thresholds_default.deny,
+ )
+
def _calib_category(self, tool_name: str) -> Optional[str]:
"""Category to route through to the calibrator for this action.
@@ -771,11 +828,17 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any:
# Extract fields from context dict
tool_name = context.get("tool_name", "unknown")
agent_id = context.get("agent_id", "anonymous")
+ tenant_id = context.get("tenant_id", "") or ""
base_risk = _coerce_unit_float(context.get("base_risk_score", 0.5), 0.5)
agent_confidence = _coerce_optional_unit_float(context.get("agent_confidence"))
reversibility = context.get("reversibility", "partially_reversible")
blast_radius = context.get("blast_radius", "local")
+ # Per-tenant threshold dispatch. Empty tenant_id or no lookup
+ # configured returns the scorer-bound defaults bound at apply_policy
+ # time on the default ("") slot.
+ threshold_allow, threshold_deny = self._thresholds_for(tenant_id)
+
# Build risk signals from each expert
signals = self._compute_signals(
tool_name=tool_name,
@@ -801,9 +864,9 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any:
# If the worst-case (within 1-alpha confidence) is safe, allow it.
# If the best-case is dangerous, deny it.
decision_score = upper
- if decision_score < self._threshold_allow:
+ if decision_score < threshold_allow:
decision = Decision.ALLOW
- elif decision_score > self._threshold_deny:
+ elif decision_score > threshold_deny:
decision = Decision.DENY
else:
decision = Decision.ESCALATE
@@ -825,7 +888,7 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any:
explanation = (
f"{decision.value}: risk={point_estimate:.3f} "
f"[{lower:.3f}, {upper:.3f}] "
- f"(threshold allow<{self._threshold_allow} deny>{self._threshold_deny})"
+ f"(threshold allow<{threshold_allow} deny>{threshold_deny})"
)
assessment = RiskAssessment(
@@ -837,8 +900,8 @@ def _evaluate_locked(self, context: dict[str, Any]) -> Any:
decision=decision,
signals=signals,
mwu_weights=self._mwu.weights,
- threshold_allow=self._threshold_allow,
- threshold_deny=self._threshold_deny,
+ threshold_allow=threshold_allow,
+ threshold_deny=threshold_deny,
sequence_risk=signals.get("sequence_pattern", 0.0),
calibration_size=self._conformal.calibration_size,
effective_alpha=eff_alpha,
@@ -869,11 +932,14 @@ def dry_run_evaluate(self, context: dict[str, Any]) -> Any:
def _dry_run_evaluate_locked(self, context: dict[str, Any]) -> Any:
tool_name = context.get("tool_name", "unknown")
agent_id = context.get("agent_id", "anonymous")
+ tenant_id = context.get("tenant_id", "") or ""
base_risk = _coerce_unit_float(context.get("base_risk_score", 0.5), 0.5)
agent_confidence = _coerce_optional_unit_float(context.get("agent_confidence"))
reversibility = context.get("reversibility", "partially_reversible")
blast_radius = context.get("blast_radius", "local")
+ threshold_allow, threshold_deny = self._thresholds_for(tenant_id)
+
# _compute_signals is read-only for everything except the
# sequence detector's warning log — silence that temporarily.
seq_logger = logging.getLogger(
@@ -898,9 +964,9 @@ def _dry_run_evaluate_locked(self, context: dict[str, Any]) -> Any:
lower, upper = self._conformal.predict_interval(
point_estimate, category=bucket,
)
- if upper < self._threshold_allow:
+ if upper < threshold_allow:
decision = Decision.ALLOW
- elif upper > self._threshold_deny:
+ elif upper > threshold_deny:
decision = Decision.DENY
else:
decision = Decision.ESCALATE
@@ -913,6 +979,8 @@ def _dry_run_evaluate_locked(self, context: dict[str, Any]) -> Any:
"calibration_size": self._conformal.calibration_size,
"effective_alpha": self._conformal.effective_alpha_for(bucket),
"bucket_category": bucket,
+ "threshold_allow": threshold_allow,
+ "threshold_deny": threshold_deny,
},
}
diff --git a/src/vaara/server/app.py b/src/vaara/server/app.py
index 17d24d8..047d834 100644
--- a/src/vaara/server/app.py
+++ b/src/vaara/server/app.py
@@ -17,6 +17,7 @@
from vaara.audit.trail import AuditTrail
from vaara.policy.controller import PolicyController
+from vaara.policy.registry import PolicyRegistry
from vaara.scorer.adaptive import AdaptiveScorer
from vaara.server.routes import register
from vaara.server.state import ServerState
@@ -26,6 +27,7 @@ def create_app(
scorer: Optional[AdaptiveScorer] = None,
audit: Optional[AuditTrail] = None,
policy_controller: Optional[PolicyController] = None,
+ policy_registry: Optional[PolicyRegistry] = None,
) -> FastAPI:
"""Build the FastAPI application.
@@ -34,11 +36,19 @@ def create_app(
audit: Pre-configured audit trail, or None for default in-memory.
policy_controller: Pre-loaded ``PolicyController``. When supplied,
the scorer is registered as a listener and ``POST
- /v1/policy/reload`` becomes available. When omitted, the
- reload endpoint returns ``409 policy_not_configured``.
+ /v1/policy/reload`` becomes available. When omitted (and no
+ ``policy_registry``), the reload endpoint returns
+ ``409 policy_not_configured``.
+ policy_registry: Pre-loaded ``PolicyRegistry`` for multi-tenant
+ deployments. Mutually exclusive with ``policy_controller`` —
+ single-controller callers are wrapped into a registry's ""
+ slot automatically by ``ServerState``.
"""
state = ServerState(
- scorer=scorer, audit=audit, policy_controller=policy_controller
+ scorer=scorer,
+ audit=audit,
+ policy_controller=policy_controller,
+ policy_registry=policy_registry,
)
app = FastAPI(
title="Vaara HTTP API",
diff --git a/src/vaara/server/routes.py b/src/vaara/server/routes.py
index c3fb5b8..e04a8ee 100644
--- a/src/vaara/server/routes.py
+++ b/src/vaara/server/routes.py
@@ -7,7 +7,7 @@
from datetime import datetime, timezone
from typing import Optional
-from fastapi import FastAPI, HTTPException, status
+from fastapi import FastAPI, Header, HTTPException, status
from fastapi.responses import JSONResponse
from vaara import __version__ as _vaara_version
@@ -31,6 +31,12 @@ def _iso(ts: float) -> str:
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
+def _resolve_tenant(body_value: str, header_value: Optional[str]) -> str:
+ body = (body_value or "").strip()
+ header = (header_value or "").strip()
+ return body or header
+
+
def register(app: FastAPI, state: ServerState) -> None:
@app.exception_handler(HTTPException)
@@ -65,8 +71,13 @@ async def server_info():
)
@app.post("/v1/score", response_model=S.ScoreResponse)
- async def score(req: S.ScoreRequest):
+ async def score(
+ req: S.ScoreRequest,
+ x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"),
+ ):
+ tenant_id = _resolve_tenant(req.tenant_id, x_vaara_tenant)
ctx = req.model_dump(exclude_none=True)
+ ctx["tenant_id"] = tenant_id
try:
decision_dict = state.scorer.evaluate(ctx)
except Exception as exc:
@@ -84,6 +95,7 @@ async def score(req: S.ScoreRequest):
tool_name=req.tool_name,
predicted_risk=float(raw.get("point_estimate", 0.5) or 0.5),
signals=signals,
+ tenant_id=tenant_id,
)
return S.ScoreResponse(
@@ -99,8 +111,12 @@ async def score(req: S.ScoreRequest):
signals=signals,
mwu_weights={k: float(v) for k, v in state.scorer._mwu.weights.items()},
thresholds=S.Thresholds(
- allow=state.scorer._threshold_allow,
- deny=state.scorer._threshold_deny,
+ allow=float(
+ decision_dict.get("threshold_allow", state.scorer._threshold_allow)
+ ),
+ deny=float(
+ decision_dict.get("threshold_deny", state.scorer._threshold_deny)
+ ),
),
sequence_risk=float(raw.get("sequence_risk", 0.0) or 0.0),
calibration_size=int(raw.get("calibration_size", 0) or 0),
@@ -130,7 +146,10 @@ async def score_outcome(req: S.OutcomeRequest):
response_model=S.AuditEventResponse,
status_code=201,
)
- async def append_audit_event(req: S.AuditEventRequest):
+ async def append_audit_event(
+ req: S.AuditEventRequest,
+ x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"),
+ ):
try:
event_type = EventType(req.event_type)
except ValueError:
@@ -139,6 +158,12 @@ async def append_audit_event(req: S.AuditEventRequest):
status.HTTP_400_BAD_REQUEST,
)
+ tenant_id = _resolve_tenant(req.tenant_id, x_vaara_tenant)
+ if not tenant_id:
+ info = state.lookup_action(req.action_id)
+ if info is not None:
+ tenant_id = info.tenant_id
+
record = AuditRecord(
record_id=str(uuid.uuid4()),
action_id=req.action_id,
@@ -148,6 +173,7 @@ async def append_audit_event(req: S.AuditEventRequest):
tool_name=req.tool_name or "",
data=req.payload or {},
regulatory_articles=[],
+ tenant_id=tenant_id,
)
state.audit._append(record)
return S.AuditEventResponse(
@@ -217,16 +243,26 @@ async def detect_pii_endpoint(req: S.DetectPIIRequest):
return S.DetectPIIResponse(**result.to_dict())
@app.post("/v1/policy/reload", response_model=S.PolicyReloadResponse)
- async def reload_policy(req: S.PolicyReloadRequest):
+ async def reload_policy(
+ req: S.PolicyReloadRequest,
+ x_vaara_tenant: Optional[str] = Header(default=None, alias="X-Vaara-Tenant"),
+ ):
from vaara.policy.schema import PolicyError
- controller = state.policy_controller
- if controller is None:
+ tenant_id = _resolve_tenant(req.tenant_id, x_vaara_tenant)
+ registry = state.policy_registry
+ controller = (
+ registry.get_exact(tenant_id) if registry is not None else None
+ )
+ if controller is None and not tenant_id:
+ controller = state.policy_controller
+ if registry is None and controller is None:
raise _error(
code="policy_not_configured",
message=(
- "Server has no PolicyController; start with "
- "`vaara serve --policy PATH` to enable reload."
+ "Server has no policy plane; start with "
+ "`vaara serve --policy PATH` or `--policy-dir DIR` to "
+ "enable reload."
),
http_status=status.HTTP_409_CONFLICT,
)
@@ -240,7 +276,10 @@ async def reload_policy(req: S.PolicyReloadRequest):
source = req.body if req.body is not None else req.path
try:
- result = controller.reload(source, format=req.format)
+ if registry is not None:
+ result = registry.reload(tenant_id, source, format=req.format)
+ else:
+ result = controller.reload(source, format=req.format)
except PolicyError as exc:
raise _error(
code="policy_invalid",
@@ -257,4 +296,5 @@ async def reload_policy(req: S.PolicyReloadRequest):
sequence_count=result.sequence_count,
action_class_count=result.action_class_count,
escalation_route_count=result.escalation_route_count,
+ tenant_id=tenant_id,
)
diff --git a/src/vaara/server/schemas.py b/src/vaara/server/schemas.py
index 74671ca..1fe92c6 100644
--- a/src/vaara/server/schemas.py
+++ b/src/vaara/server/schemas.py
@@ -41,6 +41,7 @@ class ScoreRequest(BaseModel):
blast_radius: Optional[_BlastRadius] = None
session_id: Optional[str] = Field(default=None, max_length=256)
parent_action_id: Optional[str] = Field(default=None, max_length=128)
+ tenant_id: str = Field(default="", max_length=256)
context: dict[str, Any] = Field(default_factory=dict)
@@ -85,6 +86,7 @@ class AuditEventRequest(BaseModel):
action_id: str
agent_id: Optional[str] = None
tool_name: Optional[str] = None
+ tenant_id: str = Field(default="", max_length=256)
payload: dict[str, Any] = Field(default_factory=dict)
@@ -203,6 +205,7 @@ class PolicyReloadRequest(BaseModel):
path: Optional[str] = Field(default=None, max_length=4096)
body: Optional[dict[str, Any]] = None
format: Optional[Literal["json", "yaml"]] = None
+ tenant_id: str = Field(default="", max_length=256)
class PolicyReloadResponse(BaseModel):
@@ -211,3 +214,4 @@ class PolicyReloadResponse(BaseModel):
sequence_count: int
action_class_count: int
escalation_route_count: int
+ tenant_id: str = ""
diff --git a/src/vaara/server/state.py b/src/vaara/server/state.py
index eca3baf..e34ad76 100644
--- a/src/vaara/server/state.py
+++ b/src/vaara/server/state.py
@@ -1,4 +1,4 @@
-"""Server state container — scorer + audit trail + policy controller singletons."""
+"""Server state container — scorer + audit trail + policy registry singletons."""
from __future__ import annotations
@@ -8,6 +8,7 @@
from vaara.audit.trail import AuditTrail
from vaara.policy.controller import PolicyController
+from vaara.policy.registry import PolicyRegistry
from vaara.scorer.adaptive import AdaptiveScorer
@@ -17,6 +18,7 @@ class _ActionInfo:
tool_name: str
predicted_risk: float
signals: dict[str, float] = field(default_factory=dict)
+ tenant_id: str = ""
class ServerState:
@@ -27,12 +29,43 @@ def __init__(
scorer: Optional[AdaptiveScorer] = None,
audit: Optional[AuditTrail] = None,
policy_controller: Optional[PolicyController] = None,
+ policy_registry: Optional[PolicyRegistry] = None,
) -> None:
+ if policy_controller is not None and policy_registry is not None:
+ raise ValueError(
+ "Pass either policy_controller (single-tenant legacy) or "
+ "policy_registry (multi-tenant v0.40), not both. Mixing "
+ "the two silently splits threshold sources between the "
+ "default slot and per-tenant overrides.",
+ )
self.scorer = scorer or AdaptiveScorer()
self.audit = audit or AuditTrail()
+ # v0.40: a single PolicyRegistry holds all tenant policies. The
+ # single-tenant entry point (`policy_controller=...`) lands in the
+ # empty-string "" slot for back-compat with v0.39 callers.
+ if policy_registry is None and policy_controller is not None:
+ policy_registry = PolicyRegistry()
+ policy_registry.register("", policy_controller)
+ self.policy_registry = policy_registry
self.policy_controller = policy_controller
if policy_controller is not None:
policy_controller.add_listener(self.scorer.apply_policy)
+ elif policy_registry is not None:
+ default = policy_registry.get("")
+ if default is not None:
+ default.add_listener(self.scorer.apply_policy)
+ self.policy_controller = default
+ # v0.40 per-tenant threshold dispatch: the scorer asks the
+ # registry for the calling tenant's policy on every evaluate.
+ # An exact-match miss falls back to the scorer-bound defaults
+ # (which the default-slot listener keeps fresh on reload). We
+ # use get_exact rather than get so the scorer's own default
+ # path stays the single fallback channel.
+ if policy_registry is not None:
+ def _lookup(tid: str):
+ ctrl = policy_registry.get_exact(tid)
+ return ctrl.policy if ctrl is not None else None
+ self.scorer.set_policy_lookup(_lookup)
self._lock = threading.Lock()
# action_id → info captured at score time so outcome reports can
# feed the MWU update without the client having to resend context.
@@ -45,6 +78,7 @@ def remember_action(
tool_name: str,
predicted_risk: float,
signals: dict[str, float],
+ tenant_id: str = "",
) -> None:
with self._lock:
self._actions[action_id] = _ActionInfo(
@@ -52,6 +86,7 @@ def remember_action(
tool_name=tool_name,
predicted_risk=predicted_risk,
signals=signals,
+ tenant_id=tenant_id,
)
def lookup_action(self, action_id: str) -> Optional[_ActionInfo]:
diff --git a/src/vaara/taxonomy/actions.py b/src/vaara/taxonomy/actions.py
index e4b9c7e..d086284 100644
--- a/src/vaara/taxonomy/actions.py
+++ b/src/vaara/taxonomy/actions.py
@@ -136,6 +136,7 @@ class ActionRequest:
parent_action_id: Optional[str] = None # For action chains
sequence_position: int = 0 # Position in current action sequence
timestamp_utc: str = ""
+ tenant_id: str = "" # v0.40 multi-tenant scope; "" = single-tenant
def to_policy_context(self) -> dict:
"""Convert to a plain dict of policy-evaluation fields.
@@ -159,6 +160,7 @@ def to_policy_context(self) -> dict:
"parent_action_id": self.parent_action_id,
"sequence_position": self.sequence_position,
"parameters": self.parameters,
+ "tenant_id": self.tenant_id,
}
diff --git a/tests/test_integrations_mcp_proxy.py b/tests/test_integrations_mcp_proxy.py
index b43e94a..aaea107 100644
--- a/tests/test_integrations_mcp_proxy.py
+++ b/tests/test_integrations_mcp_proxy.py
@@ -491,7 +491,7 @@ def upstream_request(req):
"_meta": {"progressToken": "tok-stream"},
},
})
- assert progress_seen == [{"tok-stream": ("parent-act-9", "mcp-proxy-client", "long_tool")}]
+ assert progress_seen == [{"tok-stream": ("parent-act-9", "mcp-proxy-client", "long_tool", "")}]
# Map is cleaned up after the call returns.
assert p._inflight_progress == {}
# Audit was called for the progress event with the parent correlation.
diff --git a/tests/test_v040_mcp_http_transport.py b/tests/test_v040_mcp_http_transport.py
new file mode 100644
index 0000000..5a67941
--- /dev/null
+++ b/tests/test_v040_mcp_http_transport.py
@@ -0,0 +1,262 @@
+"""v0.40 MCP proxy streamable-HTTP transport + fan-out routing."""
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+try:
+ from fastapi.testclient import TestClient
+except ImportError:
+ pytest.skip(
+ "server extra not installed (pip install 'vaara[server]')",
+ allow_module_level=True,
+ )
+
+from vaara.integrations import mcp_proxy
+from vaara.integrations.mcp_proxy import VaaraMCPProxy, _parse_upstream_specs
+
+
+# ── _parse_upstream_specs ──────────────────────────────────────────────────
+
+def test_parse_upstream_specs_bare_command_lands_in_default():
+ result = _parse_upstream_specs(["echo"], [])
+ assert result == {"default": ["echo"]}
+
+
+def test_parse_upstream_specs_named_slot():
+ result = _parse_upstream_specs(["github=gh-mcp-server"], [])
+ assert result == {"github": ["gh-mcp-server"]}
+
+
+def test_parse_upstream_specs_command_with_equals_stays_intact():
+ """A command like `python -m foo --bar=baz` must NOT be split at '='."""
+ result = _parse_upstream_specs(["python -m foo --bar=baz"], [])
+ assert result == {"default": ["python -m foo --bar=baz"]}
+
+
+def test_parse_upstream_specs_legacy_args_join_first_slot():
+ result = _parse_upstream_specs(["echo"], ["hello", "world"])
+ assert result == {"default": ["echo", "hello", "world"]}
+
+
+def test_parse_upstream_specs_multiple_named_fanout():
+ result = _parse_upstream_specs(["a=cmd-a", "b=cmd-b"], [])
+ assert result == {"a": ["cmd-a"], "b": ["cmd-b"]}
+
+
+def test_parse_upstream_specs_unknown_name_pattern_falls_to_default():
+ """Names that aren't simple slugs aren't treated as NAME= prefix."""
+ result = _parse_upstream_specs(["python -m srv=foo"], [])
+ # The left side ("python -m srv") fails the slug regex, so the whole
+ # spec is treated as a bare command under "default".
+ assert result == {"default": ["python -m srv=foo"]}
+
+
+# ── VaaraMCPProxy multi-upstream constructor ───────────────────────────────
+
+@pytest.fixture
+def http_proxy(monkeypatch):
+ """A VaaraMCPProxy with mocked upstreams and pipeline."""
+ monkeypatch.setattr(mcp_proxy, "UpstreamMCPClient", MagicMock())
+ pipeline = MagicMock()
+ proxy = VaaraMCPProxy(
+ upstreams={"alpha": ["cmd-alpha"], "beta": ["cmd-beta"]},
+ pipeline=pipeline,
+ )
+ return proxy
+
+
+def test_constructor_rejects_both_upstream_and_upstreams():
+ with pytest.raises(ValueError, match="Pass either"):
+ VaaraMCPProxy(
+ upstream_command=["echo"],
+ upstreams={"a": ["echo"]},
+ pipeline=MagicMock(),
+ )
+
+
+def test_constructor_requires_at_least_one_upstream():
+ with pytest.raises(ValueError, match="upstream"):
+ VaaraMCPProxy(pipeline=MagicMock())
+
+
+def test_multi_upstream_populates_default_slot(http_proxy):
+ assert "default" in http_proxy._upstreams
+ # When no explicit default is provided, the sorted-first name acts as fallback.
+ assert http_proxy._upstreams["default"] is not None
+
+
+def test_default_slot_aliases_first_upstream_not_duplicate(monkeypatch):
+ """Regression guard: default fallback must alias an existing upstream
+ client instead of spawning a duplicate subprocess. An earlier shape
+ cloned the command into upstream_map["default"] before constructing
+ the client map, which built a second UpstreamMCPClient for the same
+ command."""
+ constructed: list[MagicMock] = []
+
+ def make_client(*_args, **_kw):
+ instance = MagicMock()
+ constructed.append(instance)
+ return instance
+
+ monkeypatch.setattr(
+ mcp_proxy, "UpstreamMCPClient", MagicMock(side_effect=make_client),
+ )
+ proxy = VaaraMCPProxy(
+ upstreams={"alpha": ["cmd-alpha"], "beta": ["cmd-beta"]},
+ pipeline=MagicMock(),
+ )
+ # Exactly two clients constructed (alpha, beta), not three.
+ assert len(constructed) == 2
+ # "default" aliases the sorted-first real slot.
+ assert proxy._upstreams["default"] is proxy._upstreams["alpha"]
+ assert proxy._upstreams["default"] is not proxy._upstreams["beta"]
+
+
+def test_single_upstream_lands_under_default_slot(monkeypatch):
+ monkeypatch.setattr(mcp_proxy, "UpstreamMCPClient", MagicMock())
+ proxy = VaaraMCPProxy(
+ upstream_command=["echo"], pipeline=MagicMock(),
+ )
+ assert list(proxy._upstreams) == ["default"]
+
+
+# ── HTTP transport endpoints ───────────────────────────────────────────────
+
+def _build_http_app(proxy):
+ """Construct the FastAPI app run_http() builds, without uvicorn.run()."""
+ import unittest.mock as um
+
+ with um.patch("uvicorn.run") as run_mock:
+ # Capture the app by intercepting uvicorn.run() — we call the same
+ # method the production CLI uses, but never block on the event loop.
+ captured: dict = {}
+
+ def fake_run(app, **kwargs):
+ captured["app"] = app
+
+ run_mock.side_effect = fake_run
+ proxy.run_http(host="127.0.0.1", port=0)
+ return captured["app"]
+
+
+def test_http_health_lists_upstreams(http_proxy):
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ resp = client.get("/health")
+ assert resp.status_code == 200
+ assert set(resp.json()["upstreams"]) >= {"alpha", "beta", "default"}
+
+
+def test_http_mcp_post_routes_to_named_upstream(http_proxy):
+ http_proxy._upstreams["alpha"].request.return_value = {
+ "jsonrpc": "2.0", "id": 1, "result": {"tools": []},
+ }
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ resp = client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
+ headers={"X-Vaara-Upstream": "alpha"},
+ )
+ assert resp.status_code == 200
+ http_proxy._upstreams["alpha"].request.assert_called_once()
+
+
+def test_http_mcp_fanout_without_header_returns_400(http_proxy):
+ """Multi-upstream deployment must NOT silently default-route."""
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ resp = client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
+ )
+ assert resp.status_code == 400
+ body = resp.json()
+ assert body["detail"]["error"]["code"] == "upstream_required"
+ # Operator gets the list of valid slots in the error so a client UI
+ # can recover without an additional health probe round-trip.
+ assert "alpha" in body["detail"]["error"]["message"]
+ assert "beta" in body["detail"]["error"]["message"]
+
+
+def test_http_mcp_single_upstream_silent_default(monkeypatch):
+ """Single-upstream deployment keeps the v0.39 silent-default contract."""
+ monkeypatch.setattr(mcp_proxy, "UpstreamMCPClient", MagicMock())
+ proxy = VaaraMCPProxy(upstream_command=["echo"], pipeline=MagicMock())
+ proxy._upstreams["default"].request.return_value = {
+ "jsonrpc": "2.0", "id": 1, "result": {"tools": []},
+ }
+ app = _build_http_app(proxy)
+ client = TestClient(app)
+ resp = client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
+ )
+ assert resp.status_code == 200
+ proxy._upstreams["default"].request.assert_called_once()
+
+
+def test_http_mcp_unknown_upstream_returns_404(http_proxy):
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ resp = client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
+ headers={"X-Vaara-Upstream": "no-such-thing"},
+ )
+ assert resp.status_code == 404
+ assert resp.json()["detail"]["error"]["code"] == "unknown_upstream"
+
+
+def test_http_mcp_notification_returns_202(http_proxy):
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ resp = client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "notifications/initialized"},
+ headers={"X-Vaara-Upstream": "alpha"},
+ )
+ assert resp.status_code == 202
+
+
+def test_http_mcp_oversized_body_returns_413(http_proxy, monkeypatch):
+ monkeypatch.setattr(mcp_proxy, "_MCP_HTTP_MAX_BODY_BYTES", 64)
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ payload = {
+ "jsonrpc": "2.0", "id": 1, "method": "tools/list",
+ "params": {"x": "a" * 10_000},
+ }
+ resp = client.post("/mcp", json=payload)
+ assert resp.status_code == 413
+ assert resp.json()["error"]["code"] == "payload_too_large"
+
+
+def test_http_mcp_bad_json_returns_parse_error(http_proxy):
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ resp = client.post("/mcp", content=b"not json")
+ assert resp.status_code == 400
+ assert resp.json()["error"]["code"] == -32700
+
+
+def test_http_mcp_tenant_header_threads_into_overt(http_proxy):
+ """X-Vaara-Tenant becomes a non_content_metadata claim on OVERT envelope."""
+ http_proxy._overt = MagicMock()
+ http_proxy._upstreams["alpha"].request.return_value = {
+ "jsonrpc": "2.0", "id": 1, "result": {"tools": []},
+ }
+ app = _build_http_app(http_proxy)
+ client = TestClient(app)
+ client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
+ headers={"X-Vaara-Tenant": "tenant-q", "X-Vaara-Upstream": "alpha"},
+ )
+ # tools/list does not emit an OVERT envelope in the current proxy
+ # implementation, so this test exercises only that the request
+ # dispatched cleanly with the header set.
+ http_proxy._upstreams["alpha"].request.assert_called_once()
diff --git a/tests/test_v040_per_tenant_threshold.py b/tests/test_v040_per_tenant_threshold.py
new file mode 100644
index 0000000..2b2dcf2
--- /dev/null
+++ b/tests/test_v040_per_tenant_threshold.py
@@ -0,0 +1,212 @@
+"""v0.40 per-tenant threshold dispatch at evaluate() time.
+
+The scorer holds defaults bound from the default ("") slot via the
+standard apply_policy listener path, and on every evaluate it looks up
+the calling tenant's policy from the PolicyRegistry. A tenant with its
+own thresholds gets those thresholds applied to THIS call; a tenant
+with no policy of its own falls back to the scorer-bound defaults.
+"""
+
+from __future__ import annotations
+
+from vaara.policy.controller import PolicyController
+from vaara.policy.loader import from_dict
+from vaara.policy.registry import PolicyRegistry
+from vaara.scorer.adaptive import AdaptiveScorer
+
+
+def _policy_dict(escalate: float, deny: float) -> dict:
+ return {
+ "version": "0.1",
+ "thresholds": {"default": {"escalate": escalate, "deny": deny}},
+ }
+
+
+def _ctx(tenant_id: str = "", base_risk: float = 0.5) -> dict:
+ return {
+ "tool_name": "tool.test",
+ "agent_id": "agent-test",
+ "tenant_id": tenant_id,
+ "base_risk_score": base_risk,
+ "reversibility": "partially_reversible",
+ "blast_radius": "local",
+ }
+
+
+def _registry_lookup(registry: PolicyRegistry):
+ def _lookup(tid: str):
+ ctrl = registry.get_exact(tid)
+ return ctrl.policy if ctrl is not None else None
+ return _lookup
+
+
+def test_tenant_thresholds_override_default_for_that_call():
+ registry = PolicyRegistry()
+ registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7))))
+ registry.register(
+ "tenant-a", PolicyController(from_dict(_policy_dict(0.05, 0.10)))
+ )
+
+ scorer = AdaptiveScorer()
+ scorer.apply_policy(registry.get_exact("").policy)
+ scorer.set_policy_lookup(_registry_lookup(registry))
+
+ default_result = scorer.evaluate(_ctx(tenant_id="", base_risk=0.5))
+ assert default_result["threshold_allow"] == 0.4
+ assert default_result["threshold_deny"] == 0.7
+
+ strict_result = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+ assert strict_result["threshold_allow"] == 0.05
+ assert strict_result["threshold_deny"] == 0.10
+
+ again = scorer.evaluate(_ctx(tenant_id="", base_risk=0.5))
+ assert again["threshold_allow"] == 0.4
+ assert again["threshold_deny"] == 0.7
+
+
+def test_unknown_tenant_falls_back_to_scorer_defaults():
+ registry = PolicyRegistry()
+ registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7))))
+
+ scorer = AdaptiveScorer()
+ scorer.apply_policy(registry.get_exact("").policy)
+ scorer.set_policy_lookup(_registry_lookup(registry))
+
+ result = scorer.evaluate(_ctx(tenant_id="ghost", base_risk=0.5))
+ assert result["threshold_allow"] == 0.4
+ assert result["threshold_deny"] == 0.7
+
+
+def test_empty_tenant_id_skips_lookup_and_uses_defaults():
+ calls = []
+
+ def _lookup(tid: str):
+ calls.append(tid)
+ return None
+
+ scorer = AdaptiveScorer(threshold_allow=0.4, threshold_deny=0.7)
+ scorer.set_policy_lookup(_lookup)
+
+ result = scorer.evaluate(_ctx(tenant_id="", base_risk=0.5))
+ assert result["threshold_allow"] == 0.4
+ assert result["threshold_deny"] == 0.7
+ assert calls == []
+
+
+def test_lookup_exception_falls_back_to_defaults():
+ def _broken_lookup(tid: str):
+ raise RuntimeError("registry unreachable")
+
+ scorer = AdaptiveScorer(threshold_allow=0.4, threshold_deny=0.7)
+ scorer.set_policy_lookup(_broken_lookup)
+
+ result = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+ assert result["threshold_allow"] == 0.4
+ assert result["threshold_deny"] == 0.7
+
+
+def test_tenant_reload_visible_to_next_evaluate_without_listener():
+ registry = PolicyRegistry()
+ registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7))))
+ registry.register(
+ "tenant-a", PolicyController(from_dict(_policy_dict(0.30, 0.50)))
+ )
+
+ scorer = AdaptiveScorer()
+ scorer.apply_policy(registry.get_exact("").policy)
+ scorer.set_policy_lookup(_registry_lookup(registry))
+
+ first = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+ assert first["threshold_allow"] == 0.30
+
+ registry.reload("tenant-a", _policy_dict(0.05, 0.10))
+ after = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+ assert after["threshold_allow"] == 0.05
+ assert after["threshold_deny"] == 0.10
+
+
+def test_dry_run_evaluate_uses_per_tenant_thresholds():
+ registry = PolicyRegistry()
+ registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7))))
+ registry.register(
+ "tenant-a", PolicyController(from_dict(_policy_dict(0.05, 0.10)))
+ )
+
+ scorer = AdaptiveScorer()
+ scorer.apply_policy(registry.get_exact("").policy)
+ scorer.set_policy_lookup(_registry_lookup(registry))
+
+ default_dry = scorer.dry_run_evaluate(_ctx(tenant_id="", base_risk=0.5))
+ strict_dry = scorer.dry_run_evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+
+ assert default_dry["raw_result"]["threshold_allow"] == 0.4
+ assert strict_dry["raw_result"]["threshold_allow"] == 0.05
+ assert strict_dry["raw_result"]["threshold_deny"] == 0.10
+
+
+def test_server_state_wires_lookup_from_registry():
+ from vaara.server.state import ServerState
+
+ registry = PolicyRegistry()
+ registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7))))
+ registry.register(
+ "tenant-a", PolicyController(from_dict(_policy_dict(0.05, 0.10)))
+ )
+
+ state = ServerState(policy_registry=registry)
+ result = state.scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+ assert result["threshold_allow"] == 0.05
+ assert result["threshold_deny"] == 0.10
+
+
+def test_score_response_surfaces_per_tenant_thresholds_over_http():
+ """The /v1/score response's `thresholds` block reflects the
+ per-call values, not the scorer's bound defaults. Regression test:
+ smoke-test caught the response surfacing 0.4/0.7 for every tenant
+ even though the decision itself was already dispatching correctly.
+ """
+ import pytest as _pytest
+ try:
+ from fastapi.testclient import TestClient
+ from vaara.server import create_app
+ except ImportError:
+ _pytest.skip("server extra not installed (pip install 'vaara[server]')")
+ return # pytest.skip raises; the return keeps static analysers happy
+
+ registry = PolicyRegistry()
+ registry.register("", PolicyController(from_dict(_policy_dict(0.4, 0.7))))
+ registry.register(
+ "tenant-strict",
+ PolicyController(from_dict(_policy_dict(0.05, 0.10))),
+ )
+ app = create_app(policy_registry=registry)
+ client = TestClient(app)
+ req = {"tool_name": "tool.read", "agent_id": "ag-1", "action_type": "data_read"}
+
+ default_resp = client.post("/v1/score", json=req).json()
+ assert default_resp["thresholds"]["allow"] == 0.4
+ assert default_resp["thresholds"]["deny"] == 0.7
+
+ strict_resp = client.post(
+ "/v1/score", json=req, headers={"X-Vaara-Tenant": "tenant-strict"}
+ ).json()
+ assert strict_resp["thresholds"]["allow"] == 0.05
+ assert strict_resp["thresholds"]["deny"] == 0.10
+
+ unknown_resp = client.post(
+ "/v1/score", json=req, headers={"X-Vaara-Tenant": "ghost"}
+ ).json()
+ assert unknown_resp["thresholds"]["allow"] == 0.4
+ assert unknown_resp["thresholds"]["deny"] == 0.7
+
+
+def test_non_policy_lookup_return_falls_back():
+ def _bad_typed_lookup(tid: str):
+ return {"not": "a Policy"}
+
+ scorer = AdaptiveScorer(threshold_allow=0.4, threshold_deny=0.7)
+ scorer.set_policy_lookup(_bad_typed_lookup)
+
+ result = scorer.evaluate(_ctx(tenant_id="tenant-a", base_risk=0.5))
+ assert result["threshold_allow"] == 0.4
+ assert result["threshold_deny"] == 0.7
diff --git a/tests/test_v040_policy_registry.py b/tests/test_v040_policy_registry.py
new file mode 100644
index 0000000..e32a3e7
--- /dev/null
+++ b/tests/test_v040_policy_registry.py
@@ -0,0 +1,121 @@
+"""v0.40 PolicyRegistry per-tenant policy plane + /v1/policy/reload routing."""
+
+from __future__ import annotations
+
+import pytest
+
+try:
+ from fastapi.testclient import TestClient
+
+ from vaara.server import create_app
+except ImportError:
+ pytest.skip(
+ "server extra not installed (pip install 'vaara[server]')",
+ allow_module_level=True,
+ )
+
+from vaara.policy.controller import PolicyController
+from vaara.policy.loader import from_dict
+from vaara.policy.registry import PolicyRegistry
+from vaara.policy.schema import PolicyError
+
+
+def _minimal_policy() -> dict:
+ return {
+ "version": "0.1",
+ "thresholds_default": {"escalate": 0.5, "deny": 0.9},
+ }
+
+
+def test_policy_registry_reload_creates_new_tenant_slot():
+ registry = PolicyRegistry()
+ result = registry.reload("tenant-a", _minimal_policy())
+ assert result.version == 1
+ assert "tenant-a" in registry
+
+
+def test_policy_registry_get_falls_back_to_default_slot():
+ registry = PolicyRegistry()
+ registry.reload("", _minimal_policy())
+ assert registry.get("nonexistent") is not None
+ assert registry.get_exact("nonexistent") is None
+
+
+def test_policy_registry_load_directory(tmp_path):
+ (tmp_path / "default.json").write_text(
+ '{"version": "0.1", "thresholds_default": {"escalate": 0.4, "deny": 0.8}}'
+ )
+ (tmp_path / "tenant-a.json").write_text(
+ '{"version": "0.1", "thresholds_default": {"escalate": 0.3, "deny": 0.7}}'
+ )
+ registry = PolicyRegistry()
+ tenants = registry.load_directory(tmp_path)
+ assert sorted(tenants) == ["", "tenant-a"]
+ assert "" in registry
+ assert "tenant-a" in registry
+
+
+def test_policy_registry_load_directory_rejects_empty(tmp_path):
+ registry = PolicyRegistry()
+ with pytest.raises(PolicyError):
+ registry.load_directory(tmp_path)
+
+
+def test_policy_registry_accepts_policy_instance_on_reload():
+ """Bulk-load path passes Policy directly to avoid re-parsing."""
+ registry = PolicyRegistry()
+ registry.reload("", _minimal_policy())
+ policy_obj = from_dict(_minimal_policy())
+ result = registry.reload("", policy_obj)
+ assert result.version == 2
+
+
+def test_policy_reload_per_tenant_via_body():
+ registry = PolicyRegistry()
+ registry.reload("", _minimal_policy())
+ app = create_app(policy_registry=registry)
+ client = TestClient(app)
+ resp = client.post(
+ "/v1/policy/reload",
+ json={
+ "body": {
+ "version": "0.1",
+ "thresholds_default": {"escalate": 0.2, "deny": 0.6},
+ },
+ "tenant_id": "tenant-x",
+ },
+ )
+ assert resp.status_code == 200, resp.text
+ assert resp.json()["tenant_id"] == "tenant-x"
+ assert "tenant-x" in registry
+
+
+def test_policy_reload_per_tenant_via_header():
+ registry = PolicyRegistry()
+ registry.reload("", _minimal_policy())
+ app = create_app(policy_registry=registry)
+ client = TestClient(app)
+ resp = client.post(
+ "/v1/policy/reload",
+ json={"body": _minimal_policy()},
+ headers={"X-Vaara-Tenant": "tenant-y"},
+ )
+ assert resp.status_code == 200
+ assert resp.json()["tenant_id"] == "tenant-y"
+
+
+def test_policy_reload_back_compat_single_controller():
+ controller = PolicyController(from_dict(_minimal_policy()))
+ app = create_app(policy_controller=controller)
+ client = TestClient(app)
+ resp = client.post("/v1/policy/reload", json={"body": _minimal_policy()})
+ assert resp.status_code == 200
+ assert resp.json()["tenant_id"] == ""
+
+
+def test_policy_reload_unconfigured_returns_409():
+ app = create_app()
+ client = TestClient(app)
+ resp = client.post("/v1/policy/reload", json={"body": _minimal_policy()})
+ assert resp.status_code == 409
+ assert resp.json()["error"]["code"] == "policy_not_configured"
diff --git a/tests/test_v040_tenant.py b/tests/test_v040_tenant.py
new file mode 100644
index 0000000..1685ba1
--- /dev/null
+++ b/tests/test_v040_tenant.py
@@ -0,0 +1,165 @@
+"""v0.40 tenant_id end-to-end propagation tests."""
+
+from __future__ import annotations
+
+import pytest
+
+try:
+ from fastapi.testclient import TestClient
+
+ from vaara.server import create_app
+except ImportError:
+ pytest.skip(
+ "server extra not installed (pip install 'vaara[server]')",
+ allow_module_level=True,
+ )
+
+from vaara.audit.trail import AuditRecord, AuditTrail, EventType
+from vaara.taxonomy.actions import (
+ ActionCategory,
+ ActionRequest,
+ ActionType,
+ BlastRadius,
+ Reversibility,
+)
+
+
+def _action_type(name: str = "t") -> ActionType:
+ return ActionType(
+ name=name,
+ category=ActionCategory.DATA,
+ reversibility=Reversibility.FULLY,
+ blast_radius=BlastRadius.LOCAL,
+ )
+
+
+def _minimal_policy() -> dict:
+ return {
+ "version": 1,
+ "thresholds_default": {"escalate": 0.5, "deny": 0.9},
+ }
+
+
+def test_score_body_tenant_id_lands_in_action_info():
+ app = create_app()
+ client = TestClient(app)
+ resp = client.post(
+ "/v1/score",
+ json={"tool_name": "search", "agent_id": "a", "tenant_id": "tenant-a"},
+ )
+ assert resp.status_code == 200
+ info = app.state.vaara.lookup_action(resp.json()["action_id"])
+ assert info is not None
+ assert info.tenant_id == "tenant-a"
+
+
+def test_score_header_tenant_id_used_when_body_empty():
+ app = create_app()
+ client = TestClient(app)
+ resp = client.post(
+ "/v1/score",
+ json={"tool_name": "search", "agent_id": "a"},
+ headers={"X-Vaara-Tenant": "tenant-b"},
+ )
+ assert resp.status_code == 200
+ info = app.state.vaara.lookup_action(resp.json()["action_id"])
+ assert info.tenant_id == "tenant-b"
+
+
+def test_score_body_tenant_wins_over_header():
+ app = create_app()
+ client = TestClient(app)
+ resp = client.post(
+ "/v1/score",
+ json={"tool_name": "search", "agent_id": "a", "tenant_id": "body-tenant"},
+ headers={"X-Vaara-Tenant": "header-tenant"},
+ )
+ info = app.state.vaara.lookup_action(resp.json()["action_id"])
+ assert info.tenant_id == "body-tenant"
+
+
+def test_audit_event_writes_with_tenant_from_action_lookup():
+ app = create_app()
+ client = TestClient(app)
+ score = client.post(
+ "/v1/score",
+ json={"tool_name": "search", "agent_id": "a", "tenant_id": "tenant-c"},
+ )
+ action_id = score.json()["action_id"]
+ resp = client.post(
+ "/v1/audit/events",
+ json={
+ "event_type": "action_executed",
+ "action_id": action_id,
+ "agent_id": "a",
+ "tool_name": "search",
+ "payload": {"result": "ok"},
+ },
+ )
+ assert resp.status_code == 201
+ records = [
+ r for r in app.state.vaara.audit._records if r.action_id == action_id
+ ]
+ assert any(r.tenant_id == "tenant-c" for r in records)
+
+
+def test_audit_trail_tenant_propagates_to_followup_records():
+ trail = AuditTrail()
+ action_type = _action_type()
+ req = ActionRequest(
+ agent_id="a", tool_name="t", action_type=action_type,
+ parameters={}, tenant_id="tenant-d",
+ )
+ action_id = trail.record_action_requested(req)
+ trail.record_decision(
+ action_id=action_id, agent_id="a", tool_name="t",
+ decision="allow", reason="ok", risk_score=0.1,
+ )
+ trail.record_execution(
+ action_id=action_id, agent_id="a", tool_name="t", result={"ok": True},
+ )
+ records = trail.get_action_trail(action_id)
+ assert len(records) == 3
+ assert all(r.tenant_id == "tenant-d" for r in records)
+
+
+def test_audit_trail_tenant_map_evicts_under_pressure():
+ trail = AuditTrail()
+ trail._MAX_ACTION_TENANT_MAP = 100
+ action_type = _action_type()
+ last_ids: list[str] = []
+ for i in range(150):
+ req = ActionRequest(
+ agent_id="a", tool_name="t", action_type=action_type,
+ tenant_id=f"t{i}",
+ )
+ last_ids.append(trail.record_action_requested(req))
+ assert len(trail._tenant_for_action) <= 100
+ assert trail._tenant_for_action.get(last_ids[-1]) == "t149"
+ assert last_ids[0] not in trail._tenant_for_action
+
+
+def test_audit_trail_caps_tenant_id_length():
+ trail = AuditTrail()
+ action_type = _action_type()
+ huge = "x" * 10_000
+ req = ActionRequest(
+ agent_id="a", tool_name="t", action_type=action_type,
+ tenant_id=huge,
+ )
+ action_id = trail.record_action_requested(req)
+ record = trail.get_action_trail(action_id)[0]
+ assert len(record.tenant_id) <= trail._MAX_TENANT_ID_LEN
+
+
+def test_audit_record_hash_excludes_tenant_id():
+ """tenant_id is NOT part of compute_hash so pre-v0.40 chains re-verify."""
+ rec_no = AuditRecord(
+ record_id="r1", action_id="a1", event_type=EventType.ACTION_REQUESTED,
+ timestamp=1.0, agent_id="a", tool_name="t",
+ )
+ rec_with = AuditRecord(
+ record_id="r1", action_id="a1", event_type=EventType.ACTION_REQUESTED,
+ timestamp=1.0, agent_id="a", tool_name="t", tenant_id="tenant-x",
+ )
+ assert rec_no.compute_hash() == rec_with.compute_hash()