diff --git a/docs/my-website/docs/mcp_zero_trust.md b/docs/my-website/docs/mcp_zero_trust.md new file mode 100644 index 00000000000..8f431523cb8 --- /dev/null +++ b/docs/my-website/docs/mcp_zero_trust.md @@ -0,0 +1,294 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# MCP Zero Trust Auth (JWT Signer) + +![Zero Trust MCP Gateway](/img/mcp_zero_trust_gateway.png) + +MCP servers have no built-in way to verify that a request actually came through LiteLLM. Without this guardrail, any client that can reach your MCP server directly can call tools — bypassing your access controls entirely. + +`MCPJWTSigner` fixes this. It signs every outbound tool call with a short-lived RS256 JWT. Your MCP server verifies the signature against LiteLLM's public key. Requests that didn't go through LiteLLM have no valid signature and are rejected. + +--- + +## Basic setup + +Add the guardrail to your config and point your MCP server at LiteLLM's JWKS endpoint. Every tool call gets a signed JWT automatically — no changes needed on the client side. + +```yaml title="config.yaml" +mcp_servers: + - server_name: weather + url: http://localhost:8000/mcp + transport: http + +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + issuer: "https://my-litellm.example.com" # defaults to request base URL + audience: "mcp" # default: "mcp" + ttl_seconds: 300 # default: 300 +``` + +**Bring your own signing key** — recommended for production. Auto-generated keys are lost on restart. + +```bash +export MCP_JWT_SIGNING_KEY="-----BEGIN RSA PRIVATE KEY-----\n..." +# or point to a file +export MCP_JWT_SIGNING_KEY="file:///secrets/mcp-signing-key.pem" +``` + +**Build a verified MCP server with [FastMCP](https://gofastmcp.com):** + +```python title="weather_server.py" +from fastmcp import FastMCP, Context +from fastmcp.server.auth.providers.jwt import JWTVerifier + +auth = JWTVerifier( + jwks_uri="https://my-litellm.example.com/.well-known/jwks.json", + issuer="https://my-litellm.example.com", + audience="mcp", + algorithm="RS256", +) + +mcp = FastMCP("weather-server", auth=auth) + +@mcp.tool() +async def get_weather(city: str, ctx: Context) -> str: + caller = ctx.client_id # JWT `sub` — the verified user identity + return f"Weather in {city}: sunny, 72°F (requested by {caller})" + +if __name__ == "__main__": + mcp.run(transport="http", host="0.0.0.0", port=8000) +``` + +FastMCP fetches the JWKS automatically and re-fetches when the signing key changes. + +LiteLLM publishes OIDC discovery so MCP servers find the key without any manual configuration: + +``` +GET /.well-known/openid-configuration → { "jwks_uri": "https:///.well-known/jwks.json" } +GET /.well-known/jwks.json → { "keys": [{ "kty": "RSA", "alg": "RS256", ... }] } +``` + +> **Read further only if you need to:** thread a corporate IdP identity into the JWT, enforce specific claims on callers, add custom metadata, use AWS Bedrock AgentCore Gateway, or debug JWT rejections. + +--- + +## Thread IdP identity into MCP JWTs + +By default the outbound JWT `sub` is LiteLLM's internal `user_id`. If your users authenticate with Okta, Azure AD, or another IdP, the MCP server sees a LiteLLM-internal ID — not the user's email or employee ID. + +With verify+re-sign, LiteLLM validates the incoming IdP token first, then builds the outbound JWT using the real identity claims from that token. The MCP server gets the user's actual identity without ever having to trust the original IdP directly. + +```yaml title="config.yaml" +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + issuer: "https://my-litellm.example.com" + + # Validate the incoming Bearer token against the IdP + access_token_discovery_uri: "https://login.microsoftonline.com/{tenant}/v2.0/.well-known/openid-configuration" + verify_issuer: "https://login.microsoftonline.com/{tenant}/v2.0" + verify_audience: "api://my-app" + + # Which claim to use for `sub` in the outbound JWT — first non-empty value wins + end_user_claim_sources: + - "token:sub" # from the verified incoming JWT + - "token:email" # fallback to email + - "litellm:user_id" # last resort: LiteLLM's internal user_id +``` + +If the incoming token is **opaque** (not a JWT — some IdPs issue these), add an introspection endpoint. LiteLLM will POST the token to it (RFC 7662) and use the returned claims: + +```yaml + token_introspection_endpoint: "https://idp.example.com/oauth2/introspect" +``` + +**Supported `end_user_claim_sources` values:** + +| Source | Resolves to | +|--------|-------------| +| `token:` | Any claim from the verified incoming JWT (e.g. `token:sub`, `token:email`, `token:oid`) | +| `litellm:user_id` | LiteLLM's internal user ID | +| `litellm:email` | User email from LiteLLM auth context | +| `litellm:end_user_id` | End-user ID if set separately | +| `litellm:team_id` | Team ID from LiteLLM auth context | + +--- + +## Block callers missing required attributes + +Some MCP servers expose sensitive operations that should only be reachable by verified employees — not service accounts, not external API keys. You can enforce this at the LiteLLM layer so the MCP server never receives the request at all. + +`required_claims` rejects with `403` if the incoming token is missing any listed claim. `optional_claims` forwards claims that are useful but not mandatory. + +```yaml title="config.yaml" +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + + access_token_discovery_uri: "https://idp.example.com/.well-known/openid-configuration" + + # Service accounts without `employee_id` are blocked before the tool runs + required_claims: + - "sub" + - "employee_id" + + # Forward these into the outbound JWT when present — skipped silently if absent + optional_claims: + - "groups" + - "department" +``` + +**What the client sees when blocked:** +```json +HTTP 403 +{ "error": "MCPJWTSigner: incoming token is missing required claims: ['employee_id']. Configure the IdP to include these claims." } +``` + +--- + +## Add custom metadata to every JWT + +Your MCP server may need context that LiteLLM doesn't carry natively — which deployment sent the request, a tenant ID, an environment tag. Use claim operations to inject, override, or strip claims from the outbound JWT. + +```yaml title="config.yaml" +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + + # add: insert only when the key is not already in the JWT + add_claims: + deployment_id: "prod-us-east-1" + tenant_id: "acme-corp" + + # set: always override — even if the claim came from the incoming token + set_claims: + env: "production" + + # remove: strip claims the MCP server shouldn't see + remove_claims: + - "nbf" # some validators reject nbf; remove it if yours does +``` + +Operations run in order — `add_claims` → `set_claims` → `remove_claims`. `set_claims` always wins over `add_claims`; `remove_claims` beats both. + +--- + +## AWS Bedrock AgentCore Gateway + +Bedrock AgentCore Gateway uses two separate JWTs: one to authenticate the transport connection and another to authorize tool calls. They need different `aud` values and TTLs — a single JWT won't work for both. + +LiteLLM can issue both in one hook and inject them into separate headers: + +```yaml title="config.yaml" +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + issuer: "https://my-litellm.example.com" + audience: "mcp-resource" # for the MCP resource layer + ttl_seconds: 300 + + # Second JWT for the transport channel — same sub/act/scope, different aud + TTL + channel_token_audience: "bedrock-agentcore-gateway" + channel_token_ttl: 60 # transport tokens should be short-lived +``` + +LiteLLM injects two headers on every tool call: +- `Authorization: Bearer ` — audience `mcp-resource`, TTL 300s +- `x-mcp-channel-token: Bearer ` — audience `bedrock-agentcore-gateway`, TTL 60s + +Both tokens are signed with the same LiteLLM key, so your MCP server only needs to trust one JWKS endpoint. + +--- + +## Control which scopes go into the JWT + +By default LiteLLM generates least-privilege scopes per request: +- Tool call → `mcp:tools/call mcp:tools/{name}:call` +- List tools → `mcp:tools/call mcp:tools/list` + +If your MCP server does its own scope enforcement and needs a specific format, set `allowed_scopes` to replace auto-generation entirely: + +```yaml title="config.yaml" +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + + allowed_scopes: + - "mcp:tools/call" + - "mcp:tools/list" + - "mcp:admin" +``` + +Every JWT carries exactly those scopes regardless of which tool is being called. + +--- + +## Debug JWT rejections + +Your MCP server is returning 401 and you're not sure what's in the JWT. Enable `debug_headers` and LiteLLM adds a `x-litellm-mcp-debug` response header with the key claims that were signed: + +```yaml title="config.yaml" +guardrails: + - guardrail_name: mcp-jwt-signer + litellm_params: + guardrail: mcp_jwt_signer + mode: pre_mcp_call + default_on: true + debug_headers: true +``` + +Response header: +``` +x-litellm-mcp-debug: v=1; kid=a3f1b2c4d5e6f708; sub=alice@corp.com; iss=https://my-litellm.example.com; exp=1712345678; scope=mcp:tools/call mcp:tools/get_weather:call +``` + +Check that `kid` matches what the MCP server fetched from JWKS, `iss`/`aud` match your server's expected values, and `exp` hasn't passed. Disable in production — the header leaks claim metadata. + +--- + +## JWT claims reference + +| Claim | Value | +|-------|-------| +| `iss` | `issuer` config value (or request base URL) | +| `aud` | `audience` config value (default: `"mcp"`) | +| `sub` | Resolved via `end_user_claim_sources` (default: `user_id` → api-key hash → `"litellm-proxy"`) | +| `act.sub` | `team_id` → `org_id` → `"litellm-proxy"` (RFC 8693 delegation) | +| `email` | `user_email` from LiteLLM auth context (when available) | +| `scope` | Auto-generated per tool call, or `allowed_scopes` when set | +| `iat`, `exp`, `nbf` | Standard timing claims (RFC 7519) | + +--- + +## Limitations + +- **OpenAPI-backed MCP servers** (`spec_path` set) do not support JWT injection. LiteLLM logs a warning and skips the header. Use SSE/HTTP transport servers to get full JWT injection. +- The keypair is **in-memory by default** and rotated on each restart unless `MCP_JWT_SIGNING_KEY` is set. FastMCP's `JWTVerifier` handles key rotation transparently via JWKS key ID matching. + +--- + +## Related + +- [MCP Guardrails](./mcp_guardrail) — PII masking and blocking for MCP calls +- [MCP OAuth](./mcp_oauth) — upstream OAuth2 for MCP server access +- [MCP AWS SigV4](./mcp_aws_sigv4) — AWS-signed requests to MCP servers diff --git a/docs/my-website/img/mcp_zero_trust_gateway.png b/docs/my-website/img/mcp_zero_trust_gateway.png new file mode 100644 index 00000000000..3955cef0553 Binary files /dev/null and b/docs/my-website/img/mcp_zero_trust_gateway.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 1362745a91f..46e1bd041e6 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -636,6 +636,7 @@ const sidebars = { "mcp_control", "mcp_cost", "mcp_guardrail", + "mcp_zero_trust", "mcp_troubleshoot", ] }, diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index af3a715051b..3385e7feef6 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -677,7 +677,60 @@ async def oauth_authorization_server_mcp( # Alias for standard OpenID discovery @router.get("/.well-known/openid-configuration") async def openid_configuration(request: Request): - return await oauth_authorization_server_mcp(request) + response = await oauth_authorization_server_mcp(request) + + # If MCPJWTSigner is active, augment the discovery doc with JWKS fields so + # MCP servers and gateways (e.g. AWS Bedrock AgentCore Gateway) can resolve + # the signing keys and verify liteLLM-issued tokens. + try: + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + signer = get_mcp_jwt_signer() + if signer is not None: + request_base_url = get_request_base_url(request) + if isinstance(response, dict): + response = { + **response, + "jwks_uri": f"{request_base_url}/.well-known/jwks.json", + "id_token_signing_alg_values_supported": ["RS256"], + } + except ImportError: + pass + + return response + + +@router.get("/.well-known/jwks.json") +async def jwks_json(request: Request): + """ + JSON Web Key Set endpoint. + + Returns the RSA public key used by MCPJWTSigner to sign outbound MCP tokens. + MCP servers and gateways use this endpoint to verify liteLLM-issued JWTs. + + Returns an empty key set if MCPJWTSigner is not configured. + """ + try: + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + signer = get_mcp_jwt_signer() + if signer is not None: + return JSONResponse( + content=signer.get_jwks(), + headers={"Cache-Control": f"public, max-age={signer.jwks_max_age}"}, + ) + except ImportError: + pass + + # No signer active — return empty key set; short cache so activation is picked up quickly. + return JSONResponse( + content={"keys": []}, + headers={"Cache-Control": "public, max-age=60"}, + ) # Additional legacy pattern support diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 43fe54fdfb7..a8e5e60f228 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -1908,7 +1908,15 @@ async def pre_call_tool_check( user_api_key_auth: Optional[UserAPIKeyAuth], proxy_logging_obj: ProxyLogging, server: MCPServer, - ): + raw_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Run pre-call checks and guardrail hooks for an MCP tool call. + + Returns a dict that may contain: + - "arguments": hook-modified tool arguments (only if changed) + - "extra_headers": headers injected by pre_mcp_call guardrail hooks + """ ## check if the tool is allowed or banned for the given server if not self.check_allowed_or_banned_tools(name, server): raise HTTPException( @@ -1932,6 +1940,14 @@ async def pre_call_tool_check( server=server, ) + # Extract incoming Bearer token from raw request headers so + # guardrails like MCPJWTSigner can verify + re-sign it (FR-5). + normalized_raw = {k.lower(): v for k, v in (raw_headers or {}).items()} + incoming_bearer_token: Optional[str] = None + auth_hdr = normalized_raw.get("authorization", "") + if auth_hdr.lower().startswith("bearer "): + incoming_bearer_token = auth_hdr[len("bearer "):] + pre_hook_kwargs = { "name": name, "arguments": arguments, @@ -1957,6 +1973,7 @@ async def pre_call_tool_check( if user_api_key_auth else None ), + "incoming_bearer_token": incoming_bearer_token, } # Create MCP request object for processing @@ -1969,6 +1986,7 @@ async def pre_call_tool_check( mcp_request_obj, pre_hook_kwargs ) + hook_result: Dict[str, Any] = {} try: # Use standard pre_call_hook modified_data = await proxy_logging_obj.pre_call_hook( @@ -1984,7 +2002,9 @@ async def pre_call_tool_check( ) ) if modified_kwargs.get("arguments") != arguments: - arguments = modified_kwargs["arguments"] + hook_result["arguments"] = modified_kwargs["arguments"] + if modified_kwargs.get("extra_headers"): + hook_result["extra_headers"] = modified_kwargs["extra_headers"] except ( BlockedPiiEntityError, @@ -1995,6 +2015,8 @@ async def pre_call_tool_check( verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}") raise e + return hook_result + def _create_during_hook_task( self, name: str, @@ -2047,6 +2069,7 @@ async def _call_regular_mcp_tool( raw_headers: Optional[Dict[str, str]], proxy_logging_obj: Optional[ProxyLogging], host_progress_callback: Optional[Callable] = None, + hook_extra_headers: Optional[Dict[str, str]] = None, ) -> CallToolResult: """ Call a regular MCP tool using the MCP client. @@ -2061,6 +2084,9 @@ async def _call_regular_mcp_tool( oauth2_headers: Optional OAuth2 headers raw_headers: Optional raw headers from the request proxy_logging_obj: Optional ProxyLogging object for hook integration + host_progress_callback: Optional callback for progress updates + hook_extra_headers: Optional headers injected by pre_mcp_call guardrail + hooks. Merged last (highest priority) into outbound request headers. Returns: CallToolResult from the MCP server @@ -2116,6 +2142,31 @@ async def _call_regular_mcp_tool( extra_headers = {} extra_headers.update(mcp_server.static_headers) + if hook_extra_headers: + if extra_headers is None: + extra_headers = {} + if "Authorization" in hook_extra_headers: + if "Authorization" in extra_headers: + verbose_logger.warning( + "MCPServerManager: hook_extra_headers 'Authorization' will overwrite " + "the existing Authorization header from static_headers. " + "The hook JWT will take precedence." + ) + elif server_auth_header is not None: + # server_auth_header is passed separately to _create_mcp_client as + # auth_value. Both will reach the upstream server — warn so admins + # know two Authorization credentials are being sent. + verbose_logger.warning( + "MCPServerManager: hook_extra_headers injects 'Authorization' while " + "server '%s' already has a configured authentication_token. " + "Both credentials will be sent; the hook header is in extra_headers " + "and the server token is in auth_value — the upstream server decides " + "which one wins. Consider unsetting authentication_token if you want " + "the hook JWT to be the sole credential.", + mcp_server.server_name or mcp_server.name, + ) + extra_headers.update(hook_extra_headers) + stdio_env = self._build_stdio_env(mcp_server, raw_headers) client = await self._create_mcp_client( @@ -2201,15 +2252,19 @@ async def call_tool( # Allow validation and modification of tool calls before execution # Using standard pre_call_hook ######################################################### + hook_result: Dict[str, Any] = {} if proxy_logging_obj: - await self.pre_call_tool_check( + hook_result = await self.pre_call_tool_check( name=name, arguments=arguments, server_name=server_name, user_api_key_auth=user_api_key_auth, proxy_logging_obj=proxy_logging_obj, server=mcp_server, + raw_headers=raw_headers, ) + if "arguments" in hook_result: + arguments = hook_result["arguments"] # Prepare tasks for during hooks tasks = [] @@ -2227,8 +2282,16 @@ async def call_tool( # For OpenAPI servers, call the tool handler directly instead of via MCP client if mcp_server.spec_path: verbose_logger.debug( - f"Calling OpenAPI tool {name} directly via HTTP handler" + "Calling OpenAPI tool %s directly via HTTP handler", name ) + if hook_result.get("extra_headers"): + verbose_logger.warning( + "pre_mcp_call hook returned extra_headers for OpenAPI-backed " + "MCP server '%s' — header injection is not supported for " + "OpenAPI servers; headers will be ignored. Use SSE/HTTP " + "transport to enable hook header injection.", + server_name, + ) tasks.append( asyncio.create_task( self._call_openapi_tool_handler(mcp_server, name, arguments) @@ -2247,6 +2310,7 @@ async def call_tool( raw_headers=raw_headers, proxy_logging_obj=proxy_logging_obj, host_progress_callback=host_progress_callback, + hook_extra_headers=hook_result.get("extra_headers"), ) # For OpenAPI tools, await outside the client context diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ecbd7314cd7..9e86680e355 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2471,6 +2471,9 @@ class UserAPIKeyAuth( Any ] = None # Expanded created_by user when expand=user is used end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + # Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery + # and forwarded into outbound tokens by guardrails such as MCPJWTSigner. + jwt_claims: Optional[Dict] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 376048e7a13..451ed56339d 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -700,6 +700,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) if valid_token is not None: api_key = valid_token.token or "" + valid_token.jwt_claims = jwt_claims do_standard_jwt_auth = False # Fall through to virtual key checks @@ -729,6 +730,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 team_membership: Optional[LiteLLM_TeamMembership] = result.get( "team_membership", None ) + jwt_claims: Optional[dict] = result.get("jwt_claims", None) global_proxy_spend = await get_global_proxy_spend( litellm_proxy_admin_name=litellm_proxy_admin_name, @@ -757,6 +759,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 org_id=org_id, end_user_id=end_user_id, parent_otel_span=parent_otel_span, + jwt_claims=jwt_claims, ) valid_token = UserAPIKeyAuth( @@ -803,6 +806,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 team_metadata=( team_object.metadata if team_object is not None else None ), + jwt_claims=jwt_claims, ) # Check if model has zero cost - if so, skip all budget checks diff --git a/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/__init__.py new file mode 100644 index 00000000000..abea9014a11 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/__init__.py @@ -0,0 +1,84 @@ +"""MCP JWT Signer guardrail — built-in LiteLLM guardrail for zero trust MCP auth.""" + +from typing import TYPE_CHECKING + +from litellm.types.guardrails import SupportedGuardrailIntegrations + +from .mcp_jwt_signer import MCPJWTSigner, get_mcp_jwt_signer + +if TYPE_CHECKING: + from litellm.types.guardrails import Guardrail, LitellmParams + + +def initialize_guardrail( + litellm_params: "LitellmParams", guardrail: "Guardrail" +) -> MCPJWTSigner: + import litellm + + guardrail_name = guardrail.get("guardrail_name") + if not guardrail_name: + raise ValueError("MCPJWTSigner guardrail requires a guardrail_name") + + mode = litellm_params.mode + if mode != "pre_mcp_call": + raise ValueError( + f"MCPJWTSigner guardrail '{guardrail_name}' has mode='{mode}' but must use " + "mode='pre_mcp_call'. JWT injection only fires for MCP tool calls." + ) + + optional_params = getattr(litellm_params, "optional_params", None) + + def _get(key): # type: ignore[no-untyped-def] + if optional_params is not None: + v = getattr(optional_params, key, None) + if v is not None: + return v + return getattr(litellm_params, key, None) + + signer = MCPJWTSigner( + guardrail_name=guardrail_name, + event_hook=litellm_params.mode, + default_on=litellm_params.default_on, + # Core signing + issuer=_get("issuer"), + audience=_get("audience"), + ttl_seconds=_get("ttl_seconds"), + # FR-5: verify + re-sign + access_token_discovery_uri=_get("access_token_discovery_uri"), + token_introspection_endpoint=_get("token_introspection_endpoint"), + verify_issuer=_get("verify_issuer"), + verify_audience=_get("verify_audience"), + # FR-12: end-user identity mapping + end_user_claim_sources=_get("end_user_claim_sources"), + # FR-13: claim operations + add_claims=_get("add_claims"), + set_claims=_get("set_claims"), + remove_claims=_get("remove_claims"), + # FR-14: two-token model + channel_token_audience=_get("channel_token_audience"), + channel_token_ttl=_get("channel_token_ttl"), + # FR-15: incoming claim validation + required_claims=_get("required_claims"), + optional_claims=_get("optional_claims"), + # FR-9: debug headers + debug_headers=_get("debug_headers") or False, + # FR-10: configurable scopes + allowed_scopes=_get("allowed_scopes"), + ) + litellm.logging_callback_manager.add_litellm_callback(signer) + return signer + + +guardrail_initializer_registry = { + SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: initialize_guardrail, +} + +guardrail_class_registry = { + SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: MCPJWTSigner, +} + +__all__ = [ + "MCPJWTSigner", + "initialize_guardrail", + "get_mcp_jwt_signer", +] diff --git a/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/mcp_jwt_signer.py b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/mcp_jwt_signer.py new file mode 100644 index 00000000000..63997808aa5 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/mcp_jwt_signer.py @@ -0,0 +1,889 @@ +""" +MCPJWTSigner — Built-in LiteLLM guardrail for zero trust MCP authentication. + +Signs outbound MCP requests with a LiteLLM-issued RS256 JWT so that MCP servers +can trust a single signing authority (liteLLM) instead of every upstream IdP. + +Usage in config.yaml: + + guardrails: + - guardrail_name: "mcp-jwt-signer" + litellm_params: + guardrail: mcp_jwt_signer + mode: "pre_mcp_call" + default_on: true + + # Core signing config + issuer: "https://my-litellm.example.com" # optional + audience: "mcp" # optional + ttl_seconds: 300 # optional + + # FR-5: Verify + re-sign — validate incoming Bearer token before signing + access_token_discovery_uri: "https://idp.example.com/.well-known/openid-configuration" + token_introspection_endpoint: "https://idp.example.com/introspect" # opaque tokens + verify_issuer: "https://idp.example.com" # expected iss in incoming JWT + verify_audience: "api://my-app" # expected aud in incoming JWT + + # FR-12: End-user identity mapping — ordered resolution chain + # Supported: token:, litellm:user_id, litellm:email, + # litellm:end_user_id, litellm:team_id + end_user_claim_sources: + - "token:sub" + - "token:email" + - "litellm:user_id" + + # FR-13: Claim operations + add_claims: # add if key not already present in the JWT + deployment_id: "prod-001" + set_claims: # always set (overrides computed value) + env: "production" + remove_claims: # remove from final JWT + - "nbf" + + # FR-14: Two-token model — issue a second JWT for the MCP transport channel + channel_token_audience: "bedrock-gateway" + channel_token_ttl: 60 + + # FR-15: Incoming claim validation — enforce required IdP claims + required_claims: + - "sub" + - "email" + optional_claims: # pass through from jwt_claims into outbound JWT + - "groups" + - "roles" + + # FR-9: Debug headers + debug_headers: false # emit x-litellm-mcp-debug header when true + + # FR-10: Configurable scopes — explicit list replaces auto-generation + allowed_scopes: + - "mcp:tools/call" + - "mcp:tools/list" + +MCP servers verify tokens via: + GET /.well-known/openid-configuration → { jwks_uri: ".../.well-known/jwks.json" } + GET /.well-known/jwks.json → RSA public key in JWKS format + +Optionally set MCP_JWT_SIGNING_KEY env var (PEM string or file:///path) to use +your own RSA keypair. If unset, an RSA-2048 keypair is auto-generated at startup. +""" + +import base64 +import hashlib +import os +import re +import time +from typing import Any, Dict, List, Optional, Union + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.utils import CallTypesLiteral + +# Module-level singleton for the JWKS discovery endpoint to access. +_mcp_jwt_signer_instance: Optional["MCPJWTSigner"] = None + +# Simple in-memory JWKS cache: keyed by JWKS URI → (keys_list, fetched_at). +_jwks_cache: Dict[str, tuple] = {} +_JWKS_CACHE_TTL = 3600 # 1 hour + + +def get_mcp_jwt_signer() -> Optional["MCPJWTSigner"]: + """Return the active MCPJWTSigner singleton, or None if not initialized.""" + return _mcp_jwt_signer_instance + + +def _load_private_key_from_env(env_var: str) -> RSAPrivateKey: + """Load an RSA private key from an env var (PEM string or file:// path).""" + key_material = os.environ.get(env_var, "") + if not key_material: + raise ValueError( + f"MCPJWTSigner: environment variable '{env_var}' is set but empty." + ) + if key_material.startswith("file://"): + path = key_material[len("file://"):] + with open(path, "rb") as f: + key_bytes = f.read() + else: + key_bytes = key_material.encode("utf-8") + return serialization.load_pem_private_key(key_bytes, password=None) # type: ignore[return-value] + + +def _generate_rsa_key_pair() -> RSAPrivateKey: + """Generate a new RSA-2048 private key.""" + return rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + +def _int_to_base64url(n: int) -> str: + """Encode an integer as a base64url string (no padding).""" + byte_length = (n.bit_length() + 7) // 8 + return ( + base64.urlsafe_b64encode(n.to_bytes(byte_length, byteorder="big")) + .rstrip(b"=") + .decode("ascii") + ) + + +def _compute_kid(public_key: Any) -> str: + """Derive a key ID from the public key's DER encoding (SHA-256, first 16 hex chars).""" + der_bytes = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return hashlib.sha256(der_bytes).hexdigest()[:16] + + +async def _fetch_jwks(jwks_uri: str) -> List[Dict[str, Any]]: + """ + Fetch and cache a JWKS from the given URI. + + Results are cached for _JWKS_CACHE_TTL seconds to avoid hammering the IdP. + """ + now = time.time() + cached = _jwks_cache.get(jwks_uri) + if cached is not None: + keys, fetched_at = cached + if now - fetched_at < _JWKS_CACHE_TTL: + return keys # type: ignore[return-value] + + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + resp = await client.get(jwks_uri, headers={"Accept": "application/json"}) + resp.raise_for_status() + keys = resp.json().get("keys", []) + _jwks_cache[jwks_uri] = (keys, now) + return keys # type: ignore[return-value] + + +async def _fetch_oidc_discovery(discovery_uri: str) -> Dict[str, Any]: + """Fetch an OIDC discovery document and return its parsed JSON.""" + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + resp = await client.get(discovery_uri, headers={"Accept": "application/json"}) + resp.raise_for_status() + return resp.json() # type: ignore[return-value] + + +class MCPJWTSigner(CustomGuardrail): + """ + Built-in LiteLLM guardrail that signs outbound MCP requests with a + LiteLLM-issued RS256 JWT, enabling zero trust authentication. + + MCP servers verify tokens using liteLLM's OIDC discovery endpoint and + JWKS endpoint rather than trusting each upstream IdP directly. + + The signed JWT carries: + - iss: LiteLLM issuer identifier + - aud: MCP audience (configurable) + - sub: End-user identity (resolved via end_user_claim_sources, RFC 8693) + - act: Actor/agent identity (team_id or org_id, RFC 8693 delegation) + - scope: Tool-level access scopes (configurable via allowed_scopes) + - iat, exp, nbf: Standard timing claims + + Feature set: + FR-5: Verify + re-sign (access_token_discovery_uri, token_introspection_endpoint) + FR-9: Debug headers (debug_headers) + FR-10: Configurable scopes (allowed_scopes) + FR-12: Configurable end-user identity mapping (end_user_claim_sources) + FR-13: Claim operations (add_claims, set_claims, remove_claims) + FR-14: Two-token model (channel_token_audience, channel_token_ttl) + FR-15: Incoming claim validation (required_claims, optional_claims) + """ + + ALGORITHM = "RS256" + DEFAULT_TTL = 300 + DEFAULT_AUDIENCE = "mcp" + SIGNING_KEY_ENV = "MCP_JWT_SIGNING_KEY" + + def __init__( + self, + # Core signing config + issuer: Optional[str] = None, + audience: Optional[str] = None, + ttl_seconds: Optional[int] = None, + # FR-5: Verify + re-sign + access_token_discovery_uri: Optional[str] = None, + token_introspection_endpoint: Optional[str] = None, + verify_issuer: Optional[str] = None, + verify_audience: Optional[str] = None, + # FR-12: End-user identity mapping + end_user_claim_sources: Optional[List[str]] = None, + # FR-13: Claim operations + add_claims: Optional[Dict[str, Any]] = None, + set_claims: Optional[Dict[str, Any]] = None, + remove_claims: Optional[List[str]] = None, + # FR-14: Two-token model + channel_token_audience: Optional[str] = None, + channel_token_ttl: Optional[int] = None, + # FR-15: Incoming claim validation + required_claims: Optional[List[str]] = None, + optional_claims: Optional[List[str]] = None, + # FR-9: Debug headers + debug_headers: bool = False, + # FR-10: Configurable scopes + allowed_scopes: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + # --- Signing key setup --- + key_material = os.environ.get(self.SIGNING_KEY_ENV) + if key_material: + self._private_key = _load_private_key_from_env(self.SIGNING_KEY_ENV) + self._persistent_key: bool = True + verbose_proxy_logger.info( + "MCPJWTSigner: loaded RSA key from env var %s", self.SIGNING_KEY_ENV + ) + else: + self._private_key = _generate_rsa_key_pair() + self._persistent_key = False + verbose_proxy_logger.info( + "MCPJWTSigner: auto-generated RSA-2048 keypair (set %s to use your own key)", + self.SIGNING_KEY_ENV, + ) + + self._public_key = self._private_key.public_key() + self._kid = _compute_kid(self._public_key) + + # --- Core config --- + self.issuer: str = ( + issuer + or os.environ.get("MCP_JWT_ISSUER") + or os.environ.get("LITELLM_EXTERNAL_URL") + or "litellm" + ) + self.audience: str = ( + audience + or os.environ.get("MCP_JWT_AUDIENCE") + or self.DEFAULT_AUDIENCE + ) + resolved_ttl = int( + ttl_seconds + if ttl_seconds is not None + else os.environ.get("MCP_JWT_TTL_SECONDS", str(self.DEFAULT_TTL)) + ) + if resolved_ttl <= 0: + raise ValueError( + f"MCPJWTSigner: ttl_seconds must be > 0, got {resolved_ttl}" + ) + self.ttl_seconds: int = resolved_ttl + + # --- FR-5: Verify + re-sign --- + self.access_token_discovery_uri: Optional[str] = access_token_discovery_uri + self.token_introspection_endpoint: Optional[str] = token_introspection_endpoint + self.verify_issuer: Optional[str] = verify_issuer + self.verify_audience: Optional[str] = verify_audience + # Cached OIDC discovery document (fetched lazily, TTL = 24 h) + self._oidc_discovery_doc: Optional[Dict[str, Any]] = None + self._oidc_discovery_fetched_at: float = 0.0 + + # --- FR-12: End-user identity mapping --- + # Default chain: try incoming JWT sub, fall back to litellm user_id + self.end_user_claim_sources: List[str] = end_user_claim_sources or [ + "token:sub", + "litellm:user_id", + ] + + # --- FR-13: Claim operations --- + self.add_claims: Dict[str, Any] = add_claims or {} + self.set_claims: Dict[str, Any] = set_claims or {} + self.remove_claims: List[str] = remove_claims or [] + + # --- FR-14: Two-token model --- + self.channel_token_audience: Optional[str] = channel_token_audience + self.channel_token_ttl: int = ( + channel_token_ttl if channel_token_ttl is not None else self.ttl_seconds + ) + + # --- FR-15: Incoming claim validation --- + self.required_claims: List[str] = required_claims or [] + self.optional_claims: List[str] = optional_claims or [] + + # --- FR-9: Debug headers --- + self.debug_headers: bool = debug_headers + + # --- FR-10: Configurable scopes --- + self.allowed_scopes: Optional[List[str]] = allowed_scopes + + # Register singleton for JWKS/OIDC discovery endpoints. + global _mcp_jwt_signer_instance + if _mcp_jwt_signer_instance is not None: + verbose_proxy_logger.warning( + "MCPJWTSigner: replacing existing singleton — previously issued tokens " + "signed with the old key will fail JWKS verification. " + "Avoid configuring multiple mcp_jwt_signer guardrails." + ) + _mcp_jwt_signer_instance = self + + verbose_proxy_logger.info( + "MCPJWTSigner initialized: issuer=%s audience=%s ttl=%ds kid=%s " + "verify=%s channel_token=%s debug=%s", + self.issuer, + self.audience, + self.ttl_seconds, + self._kid, + bool(self.access_token_discovery_uri), + bool(self.channel_token_audience), + self.debug_headers, + ) + + # ------------------------------------------------------------------ + # Public helpers (used by /.well-known/jwks.json endpoint) + # ------------------------------------------------------------------ + + @property + def jwks_max_age(self) -> int: + """ + Recommended Cache-Control max-age for the JWKS response (seconds). + + 1 hour for persistent keys; 5 minutes for auto-generated keys so MCP + servers re-fetch quickly after a proxy restart. + """ + return 3600 if self._persistent_key else 300 + + def get_jwks(self) -> Dict[str, Any]: + """ + Return the JWKS for the RSA public key. + Used by GET /.well-known/jwks.json so MCP servers can verify tokens. + """ + public_numbers = self._public_key.public_numbers() + return { + "keys": [ + { + "kty": "RSA", + "alg": self.ALGORITHM, + "use": "sig", + "kid": self._kid, + "n": _int_to_base64url(public_numbers.n), + "e": _int_to_base64url(public_numbers.e), + } + ] + } + + # ------------------------------------------------------------------ + # FR-5: Verify + re-sign helpers + # ------------------------------------------------------------------ + + # 24-hour TTL for the OIDC discovery doc — long enough to avoid hammering + # the IdP, short enough to pick up jwks_uri changes after key rotation. + _OIDC_DISCOVERY_TTL = 86400 + + async def _get_oidc_discovery(self) -> Dict[str, Any]: + """Fetch and cache the OIDC discovery document with a 24-hour TTL. + + Only caches when the doc contains a 'jwks_uri' so that a transient or + malformed response doesn't permanently disable JWT verification. + """ + now = time.time() + cache_expired = (now - self._oidc_discovery_fetched_at) >= self._OIDC_DISCOVERY_TTL + if (self._oidc_discovery_doc is None or cache_expired) and self.access_token_discovery_uri: + doc = await _fetch_oidc_discovery(self.access_token_discovery_uri) + if "jwks_uri" in doc: + self._oidc_discovery_doc = doc + self._oidc_discovery_fetched_at = now + else: + return doc + return self._oidc_discovery_doc or {} + + async def _verify_incoming_jwt(self, raw_token: str) -> Dict[str, Any]: + """ + Verify an incoming Bearer JWT against the configured IdP's JWKS. + + Returns the verified payload claims dict. + Raises jwt.PyJWTError (or subclass) if verification fails. + """ + discovery = await self._get_oidc_discovery() + jwks_uri = discovery.get("jwks_uri") + if not jwks_uri: + raise ValueError( + "MCPJWTSigner: access_token_discovery_uri discovery document " + f"at {self.access_token_discovery_uri!r} has no 'jwks_uri'." + ) + + jwks_keys = await _fetch_jwks(jwks_uri) + + # Only read `kid` from the unverified header — never `alg`. + # Reading `alg` from an attacker-controlled header enables algorithm + # confusion attacks (e.g. alg:none, HS256 with the public key as secret). + # The algorithm is determined from the JWKS key entry instead. + unverified_header = jwt.get_unverified_header(raw_token) + kid = unverified_header.get("kid") + + # Build a JWKS object and pick the matching key. + # PyJWT's PyJWKSet handles key-type parsing and kid matching correctly. + from jwt import PyJWKSet + + try: + jwks_set = PyJWKSet.from_dict({"keys": jwks_keys}) + except Exception as exc: + raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined] + f"Failed to parse JWKS from {jwks_uri!r}: {exc}" + ) from exc + + signing_jwk = None + for jwk_obj in jwks_set.keys: + if not kid or jwk_obj.key_id == kid: + signing_jwk = jwk_obj + break + + if signing_jwk is None: + raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined] + f"No JWKS key matching kid={kid!r} at {jwks_uri!r}" + ) + + # Use the algorithm declared by the JWKS key entry, not the token header. + # PyJWT populates algorithm_name from the key's `alg` field; when absent + # it infers from the key type (RSAPublicKey → RS256). + alg = getattr(signing_jwk, "algorithm_name", None) or "RS256" + + decode_options: Dict[str, Any] = {"verify_exp": True} + decode_kwargs: Dict[str, Any] = { + "algorithms": [alg], + "options": decode_options, + } + if self.verify_audience: + decode_kwargs["audience"] = self.verify_audience + else: + decode_options["verify_aud"] = False + + if self.verify_issuer: + decode_kwargs["issuer"] = self.verify_issuer + + payload: Dict[str, Any] = jwt.decode( + raw_token, signing_jwk.key, **decode_kwargs + ) + return payload + + async def _introspect_opaque_token(self, token: str) -> Dict[str, Any]: + """ + Perform RFC 7662 token introspection for opaque (non-JWT) tokens. + + Returns the introspection response dict. Raises on HTTP error or + inactive token. + """ + if not self.token_introspection_endpoint: + raise ValueError( + "MCPJWTSigner: token_introspection_endpoint is required for " + "opaque token verification but is not configured." + ) + + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + resp = await client.post( + self.token_introspection_endpoint, + data={"token": token}, + headers={"Accept": "application/json"}, + ) + resp.raise_for_status() + result: Dict[str, Any] = resp.json() + if not result.get("active", False): + raise jwt.exceptions.ExpiredSignatureError( # type: ignore[attr-defined] + "MCPJWTSigner: incoming token is inactive (introspection returned active=false)" + ) + return result + + # ------------------------------------------------------------------ + # FR-15: Incoming claim validation + # ------------------------------------------------------------------ + + def _validate_required_claims( + self, + jwt_claims: Optional[Dict[str, Any]], + ) -> None: + """ + Raise HTTP 403 if any required_claims are absent from the verified + incoming token claims. + """ + if not self.required_claims: + return + + from fastapi import HTTPException + + missing = [c for c in self.required_claims if not (jwt_claims or {}).get(c)] + if missing: + raise HTTPException( + status_code=403, + detail={ + "error": ( + f"MCPJWTSigner: incoming token is missing required claims: " + f"{missing}. Configure the IdP to include these claims." + ) + }, + ) + + # ------------------------------------------------------------------ + # FR-12: End-user identity mapping + # ------------------------------------------------------------------ + + def _resolve_end_user_identity( + self, + user_api_key_dict: UserAPIKeyAuth, + jwt_claims: Optional[Dict[str, Any]], + ) -> str: + """ + Resolve the outbound JWT 'sub' using the ordered end_user_claim_sources list. + + Supported source prefixes: + token: — from verified incoming JWT / introspection claims + litellm:user_id — from UserAPIKeyAuth.user_id + litellm:email — from UserAPIKeyAuth.user_email + litellm:end_user_id — from UserAPIKeyAuth.end_user_id + litellm:team_id — from UserAPIKeyAuth.team_id + + Falls back to a stable hash of the API token for service-account callers. + """ + for source in self.end_user_claim_sources: + value: Optional[str] = None + + if source.startswith("token:"): + claim_name = source[len("token:"):] + raw = (jwt_claims or {}).get(claim_name) + value = str(raw) if raw else None + + elif source == "litellm:user_id": + uid = getattr(user_api_key_dict, "user_id", None) + value = str(uid) if uid else None + + elif source == "litellm:email": + email = getattr(user_api_key_dict, "user_email", None) + value = str(email) if email else None + + elif source == "litellm:end_user_id": + eid = getattr(user_api_key_dict, "end_user_id", None) + value = str(eid) if eid else None + + elif source == "litellm:team_id": + tid = getattr(user_api_key_dict, "team_id", None) + value = str(tid) if tid else None + + else: + verbose_proxy_logger.warning( + "MCPJWTSigner: unknown end_user_claim_source %r — skipping", source + ) + continue + + if value: + return value + + # Final fallback for service accounts with no user identity + token = getattr(user_api_key_dict, "token", None) or getattr( + user_api_key_dict, "api_key", None + ) + if token: + return "apikey:" + hashlib.sha256(str(token).encode()).hexdigest()[:16] + return "litellm-proxy" + + # ------------------------------------------------------------------ + # FR-10: Scope building + # ------------------------------------------------------------------ + + def _build_scope(self, raw_tool_name: str) -> str: + """ + Build the JWT scope string. + + When allowed_scopes is configured: join them verbatim. + Otherwise auto-generate minimal, least-privilege scopes: + - Tool call → mcp:tools/call mcp:tools/:call + - No tool → mcp:tools/call mcp:tools/list + + NOTE: tools/list is intentionally NOT granted on tool-call JWTs to + prevent callers from enumerating tools they didn't ask to use. + """ + if self.allowed_scopes is not None: + return " ".join(self.allowed_scopes) + + tool_name = ( + re.sub(r"[^a-zA-Z0-9_\-]", "_", raw_tool_name) if raw_tool_name else "" + ) + if tool_name: + scopes = ["mcp:tools/call", f"mcp:tools/{tool_name}:call"] + else: + scopes = ["mcp:tools/call", "mcp:tools/list"] + return " ".join(scopes) + + # ------------------------------------------------------------------ + # FR-13: Claim operations + # ------------------------------------------------------------------ + + def _apply_claim_operations(self, claims: Dict[str, Any]) -> Dict[str, Any]: + """Apply add_claims, set_claims, and remove_claims to the claim dict.""" + # add_claims: insert only when key is absent + for k, v in self.add_claims.items(): + if k not in claims: + claims[k] = v + + # set_claims: always override (highest priority) + claims = {**claims, **self.set_claims} + + # remove_claims: delete listed keys + for k in self.remove_claims: + claims.pop(k, None) + + return claims + + # ------------------------------------------------------------------ + # FR-15: optional_claims passthrough + # ------------------------------------------------------------------ + + def _passthrough_optional_claims( + self, + claims: Dict[str, Any], + jwt_claims: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + """Forward optional_claims from verified incoming token into the outbound JWT.""" + if not self.optional_claims or not jwt_claims: + return claims + for claim in self.optional_claims: + if claim in jwt_claims and claim not in claims: + claims[claim] = jwt_claims[claim] + return claims + + # ------------------------------------------------------------------ + # Core JWT builder + # ------------------------------------------------------------------ + + def _build_claims( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + jwt_claims: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Build JWT claims for the outbound MCP access token. + + Args: + user_api_key_dict: LiteLLM auth context for the current request. + data: Pre-call hook data dict (contains mcp_tool_name etc.). + jwt_claims: Verified incoming IdP claims (FR-5), or LiteLLM-decoded + jwt_claims if available. None for pure API-key requests. + """ + now = int(time.time()) + claims: Dict[str, Any] = { + "iss": self.issuer, + "aud": self.audience, + "iat": now, + "exp": now + self.ttl_seconds, + "nbf": now, + } + + # sub — resolved via ordered claim sources (FR-12) + claims["sub"] = self._resolve_end_user_identity(user_api_key_dict, jwt_claims) + + # email passthrough when available from LiteLLM context + user_email = getattr(user_api_key_dict, "user_email", None) + if user_email: + claims["email"] = user_email + + # act — RFC 8693 delegation claim (team/org context) + team_id = getattr(user_api_key_dict, "team_id", None) + org_id = getattr(user_api_key_dict, "org_id", None) + act_sub = team_id or org_id or "litellm-proxy" + claims["act"] = {"sub": act_sub} + + # end_user_id when set separately from user_id + end_user_id = getattr(user_api_key_dict, "end_user_id", None) + if end_user_id: + claims["end_user_id"] = end_user_id + + # scope (FR-10) + raw_tool_name: str = data.get("mcp_tool_name", "") + claims["scope"] = self._build_scope(raw_tool_name) + + # optional_claims passthrough (FR-15) + claims = self._passthrough_optional_claims(claims, jwt_claims) + + # Claim operations — applied last so admin overrides take effect (FR-13) + claims = self._apply_claim_operations(claims) + + return claims + + def _build_channel_token_claims( + self, + base_claims: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Build claims for the channel token (FR-14 two-token model). + + Inherits sub/act/scope from the access token but uses a separate + audience and TTL so the transport layer and resource layer receive + purpose-bound credentials. + """ + now = int(time.time()) + return { + **base_claims, + "aud": self.channel_token_audience, + "iat": now, + "exp": now + self.channel_token_ttl, + "nbf": now, + } + + # ------------------------------------------------------------------ + # FR-9: Debug header + # ------------------------------------------------------------------ + + @staticmethod + def _build_debug_header(claims: Dict[str, Any], kid: str) -> str: + """ + Build the x-litellm-mcp-debug header value. + + Format: v=1; kid=; sub=; iss=; exp=; scope= + Scope is truncated to 80 chars for header safety. + """ + sub = claims.get("sub", "") + iss = claims.get("iss", "") + exp = claims.get("exp", 0) + scope = claims.get("scope", "") + if len(scope) > 80: + scope = scope[:77] + "..." + return f"v=1; kid={kid}; sub={sub}; iss={iss}; exp={exp}; scope={scope}" + + # ------------------------------------------------------------------ + # Guardrail hook + # ------------------------------------------------------------------ + + @log_guardrail_information + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + """ + Verifies the incoming token (when configured), validates required claims, + then signs an outbound JWT and injects it as the Authorization header. + + All non-MCP call types pass through unchanged. + """ + if call_type != "call_mcp_tool": + return data + + # ------------------------------------------------------------------ + # FR-5: Verify incoming token before re-signing + # ------------------------------------------------------------------ + jwt_claims: Optional[Dict[str, Any]] = None + raw_token: Optional[str] = data.get("incoming_bearer_token") + + if self.access_token_discovery_uri and raw_token: + # Three-dot pattern → JWT; otherwise opaque. + is_jwt = raw_token.count(".") == 2 + try: + if is_jwt: + jwt_claims = await self._verify_incoming_jwt(raw_token) + elif self.token_introspection_endpoint: + jwt_claims = await self._introspect_opaque_token(raw_token) + else: + verbose_proxy_logger.warning( + "MCPJWTSigner: access_token_discovery_uri is set but the " + "incoming token appears to be opaque and no " + "token_introspection_endpoint is configured. " + "Proceeding without incoming token verification." + ) + except Exception as exc: + verbose_proxy_logger.error( + "MCPJWTSigner: incoming token verification failed: %s", exc + ) + from fastapi import HTTPException + + raise HTTPException( + status_code=401, + detail={ + "error": ( + f"MCPJWTSigner: incoming token verification failed: {exc}" + ) + }, + ) + elif not raw_token and self.access_token_discovery_uri: + verbose_proxy_logger.debug( + "MCPJWTSigner: access_token_discovery_uri configured but no Bearer " + "token found in request (API-key auth request — skipping verification)." + ) + + # Fall back to LiteLLM-decoded JWT claims (available when proxy uses JWT auth). + if jwt_claims is None: + jwt_claims = getattr(user_api_key_dict, "jwt_claims", None) + + # ------------------------------------------------------------------ + # FR-15: Validate required claims + # ------------------------------------------------------------------ + self._validate_required_claims(jwt_claims) + + # ------------------------------------------------------------------ + # Build outbound access token + # ------------------------------------------------------------------ + claims = self._build_claims(user_api_key_dict, data, jwt_claims) + + signed_token = jwt.encode( + claims, + self._private_key, + algorithm=self.ALGORITHM, + headers={"kid": self._kid}, + ) + + # Merge into existing extra_headers — a prior guardrail in the chain may + # have already injected tracing headers or correlation IDs. + existing_headers: Dict[str, str] = data.get("extra_headers") or {} + new_headers: Dict[str, str] = { + **existing_headers, + "Authorization": f"Bearer {signed_token}", + } + + # ------------------------------------------------------------------ + # FR-14: Two-token model — channel token + # ------------------------------------------------------------------ + if self.channel_token_audience: + channel_claims = self._build_channel_token_claims(claims) + channel_token = jwt.encode( + channel_claims, + self._private_key, + algorithm=self.ALGORITHM, + headers={"kid": self._kid}, + ) + new_headers["x-mcp-channel-token"] = f"Bearer {channel_token}" + + # ------------------------------------------------------------------ + # FR-9: Debug header + # ------------------------------------------------------------------ + if self.debug_headers: + new_headers["x-litellm-mcp-debug"] = self._build_debug_header( + claims, self._kid + ) + + data["extra_headers"] = new_headers + + verbose_proxy_logger.debug( + "MCPJWTSigner: signed JWT sub=%s act=%s tool=%s exp=%d " + "verified=%s channel=%s", + claims.get("sub"), + claims.get("act", {}).get("sub"), + data.get("mcp_tool_name"), + claims["exp"], + jwt_claims is not None, + bool(self.channel_token_audience), + ) + + return data diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 01a0f55aac7..b2ec8899db9 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -454,8 +454,6 @@ def _add_proxy_hooks(self, llm_router: Optional[Router] = None): for hook in PROXY_HOOKS: proxy_hook = get_proxy_hook(hook) - import inspect - expected_args = inspect.getfullargspec(proxy_hook).args passed_in_args: Dict[str, Any] = {} if "internal_usage_cache" in expected_args: @@ -559,6 +557,10 @@ def _convert_mcp_to_llm_format(self, request_obj, kwargs: dict) -> dict: "user_api_key_request_route": kwargs.get("user_api_key_request_route"), "mcp_tool_name": request_obj.tool_name, # Keep original for reference "mcp_arguments": request_obj.arguments, # Keep original for reference + # Raw Bearer token from the original HTTP request — allows guardrails + # (e.g. MCPJWTSigner) to independently verify the caller's identity + # before re-signing an outbound token (FR-5 verify+re-sign). + "incoming_bearer_token": kwargs.get("incoming_bearer_token"), } return synthetic_data @@ -824,17 +826,27 @@ def _convert_mcp_hook_response_to_kwargs( ) -> dict: """ Helper function to convert pre_call_hook response back to kwargs for MCP usage. + + Supports: + - modified_arguments: Override tool call arguments + - extra_headers: Inject custom headers into the outbound MCP request """ if not response_data: return original_kwargs - # Apply any argument modifications from the hook response modified_kwargs = original_kwargs.copy() - # If the response contains modified arguments, apply them if response_data.get("modified_arguments"): modified_kwargs["arguments"] = response_data["modified_arguments"] + if response_data.get("extra_headers"): + # Merge rather than replace — a prior guardrail in the chain may have + # already injected headers (e.g. tracing IDs). Later guardrails win on + # key collisions so that the most-specific guardrail (e.g. JWT signer) + # takes precedence over earlier ones. + existing = modified_kwargs.get("extra_headers") or {} + modified_kwargs["extra_headers"] = {**existing, **response_data["extra_headers"]} + return modified_kwargs async def process_pre_call_hook_response(self, response, data, call_type): diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 27fa27e6da3..f798f05380d 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -79,6 +79,7 @@ class SupportedGuardrailIntegrations(Enum): SEMANTIC_GUARD = "semantic_guard" MCP_END_USER_PERMISSION = "mcp_end_user_permission" BLOCK_CODE_EXECUTION = "block_code_execution" + MCP_JWT_SIGNER = "mcp_jwt_signer" class Role(Enum): diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py new file mode 100644 index 00000000000..32f3a340855 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py @@ -0,0 +1,707 @@ +""" +Tests for pre_mcp_call guardrail hook header mutation support. + +Validates that: +1. _convert_mcp_hook_response_to_kwargs extracts extra_headers from hook response +2. pre_call_tool_check returns hook-provided extra_headers AND modified arguments +3. call_tool flows hook headers and modified arguments downstream +4. Hook-provided headers take highest priority (merge after static_headers) +5. OpenAPI-backed servers log a warning and continue (skip injection) when hook headers are present +6. JWT claims are propagated in both standard and virtual-key fast paths +7. Backward compatibility: hooks without extra_headers continue to work +""" + +import asyncio +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.utils import ProxyLogging +from litellm.types.mcp import MCPAuth, MCPTransport +from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +class TestConvertMcpHookResponseToKwargs: + """Tests for ProxyLogging._convert_mcp_hook_response_to_kwargs""" + + def setup_method(self): + self.proxy_logging = ProxyLogging(user_api_key_cache=MagicMock()) + + def test_returns_original_kwargs_when_response_is_none(self): + original = {"arguments": {"key": "val"}, "name": "tool"} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + None, original + ) + assert result == original + + def test_returns_original_kwargs_when_response_is_empty_dict(self): + original = {"arguments": {"key": "val"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs({}, original) + assert result == original + + def test_extracts_modified_arguments(self): + original = {"arguments": {"old": "value"}} + response = {"modified_arguments": {"new": "value"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["arguments"] == {"new": "value"} + + def test_extracts_extra_headers(self): + original = {"arguments": {"key": "val"}} + response = {"extra_headers": {"Authorization": "Bearer signed-jwt"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["extra_headers"] == {"Authorization": "Bearer signed-jwt"} + + def test_extracts_both_arguments_and_headers(self): + original = {"arguments": {"old": "value"}} + response = { + "modified_arguments": {"new": "value"}, + "extra_headers": {"X-Custom": "header-val"}, + } + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["arguments"] == {"new": "value"} + assert result["extra_headers"] == {"X-Custom": "header-val"} + + def test_no_extra_headers_key_preserves_original(self): + """Backward compat: hooks that only return modified_arguments still work.""" + original = {"arguments": {"key": "val"}} + response = {"modified_arguments": {"key": "new_val"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert "extra_headers" not in result + assert result["arguments"] == {"key": "new_val"} + + def test_empty_extra_headers_not_set(self): + """Empty dict for extra_headers is falsy and should not be set.""" + original = {"arguments": {"key": "val"}} + response = {"extra_headers": {}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert "extra_headers" not in result + + +class TestPreCallToolCheckReturnsHeaders: + """Tests that pre_call_tool_check returns hook-provided headers.""" + + def _make_server(self, name="test_server"): + return MCPServer( + server_id="test-id", + name=name, + server_name=name, + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + ) + + @pytest.mark.asyncio + async def test_returns_empty_dict_when_hook_has_no_headers(self): + manager = MCPServerManager() + server = self._make_server() + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"modified_arguments": {"key": "val"}} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": {"key": "val"}} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result == {} + + @pytest.mark.asyncio + async def test_returns_extra_headers_from_hook(self): + manager = MCPServerManager() + server = self._make_server() + + hook_headers = {"Authorization": "Bearer signed-jwt", "X-Trace-Id": "abc123"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"extra_headers": hook_headers} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": {"key": "val"}, "extra_headers": hook_headers} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["extra_headers"] == hook_headers + + @pytest.mark.asyncio + async def test_returns_empty_dict_when_hook_returns_none(self): + manager = MCPServerManager() + server = self._make_server() + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock(return_value=None) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result == {} + + @pytest.mark.asyncio + async def test_returns_modified_arguments_from_hook(self): + """Modified arguments from the hook must be returned so the caller can use them.""" + manager = MCPServerManager() + server = self._make_server() + + original_args = {"key": "original"} + modified_args = {"key": "modified", "extra": "added"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"modified_arguments": modified_args} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": modified_args} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments=original_args, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["arguments"] == modified_args + + @pytest.mark.asyncio + async def test_returns_both_modified_arguments_and_headers(self): + """Hook can modify both arguments and inject headers simultaneously.""" + manager = MCPServerManager() + server = self._make_server() + + modified_args = {"key": "modified"} + hook_headers = {"Authorization": "Bearer jwt"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock(return_value={"dummy": True}) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": modified_args, "extra_headers": hook_headers} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "original"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["arguments"] == modified_args + assert result["extra_headers"] == hook_headers + + +class TestCallToolFlowsHookHeaders: + """Tests that call_tool passes hook_extra_headers to _call_regular_mcp_tool.""" + + def _make_server(self, name="test_server"): + return MCPServer( + server_id="test-id", + name=name, + server_name=name, + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + ) + + @pytest.mark.asyncio + async def test_hook_headers_passed_to_call_regular_mcp_tool(self): + """Verify that hook_extra_headers kwarg is forwarded.""" + manager = MCPServerManager() + server = self._make_server() + + hook_headers = {"Authorization": "Bearer signed-jwt"} + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"extra_headers": hook_headers}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "val"}, + proxy_logging_obj=proxy_logging, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("hook_extra_headers") == hook_headers + + @pytest.mark.asyncio + async def test_no_hook_headers_when_no_proxy_logging(self): + """Without proxy_logging_obj, no pre_call_tool_check runs.""" + manager = MCPServerManager() + server = self._make_server() + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "val"}, + proxy_logging_obj=None, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("hook_extra_headers") is None + + @pytest.mark.asyncio + async def test_modified_arguments_passed_to_downstream(self): + """Hook-modified arguments must be used for the actual tool call.""" + manager = MCPServerManager() + server = self._make_server() + + modified_args = {"key": "modified_by_hook"} + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"arguments": modified_args}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "original"}, + proxy_logging_obj=proxy_logging, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("arguments") == modified_args + + @pytest.mark.asyncio + async def test_openapi_server_warns_and_continues_on_hook_headers(self): + """OpenAPI-backed servers log a warning and continue when hook injects headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="openapi_server", + server_name="openapi_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + spec_path="/path/to/spec.yaml", + ) + + with patch.object( + manager, "_get_mcp_server_from_tool_name", return_value=server + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"extra_headers": {"Authorization": "Bearer jwt"}}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_openapi_tool_handler", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + import litellm.proxy._experimental.mcp_server.mcp_server_manager as mgr_mod + + proxy_logging = MagicMock(spec=ProxyLogging) + + with patch.object(mgr_mod, "verbose_logger") as mock_logger: + # Should NOT raise — just warn and proceed + await manager.call_tool( + server_name="openapi_server", + name="test_tool", + arguments={}, + proxy_logging_obj=proxy_logging, + ) + mock_logger.warning.assert_called_once() + assert "header injection is not supported" in mock_logger.warning.call_args[0][0] + + @pytest.mark.asyncio + async def test_openapi_server_no_error_without_hook_headers(self): + """No exception when OpenAPI server has no hook-injected headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="openapi_server", + server_name="openapi_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + spec_path="/path/to/spec.yaml", + ) + + with patch.object( + manager, "_get_mcp_server_from_tool_name", return_value=server + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_openapi_tool_handler", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="openapi_server", + name="test_tool", + arguments={}, + proxy_logging_obj=proxy_logging, + ) + + +class TestHookHeaderMergePriority: + """Tests that hook-provided headers have highest priority in _call_regular_mcp_tool.""" + + def _make_server( + self, + static_headers: Optional[Dict[str, str]] = None, + extra_headers_config: Optional[list] = None, + ): + return MCPServer( + server_id="test-id", + name="Test Server", + server_name="test_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + static_headers=static_headers, + extra_headers=extra_headers_config, + ) + + @pytest.mark.asyncio + async def test_hook_headers_override_static_headers(self): + """Hook headers should take precedence over static_headers.""" + manager = MCPServerManager() + server = self._make_server( + static_headers={"Authorization": "Bearer static-token", "X-Static": "yes"} + ) + + hook_headers = {"Authorization": "Bearer hook-signed-jwt"} + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers=None, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers=hook_headers, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers["Authorization"] == "Bearer hook-signed-jwt" + assert headers["X-Static"] == "yes" + + @pytest.mark.asyncio + async def test_no_hook_headers_preserves_existing_behavior(self): + """When hook_extra_headers is None, existing header logic is unchanged.""" + manager = MCPServerManager() + server = self._make_server( + static_headers={"X-Static": "static-value"} + ) + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers=None, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers=None, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers == {"X-Static": "static-value"} + + @pytest.mark.asyncio + async def test_hook_headers_merge_with_oauth2(self): + """Hook headers merge on top of OAuth2 headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="Test Server", + server_name="test_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + ) + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers={ + "Authorization": "Bearer oauth2-token", + "X-OAuth": "yes", + }, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers={ + "Authorization": "Bearer hook-jwt", + "X-Trace-Id": "trace-123", + }, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers["Authorization"] == "Bearer hook-jwt" + assert headers["X-OAuth"] == "yes" + assert headers["X-Trace-Id"] == "trace-123" + + +class TestUserAPIKeyAuthJwtClaims: + """Tests that UserAPIKeyAuth correctly carries jwt_claims.""" + + def test_jwt_claims_field_defaults_to_none(self): + auth = UserAPIKeyAuth(api_key="test-key") + assert auth.jwt_claims is None + + def test_jwt_claims_field_accepts_dict(self): + claims = {"sub": "user-123", "iss": "litellm", "exp": 9999999999} + auth = UserAPIKeyAuth(api_key="test-key", jwt_claims=claims) + assert auth.jwt_claims == claims + assert auth.jwt_claims["sub"] == "user-123" + + def test_jwt_claims_backward_compatible_without_field(self): + """Existing code that doesn't pass jwt_claims should still work.""" + auth = UserAPIKeyAuth( + api_key="test-key", + user_id="user-1", + team_id="team-1", + ) + assert auth.jwt_claims is None + assert auth.user_id == "user-1" + + def test_jwt_claims_set_after_construction(self): + """Virtual-key fast path sets jwt_claims after the object is created.""" + auth = UserAPIKeyAuth(api_key="test-key") + assert auth.jwt_claims is None + + claims = {"sub": "user-456", "iss": "okta", "groups": ["admin"]} + auth.jwt_claims = claims + assert auth.jwt_claims == claims + assert auth.jwt_claims["groups"] == ["admin"] diff --git a/tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py b/tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py new file mode 100644 index 00000000000..247fe7b5764 --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py @@ -0,0 +1,1103 @@ +""" +Tests for the MCPJWTSigner built-in guardrail. + +Tests cover: + - RSA key generation and loading + - JWT signing and JWKS format + - Claim building (sub, act, scope) + - Hook fires for call_mcp_tool, skips other call types + - get_mcp_jwt_signer() singleton pattern +""" + +import base64 +import time +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import jwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_user_api_key_dict( + user_id: str = "user-123", + team_id: str = "team-abc", + user_email: str = "user@example.com", + end_user_id: Optional[str] = None, +) -> MagicMock: + mock = MagicMock() + mock.user_id = user_id + mock.team_id = team_id + mock.user_email = user_email + mock.end_user_id = end_user_id + mock.org_id = None + mock.token = None + mock.api_key = None + # Explicit None so MagicMock doesn't auto-create a truthy proxy attribute + mock.jwt_claims = None + return mock + + +def _decode_unverified(token: str) -> Dict[str, Any]: + return jwt.decode(token, options={"verify_signature": False}) + + +# --------------------------------------------------------------------------- +# Import target (inline so we can reset the singleton between tests) +# --------------------------------------------------------------------------- + + +def _make_signer(**kwargs: Any): + # Reset singleton before each signer creation to avoid cross-test pollution + import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as mod + + mod._mcp_jwt_signer_instance = None + + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + MCPJWTSigner, + ) + + return MCPJWTSigner( + guardrail_name="test-jwt-signer", + event_hook="pre_mcp_call", + default_on=True, + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Key generation tests +# --------------------------------------------------------------------------- + + +def test_auto_generates_rsa_keypair(): + """MCPJWTSigner auto-generates an RSA-2048 keypair when env var is unset.""" + signer = _make_signer() + assert signer._private_key is not None + assert signer._public_key is not None + assert signer._kid is not None and len(signer._kid) == 16 + + +def test_kid_is_deterministic(): + """Two signers built from the same key have the same kid.""" + signer1 = _make_signer() + private_pem = signer1._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + with patch.dict("os.environ", {"MCP_JWT_SIGNING_KEY": private_pem}): + signer2 = _make_signer() + + assert signer1._kid == signer2._kid + + +def test_load_key_from_env_var(): + """MCPJWTSigner loads a user-provided RSA key from the env var.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + with patch.dict("os.environ", {"MCP_JWT_SIGNING_KEY": pem}): + signer = _make_signer() + + assert signer._kid is not None + + +# --------------------------------------------------------------------------- +# JWKS tests +# --------------------------------------------------------------------------- + + +def test_get_jwks_format(): + """get_jwks() returns a valid JWKS dict with RSA fields.""" + signer = _make_signer() + jwks = signer.get_jwks() + + assert "keys" in jwks + assert len(jwks["keys"]) == 1 + key = jwks["keys"][0] + + assert key["kty"] == "RSA" + assert key["alg"] == "RS256" + assert key["use"] == "sig" + assert key["kid"] == signer._kid + assert "n" in key and len(key["n"]) > 0 + assert "e" in key and key["e"] == "AQAB" # 65537 in base64url + + +def test_jwks_public_key_can_verify_signed_jwt(): + """A JWT signed by MCPJWTSigner can be verified using the JWKS public key.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp") + now = int(time.time()) + claims = {"iss": "https://litellm.example.com", "aud": "mcp", "iat": now, "exp": now + 300} + + token = jwt.encode(claims, signer._private_key, algorithm="RS256") + + # Reconstruct public key from JWKS + jwks = signer.get_jwks() + key_data = jwks["keys"][0] + n = int.from_bytes(base64.urlsafe_b64decode(key_data["n"] + "=="), byteorder="big") + e = int.from_bytes(base64.urlsafe_b64decode(key_data["e"] + "=="), byteorder="big") + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers + pub_key = RSAPublicNumbers(e=e, n=n).public_key() + + decoded = jwt.decode( + token, + pub_key, + algorithms=["RS256"], + audience="mcp", + issuer="https://litellm.example.com", + ) + assert decoded["iss"] == "https://litellm.example.com" + + +# --------------------------------------------------------------------------- +# Claim building tests +# --------------------------------------------------------------------------- + + +def test_build_claims_standard_fields(): + """_build_claims() populates iss, aud, iat, exp, nbf.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp", ttl_seconds=300) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "get_weather"} + + claims = signer._build_claims(user_dict, data) + + assert claims["iss"] == "https://litellm.example.com" + assert claims["aud"] == "mcp" + assert "iat" in claims + assert "exp" in claims + assert claims["exp"] - claims["iat"] == 300 + assert "nbf" in claims + + +def test_build_claims_identity(): + """_build_claims() sets sub from user_id and act from team_id (RFC 8693).""" + signer = _make_signer() + user_dict = _make_user_api_key_dict(user_id="user-xyz", team_id="team-eng") + data: Dict[str, Any] = {} + + claims = signer._build_claims(user_dict, data) + + assert claims["sub"] == "user-xyz" + assert claims["act"]["sub"] == "team-eng" + assert claims["email"] == "user@example.com" + + +def test_build_claims_scope_with_tool(): + """_build_claims() encodes tool-specific scope when mcp_tool_name is set.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "search_web"} + + claims = signer._build_claims(user_dict, data) + + scopes = set(claims["scope"].split()) + assert "mcp:tools/call" in scopes + assert "mcp:tools/search_web:call" in scopes + # Tool-call JWTs must NOT carry mcp:tools/list — least-privilege + assert "mcp:tools/list" not in scopes + + +def test_build_claims_scope_without_tool(): + """_build_claims() includes mcp:tools/list when no specific tool is called.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data: Dict[str, Any] = {} + + claims = signer._build_claims(user_dict, data) + + scopes = set(claims["scope"].split()) + assert "mcp:tools/call" in scopes + assert "mcp:tools/list" in scopes + # No per-tool call scope when no tool name was given + assert not any(s.endswith(":call") and s != "mcp:tools/call" for s in scopes) + + +def test_build_claims_act_fallback_to_litellm_proxy(): + """_build_claims() falls back to 'litellm-proxy' when team_id and org_id are absent.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + user_dict.team_id = None + user_dict.org_id = None + + claims = signer._build_claims(user_dict, {}) + + assert claims["act"]["sub"] == "litellm-proxy" + + +def test_build_claims_sub_fallback_to_token_hash(): + """_build_claims() sets sub to an apikey: hash when user_id is absent.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict(user_id="") + user_dict.user_id = None + user_dict.token = "sk-test-api-key-abc123" + + claims = signer._build_claims(user_dict, {}) + + assert claims["sub"].startswith("apikey:") + assert len(claims["sub"]) == len("apikey:") + 16 # sha256 hex[:16] + + +def test_build_claims_sub_fallback_to_litellm_proxy_when_no_token(): + """_build_claims() falls back to 'litellm-proxy' when user_id and token are both absent.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict(user_id="") + user_dict.user_id = None + user_dict.token = None + user_dict.api_key = None + + claims = signer._build_claims(user_dict, {}) + + assert claims["sub"] == "litellm-proxy" + + +def test_init_raises_on_zero_ttl(): + """MCPJWTSigner raises ValueError when ttl_seconds is 0.""" + with pytest.raises(ValueError, match="ttl_seconds must be > 0"): + _make_signer(ttl_seconds=0) + + +def test_init_raises_on_negative_ttl(): + """MCPJWTSigner raises ValueError when ttl_seconds is negative.""" + with pytest.raises(ValueError, match="ttl_seconds must be > 0"): + _make_signer(ttl_seconds=-60) + + +def test_jwks_max_age_persistent_key(): + """jwks_max_age is 3600 when key loaded from env var.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa as crsa + + private_key = crsa.generate_private_key(public_exponent=65537, key_size=2048) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + with patch.dict("os.environ", {"MCP_JWT_SIGNING_KEY": pem}): + signer = _make_signer() + + assert signer.jwks_max_age == 3600 + + +def test_jwks_max_age_auto_generated_key(): + """jwks_max_age is 300 for auto-generated (ephemeral) keys.""" + signer = _make_signer() + assert signer.jwks_max_age == 300 + + +# --------------------------------------------------------------------------- +# Hook dispatch tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_fires_for_call_mcp_tool(): + """async_pre_call_hook() injects Authorization header for call_mcp_tool.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp") + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "do_thing"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "extra_headers" in result + assert result["extra_headers"]["Authorization"].startswith("Bearer ") + + +@pytest.mark.asyncio +async def test_hook_skips_non_mcp_call_types(): + """async_pre_call_hook() leaves data unchanged for non-MCP call types.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"messages": [{"role": "user", "content": "hello"}]} + + for call_type in ("completion", "acompletion", "embedding", "list_mcp_tools"): + original_data = {**data} + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=original_data, + call_type=call_type, # type: ignore[arg-type] + ) + assert "extra_headers" not in (result or {}), f"extra_headers should not be set for {call_type}" + + +@pytest.mark.asyncio +async def test_signed_token_is_verifiable(): + """The JWT injected by the hook can be verified against the JWKS public key.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp", ttl_seconds=300) + user_dict = _make_user_api_key_dict(user_id="alice", team_id="backend") + data = {"mcp_tool_name": "search"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + token = result["extra_headers"]["Authorization"].removeprefix("Bearer ") + + decoded = _decode_unverified(token) + assert decoded["sub"] == "alice" + assert decoded["act"]["sub"] == "backend" + assert "mcp:tools/search:call" in decoded["scope"] + assert decoded["iss"] == "https://litellm.example.com" + assert decoded["aud"] == "mcp" + + +# --------------------------------------------------------------------------- +# Singleton tests +# --------------------------------------------------------------------------- + + +def test_get_mcp_jwt_signer_returns_none_before_init(): + """get_mcp_jwt_signer() returns None before any MCPJWTSigner is created.""" + import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as mod + + mod._mcp_jwt_signer_instance = None + + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + assert get_mcp_jwt_signer() is None + + +def test_get_mcp_jwt_signer_returns_instance_after_init(): + """get_mcp_jwt_signer() returns the initialized signer instance.""" + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + signer = _make_signer() + assert get_mcp_jwt_signer() is signer + + +# --------------------------------------------------------------------------- +# FR-10: Configurable scopes +# --------------------------------------------------------------------------- + + +def test_allowed_scopes_replaces_auto_generation(): + """When allowed_scopes is set it is used verbatim instead of auto-generating.""" + signer = _make_signer(allowed_scopes=["mcp:admin", "mcp:tools/call"]) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "some_tool"} + + claims = signer._build_claims(user_dict, data) + + assert claims["scope"] == "mcp:admin mcp:tools/call" + + +def test_tool_call_scope_no_list_permission(): + """Tool-call JWTs must NOT carry mcp:tools/list (least-privilege).""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "my_tool"} + + claims = signer._build_claims(user_dict, data) + + scopes = set(claims["scope"].split()) + assert "mcp:tools/list" not in scopes + assert "mcp:tools/call" in scopes + assert "mcp:tools/my_tool:call" in scopes + + +# --------------------------------------------------------------------------- +# FR-12: End-user identity mapping +# --------------------------------------------------------------------------- + + +def test_end_user_claim_sources_token_sub(): + """end_user_claim_sources resolves sub from incoming JWT claims.""" + signer = _make_signer(end_user_claim_sources=["token:sub", "litellm:user_id"]) + user_dict = _make_user_api_key_dict(user_id="litellm-user") + jwt_claims = {"sub": "idp-user-123", "email": "idp@example.com"} + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["sub"] == "idp-user-123" + + +def test_end_user_claim_sources_falls_back_to_litellm_user_id(): + """Falls back to litellm:user_id when token:sub is absent.""" + signer = _make_signer(end_user_claim_sources=["token:sub", "litellm:user_id"]) + user_dict = _make_user_api_key_dict(user_id="litellm-user") + jwt_claims: Dict[str, Any] = {} # no sub + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["sub"] == "litellm-user" + + +def test_end_user_claim_sources_email_source(): + """token:email resolves correctly.""" + signer = _make_signer(end_user_claim_sources=["token:email"]) + user_dict = _make_user_api_key_dict(user_id="") + user_dict.user_id = None + jwt_claims = {"email": "alice@corp.com"} + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["sub"] == "alice@corp.com" + + +def test_end_user_claim_sources_litellm_email(): + """litellm:email resolves from UserAPIKeyAuth.user_email.""" + signer = _make_signer(end_user_claim_sources=["litellm:email"]) + user_dict = _make_user_api_key_dict(user_email="proxy-user@example.com") + user_dict.user_id = None + + claims = signer._build_claims(user_dict, {}) + + assert claims["sub"] == "proxy-user@example.com" + + +# --------------------------------------------------------------------------- +# FR-13: Claim operations +# --------------------------------------------------------------------------- + + +def test_add_claims_inserts_when_absent(): + """add_claims inserts key when it is not already in the JWT.""" + signer = _make_signer(add_claims={"deployment_id": "prod-001"}) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert claims["deployment_id"] == "prod-001" + + +def test_add_claims_does_not_overwrite_existing(): + """add_claims does NOT overwrite an existing claim (use set_claims for that).""" + signer = _make_signer(add_claims={"iss": "should-not-win"}) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + # iss should be the configured issuer, not overwritten + assert claims["iss"] != "should-not-win" + + +def test_set_claims_always_overrides(): + """set_claims always overrides computed claims.""" + signer = _make_signer(set_claims={"iss": "override-issuer", "custom": "x"}) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert claims["iss"] == "override-issuer" + assert claims["custom"] == "x" + + +def test_remove_claims_deletes_keys(): + """remove_claims deletes specified keys from the final JWT.""" + signer = _make_signer(remove_claims=["nbf", "email"]) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert "nbf" not in claims + assert "email" not in claims + + +def test_claim_operations_order_add_then_set_then_remove(): + """add → set → remove is applied in order: set wins over add, remove beats both.""" + signer = _make_signer( + add_claims={"x": "from-add"}, + set_claims={"x": "from-set"}, + remove_claims=["x"], + ) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert "x" not in claims # remove wins + + +# --------------------------------------------------------------------------- +# FR-14: Two-token model +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_channel_token_injected_when_configured(): + """When channel_token_audience is set, x-mcp-channel-token header is injected.""" + signer = _make_signer( + channel_token_audience="bedrock-gateway", + channel_token_ttl=60, + ) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "list_tables"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-mcp-channel-token" in result["extra_headers"] + channel_token = result["extra_headers"]["x-mcp-channel-token"].removeprefix("Bearer ") + channel_payload = _decode_unverified(channel_token) + assert channel_payload["aud"] == "bedrock-gateway" + + +@pytest.mark.asyncio +async def test_channel_token_absent_when_not_configured(): + """x-mcp-channel-token is not injected when channel_token_audience is unset.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "tool"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-mcp-channel-token" not in result["extra_headers"] + + +# --------------------------------------------------------------------------- +# FR-15: Incoming claim validation +# --------------------------------------------------------------------------- + + +def test_required_claims_pass_when_present(): + """_validate_required_claims() passes when all required claims are present.""" + signer = _make_signer(required_claims=["sub", "email"]) + # Should not raise + signer._validate_required_claims({"sub": "user", "email": "u@example.com"}) + + +def test_required_claims_raise_403_when_missing(): + """_validate_required_claims() raises HTTP 403 when a required claim is missing.""" + from fastapi import HTTPException + + signer = _make_signer(required_claims=["sub", "email"]) + with pytest.raises(HTTPException) as exc_info: + signer._validate_required_claims({"sub": "user"}) # email missing + + assert exc_info.value.status_code == 403 + assert "email" in str(exc_info.value.detail) + + +def test_required_claims_raise_when_no_jwt_claims(): + """_validate_required_claims() raises when jwt_claims is None and claims are required.""" + from fastapi import HTTPException + + signer = _make_signer(required_claims=["sub"]) + with pytest.raises(HTTPException): + signer._validate_required_claims(None) + + +def test_optional_claims_passed_through(): + """optional_claims are forwarded from incoming jwt_claims into the outbound JWT.""" + signer = _make_signer(optional_claims=["groups", "roles"]) + user_dict = _make_user_api_key_dict() + jwt_claims = {"sub": "u", "groups": ["admin"], "roles": ["editor"]} + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["groups"] == ["admin"] + assert claims["roles"] == ["editor"] + + +def test_optional_claims_not_injected_if_absent(): + """optional_claims are silently skipped when absent in incoming jwt_claims.""" + signer = _make_signer(optional_claims=["groups"]) + user_dict = _make_user_api_key_dict() + jwt_claims: Dict[str, Any] = {"sub": "u"} # no groups + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert "groups" not in claims + + +# --------------------------------------------------------------------------- +# FR-9: Debug headers +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_debug_header_injected_when_enabled(): + """x-litellm-mcp-debug header is injected when debug_headers=True.""" + signer = _make_signer(debug_headers=True) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "my_tool"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-litellm-mcp-debug" in result["extra_headers"] + debug_val = result["extra_headers"]["x-litellm-mcp-debug"] + assert "v=1" in debug_val + assert "kid=" in debug_val + assert "sub=" in debug_val + + +@pytest.mark.asyncio +async def test_debug_header_absent_when_disabled(): + """x-litellm-mcp-debug is NOT injected when debug_headers=False (default).""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "tool"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-litellm-mcp-debug" not in result["extra_headers"] + + +# --------------------------------------------------------------------------- +# P1 fix: extra_headers merging (multi-guardrail chains) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_extra_headers_are_merged_not_replaced(): + """ + Existing extra_headers from a prior guardrail are preserved — only + Authorization is added/overwritten, other keys survive. + """ + signer = _make_signer() + user_dict = _make_user_api_key_dict() + # Simulate a prior guardrail having injected a tracing header + data = { + "mcp_tool_name": "list", + "extra_headers": {"x-trace-id": "abc123", "x-correlation-id": "xyz"}, + } + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + headers = result["extra_headers"] + # Prior headers preserved + assert headers.get("x-trace-id") == "abc123" + assert headers.get("x-correlation-id") == "xyz" + # Authorization injected + assert "Authorization" in headers + + +# --------------------------------------------------------------------------- +# FR-5: Verify + re-sign — jwt_claims fallback from UserAPIKeyAuth +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sub_resolved_from_user_api_key_dict_jwt_claims(): + """ + When no raw token is present but UserAPIKeyAuth.jwt_claims has a sub, + the guardrail resolves sub from jwt_claims (LiteLLM-decoded JWT path). + """ + signer = _make_signer(end_user_claim_sources=["token:sub", "litellm:user_id"]) + user_dict = _make_user_api_key_dict(user_id="litellm-fallback") + # jwt_claims populated by LiteLLM's JWT auth machinery + user_dict.jwt_claims = {"sub": "idp-alice", "email": "alice@idp.com"} + data = {"mcp_tool_name": "query"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + token = result["extra_headers"]["Authorization"].removeprefix("Bearer ") + payload = _decode_unverified(token) + assert payload["sub"] == "idp-alice" + + +# --------------------------------------------------------------------------- +# initialize_guardrail factory — regression test for config.yaml wire-up +# --------------------------------------------------------------------------- + + +def test_initialize_guardrail_passes_all_params(): + """ + initialize_guardrail must wire every documented config.yaml param through + to MCPJWTSigner. Previously only issuer/audience/ttl_seconds were passed; + all FR-5/9/10/12/13/14/15 params were silently dropped. + """ + import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as mod + + mod._mcp_jwt_signer_instance = None + + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer import ( + initialize_guardrail, + ) + + litellm_params = MagicMock() + litellm_params.mode = "pre_mcp_call" + litellm_params.default_on = True + litellm_params.optional_params = None + # Set every non-default param directly on litellm_params + litellm_params.issuer = "https://litellm.example.com" + litellm_params.audience = "mcp-test" + litellm_params.ttl_seconds = 120 + litellm_params.access_token_discovery_uri = "https://idp.example.com/.well-known/openid-configuration" + litellm_params.token_introspection_endpoint = "https://idp.example.com/introspect" + litellm_params.verify_issuer = "https://idp.example.com" + litellm_params.verify_audience = "api://test" + litellm_params.end_user_claim_sources = ["token:email", "litellm:user_id"] + litellm_params.add_claims = {"deployment_id": "prod"} + litellm_params.set_claims = {"env": "production"} + litellm_params.remove_claims = ["nbf"] + litellm_params.channel_token_audience = "bedrock-gateway" + litellm_params.channel_token_ttl = 60 + litellm_params.required_claims = ["sub", "email"] + litellm_params.optional_claims = ["groups"] + litellm_params.debug_headers = True + litellm_params.allowed_scopes = ["mcp:tools/call"] + + guardrail = {"guardrail_name": "mcp-jwt-signer"} + + with patch("litellm.logging_callback_manager.add_litellm_callback"): + signer = initialize_guardrail(litellm_params, guardrail) + + assert signer.issuer == "https://litellm.example.com" + assert signer.audience == "mcp-test" + assert signer.ttl_seconds == 120 + assert signer.access_token_discovery_uri == "https://idp.example.com/.well-known/openid-configuration" + assert signer.token_introspection_endpoint == "https://idp.example.com/introspect" + assert signer.verify_issuer == "https://idp.example.com" + assert signer.verify_audience == "api://test" + assert signer.end_user_claim_sources == ["token:email", "litellm:user_id"] + assert signer.add_claims == {"deployment_id": "prod"} + assert signer.set_claims == {"env": "production"} + assert signer.remove_claims == ["nbf"] + assert signer.channel_token_audience == "bedrock-gateway" + assert signer.channel_token_ttl == 60 + assert signer.required_claims == ["sub", "email"] + assert signer.optional_claims == ["groups"] + assert signer.debug_headers is True + assert signer.allowed_scopes == ["mcp:tools/call"] + + +# --------------------------------------------------------------------------- +# FR-5: _fetch_jwks, _get_oidc_discovery, _verify_incoming_jwt, +# _introspect_opaque_token +# --------------------------------------------------------------------------- + +import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as _signer_mod + + +def _make_httpx_response(json_body: dict, status_code: int = 200): + """Build a minimal fake httpx Response object.""" + mock_resp = MagicMock() + mock_resp.status_code = status_code + mock_resp.json.return_value = json_body + mock_resp.raise_for_status = MagicMock() + if status_code >= 400: + from httpx import HTTPStatusError, Request, Response + + mock_resp.raise_for_status.side_effect = HTTPStatusError( + "error", request=MagicMock(), response=MagicMock() + ) + return mock_resp + + +# --- _fetch_jwks --- + + +@pytest.mark.asyncio +async def test_fetch_jwks_returns_keys_and_caches(): + """_fetch_jwks returns keys from the remote JWKS URI and caches the result.""" + _signer_mod._jwks_cache.clear() + + fake_keys = [{"kty": "RSA", "kid": "k1", "n": "abc", "e": "AQAB"}] + fake_resp = _make_httpx_response({"keys": fake_keys}) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=fake_resp) + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + keys = await _signer_mod._fetch_jwks("https://idp.example.com/jwks") + + assert keys == fake_keys + assert "https://idp.example.com/jwks" in _signer_mod._jwks_cache + _signer_mod._jwks_cache.clear() + + +@pytest.mark.asyncio +async def test_fetch_jwks_uses_cache_on_second_call(): + """_fetch_jwks returns the cached value without a second HTTP call.""" + _signer_mod._jwks_cache.clear() + fake_keys = [{"kty": "RSA", "kid": "k1"}] + _signer_mod._jwks_cache["https://idp.example.com/jwks"] = ( + fake_keys, + time.time(), + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock() + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + keys = await _signer_mod._fetch_jwks("https://idp.example.com/jwks") + + mock_client.get.assert_not_called() + assert keys == fake_keys + _signer_mod._jwks_cache.clear() + + +# --- _get_oidc_discovery --- + + +@pytest.mark.asyncio +async def test_get_oidc_discovery_caches_when_jwks_uri_present(): + """_get_oidc_discovery caches the doc when jwks_uri is in the response.""" + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration" + ) + signer._oidc_discovery_doc = None # ensure fresh + + discovery_doc = { + "issuer": "https://idp.example.com", + "jwks_uri": "https://idp.example.com/jwks", + } + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_oidc_discovery", + new_callable=AsyncMock, + return_value=discovery_doc, + ): + result = await signer._get_oidc_discovery() + + assert result["jwks_uri"] == "https://idp.example.com/jwks" + assert signer._oidc_discovery_doc == discovery_doc + + +@pytest.mark.asyncio +async def test_get_oidc_discovery_does_not_cache_when_jwks_uri_absent(): + """_get_oidc_discovery does NOT cache a doc that is missing jwks_uri.""" + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration" + ) + signer._oidc_discovery_doc = None + + bad_doc = {"issuer": "https://idp.example.com"} # no jwks_uri + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_oidc_discovery", + new_callable=AsyncMock, + return_value=bad_doc, + ) as mock_fetch: + result1 = await signer._get_oidc_discovery() + result2 = await signer._get_oidc_discovery() + + # Returns the bad doc each time without caching it + assert "jwks_uri" not in result1 + assert signer._oidc_discovery_doc is None # never cached + assert mock_fetch.call_count == 2 # retried on second call + + +# --- _verify_incoming_jwt --- + + +@pytest.mark.asyncio +async def test_verify_incoming_jwt_returns_payload_on_valid_token(): + """_verify_incoming_jwt decodes and returns claims from a valid JWT.""" + # Build a signer to get a real RSA key pair; use its key to mint the "incoming" JWT + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration", + verify_audience="api://test", + verify_issuer="https://idp.example.com", + ) + # Mint a JWT with signer's own key — we'll pretend it came from the IdP + now = int(time.time()) + incoming_claims = { + "sub": "idp-user-42", + "iss": "https://idp.example.com", + "aud": "api://test", + "iat": now, + "exp": now + 300, + } + incoming_token = jwt.encode(incoming_claims, signer._private_key, algorithm="RS256", headers={"kid": signer._kid}) + + # Build a JWKS from the same public key so verification passes + jwks = signer.get_jwks() + + with patch.object( + signer, + "_get_oidc_discovery", + new_callable=AsyncMock, + return_value={"jwks_uri": "https://idp.example.com/jwks"}, + ): + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_jwks", + new_callable=AsyncMock, + return_value=jwks["keys"], + ): + payload = await signer._verify_incoming_jwt(incoming_token) + + assert payload["sub"] == "idp-user-42" + + +@pytest.mark.asyncio +async def test_verify_incoming_jwt_raises_on_expired_token(): + """_verify_incoming_jwt raises PyJWTError on an expired token.""" + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration", + ) + expired_claims = { + "sub": "idp-user", + "iss": "https://idp.example.com", + "aud": "api://test", + "iat": int(time.time()) - 600, + "exp": int(time.time()) - 300, # expired + } + expired_token = jwt.encode(expired_claims, signer._private_key, algorithm="RS256") + jwks = signer.get_jwks() + + with patch.object( + signer, + "_get_oidc_discovery", + new_callable=AsyncMock, + return_value={"jwks_uri": "https://idp.example.com/jwks"}, + ): + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_jwks", + new_callable=AsyncMock, + return_value=jwks["keys"], + ): + with pytest.raises(jwt.PyJWTError): + await signer._verify_incoming_jwt(expired_token) + + +# --- _introspect_opaque_token --- + + +@pytest.mark.asyncio +async def test_introspect_opaque_token_returns_claims_when_active(): + """_introspect_opaque_token returns the introspection payload for active tokens.""" + signer = _make_signer( + token_introspection_endpoint="https://idp.example.com/introspect" + ) + + introspection_response = { + "active": True, + "sub": "service-account", + "scope": "read write", + } + fake_resp = _make_httpx_response(introspection_response) + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=fake_resp) + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + result = await signer._introspect_opaque_token("opaque-token-abc") + + assert result["sub"] == "service-account" + assert result["active"] is True + + +@pytest.mark.asyncio +async def test_introspect_opaque_token_raises_on_inactive_token(): + """_introspect_opaque_token raises ExpiredSignatureError when active=false.""" + signer = _make_signer( + token_introspection_endpoint="https://idp.example.com/introspect" + ) + + fake_resp = _make_httpx_response({"active": False}) + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=fake_resp) + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + with pytest.raises(jwt.ExpiredSignatureError): + await signer._introspect_opaque_token("opaque-token-xyz") + + +@pytest.mark.asyncio +async def test_introspect_opaque_token_raises_without_endpoint_configured(): + """_introspect_opaque_token raises ValueError when no endpoint is set.""" + signer = _make_signer() # no token_introspection_endpoint + + with pytest.raises(ValueError, match="token_introspection_endpoint"): + await signer._introspect_opaque_token("some-token") + + +# --- FR-5 end-to-end hook path --- + + +@pytest.mark.asyncio +async def test_hook_raises_401_when_jwt_verification_fails(): + """async_pre_call_hook raises HTTP 401 when incoming JWT verification fails.""" + from fastapi import HTTPException + + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration" + ) + + with patch.object( + signer, + "_verify_incoming_jwt", + new_callable=AsyncMock, + side_effect=jwt.InvalidSignatureError("bad signature"), + ): + with patch.object( + signer, + "_get_oidc_discovery", + new_callable=AsyncMock, + return_value={"jwks_uri": "https://idp.example.com/jwks"}, + ): + with pytest.raises(HTTPException) as exc_info: + await signer.async_pre_call_hook( + user_api_key_dict=_make_user_api_key_dict(), + cache=MagicMock(), + data={"mcp_tool_name": "tool", "incoming_bearer_token": "hdr.pld.sig"}, + call_type="call_mcp_tool", + ) + + assert exc_info.value.status_code == 401