diff --git a/docs/clients/auth/cimd.mdx b/docs/clients/auth/cimd.mdx new file mode 100644 index 0000000000..6980c66f2d --- /dev/null +++ b/docs/clients/auth/cimd.mdx @@ -0,0 +1,137 @@ +--- +title: CIMD Authentication +sidebarTitle: CIMD +description: Use Client ID Metadata Documents for verifiable, domain-based client identity. +icon: id-badge +--- + +import { VersionBadge } from "/snippets/version-badge.mdx" + + + + +CIMD authentication is only relevant for HTTP-based transports and requires a server that advertises CIMD support. + + +With standard OAuth, your client registers dynamically with every server it connects to, receiving a fresh `client_id` each time. This works, but the server has no way to verify *who* your client actually is — any client can claim any name during registration. + +CIMD (Client ID Metadata Documents) flips this around. You host a small JSON document at an HTTPS URL you control, and that URL becomes your `client_id`. When your client connects to a server, the server fetches your metadata document and can verify your identity through your domain ownership. Users see a verified domain badge in the consent screen instead of an unverified client name. + +## Client Usage + +Pass your CIMD document URL to the `client_metadata_url` parameter of `OAuth`: + +```python +from fastmcp import Client +from fastmcp.client.auth import OAuth + +async with Client( + "https://mcp-server.example.com/mcp", + auth=OAuth( + client_metadata_url="https://myapp.example.com/oauth/client.json", + ), +) as client: + await client.ping() +``` + +When the server supports CIMD, the client uses your metadata URL as its `client_id` instead of performing Dynamic Client Registration. The server fetches your document, validates it, and proceeds with the standard OAuth authorization flow. + + +You don't need to pass `mcp_url` when using `OAuth` with `Client(auth=...)` — the transport provides the server URL automatically. + + +## Creating a CIMD Document + +A CIMD document is a JSON file that describes your client. The most important field is `client_id`, which must exactly match the URL where you host the document. + +Use the FastMCP CLI to generate one: + +```bash +fastmcp auth cimd create \ + --name "My Application" \ + --redirect-uri "http://localhost:*/callback" \ + --client-id "https://myapp.example.com/oauth/client.json" +``` + +This produces: + +```json +{ + "client_id": "https://myapp.example.com/oauth/client.json", + "client_name": "My Application", + "redirect_uris": ["http://localhost:*/callback"], + "token_endpoint_auth_method": "none", + "grant_types": ["authorization_code"], + "response_types": ["code"] +} +``` + +If you omit `--client-id`, the CLI generates a placeholder value and reminds you to update it before hosting. + +### CLI Options + +The `create` command accepts these flags: + +| Flag | Description | +|------|-------------| +| `--name` | Human-readable client name (required) | +| `--redirect-uri`, `-r` | Allowed redirect URIs — can be specified multiple times (required) | +| `--client-id` | The URL where you'll host this document (sets `client_id` directly) | +| `--output`, `-o` | Write to a file instead of stdout | +| `--scope` | Space-separated list of scopes the client may request | +| `--client-uri` | URL of the client's home page | +| `--logo-uri` | URL of the client's logo image | +| `--no-pretty` | Output compact JSON | + +### Redirect URIs + +The `redirect_uris` field supports wildcard port matching for localhost. The pattern `http://localhost:*/callback` matches any port, which is useful for development clients that bind to random available ports (which is what FastMCP's `OAuth` helper does by default). + +## Hosting Requirements + +CIMD documents must be hosted at a publicly accessible HTTPS URL with a non-root path: + +- **HTTPS required** — HTTP URLs are rejected for security +- **Non-root path** — The URL must have a path component (e.g., `/oauth/client.json`, not just `/`) +- **Public accessibility** — The server must be able to fetch the document over the internet +- **Matching `client_id`** — The `client_id` field in the document must exactly match the hosting URL + +Common hosting options include static file hosting services like GitHub Pages, Cloudflare Pages, Vercel, or S3 — anywhere you can serve a JSON file over HTTPS. + +## Validating Your Document + +Before deploying, verify your hosted document passes validation: + +```bash +fastmcp auth cimd validate https://myapp.example.com/oauth/client.json +``` + +The validator fetches the document and checks that: +- The URL is valid (HTTPS, non-root path) +- The document is well-formed JSON conforming to the CIMD schema +- The `client_id` in the document matches the URL it was fetched from + +## How It Works + +When your client connects to a CIMD-enabled server, the flow works like this: + + + +Your client sends its `client_metadata_url` as the `client_id` in the OAuth authorization request. + + +The server sees that the `client_id` is an HTTPS URL with a path — the signature of a CIMD client — and skips Dynamic Client Registration. + + +The server fetches your JSON document from the URL, validates that `client_id` matches the URL, and extracts your client metadata (name, redirect URIs, scopes). + + +The standard OAuth flow continues: browser opens for user consent, authorization code exchange, token issuance. The consent screen shows your verified domain. + + + +The server caches your CIMD document according to HTTP cache headers, so subsequent requests don't require re-fetching. + +## Server Configuration + +CIMD is a server-side feature that your MCP server must support. FastMCP's OAuth proxy providers (GitHub, Google, Auth0, etc.) support CIMD by default. See the [OAuth Proxy CIMD documentation](/servers/auth/oauth-proxy#cimd-support) for server-side configuration, including private key JWT authentication and security details. diff --git a/docs/clients/auth/oauth.mdx b/docs/clients/auth/oauth.mdx index 6a60fb0f86..25804adc3f 100644 --- a/docs/clients/auth/oauth.mdx +++ b/docs/clients/auth/oauth.mdx @@ -41,20 +41,25 @@ To fully configure the OAuth flow, use the `OAuth` helper and pass it to the `au from fastmcp import Client from fastmcp.client.auth import OAuth -oauth = OAuth(mcp_url="https://your-server.fastmcp.app/mcp") +oauth = OAuth(scopes=["user"]) async with Client("https://your-server.fastmcp.app/mcp", auth=oauth) as client: await client.ping() ``` + +You don't need to pass `mcp_url` when using `OAuth` with `Client(auth=...)` — the transport provides the server URL automatically. + + #### `OAuth` Parameters -- **`mcp_url`** (`str`): The full URL of the target MCP server endpoint. Used to discover OAuth server metadata - **`scopes`** (`str | list[str]`, optional): OAuth scopes to request. Can be space-separated string or list of strings - **`client_name`** (`str`, optional): Client name for dynamic registration. Defaults to `"FastMCP Client"` +- **`client_metadata_url`** (`str`, optional): URL-based client identity (CIMD). See [CIMD Authentication](/clients/auth/cimd) for details - **`token_storage`** (`AsyncKeyValue`, optional): Storage backend for persisting OAuth tokens. Defaults to in-memory storage (tokens lost on restart). See [Token Storage](#token-storage) for encrypted storage options - **`additional_client_metadata`** (`dict[str, Any]`, optional): Extra metadata for client registration - **`callback_port`** (`int`, optional): Fixed port for OAuth callback server. If not specified, uses a random available port +- **`httpx_client_factory`** (`McpHttpClientFactory`, optional): Factory for creating httpx clients ## OAuth Flow @@ -68,8 +73,8 @@ The client first checks the configured `token_storage` backend for existing, val If no valid tokens exist, the client attempts to discover the OAuth server's endpoints using a well-known URI (e.g., `/.well-known/oauth-authorization-server`) based on the `mcp_url`. - -If the OAuth server supports it and the client isn't already registered (or credentials aren't cached), the client performs dynamic client registration according to RFC 7591. + +If the OAuth server supports it and the client isn't already registered (or credentials aren't cached), the client performs dynamic client registration according to RFC 7591. Alternatively, if a `client_metadata_url` is configured and the server supports CIMD, the client uses its metadata URL as its identity instead of registering. A temporary local HTTP server is started on an available port (or the port specified via `callback_port`). This server's address (e.g., `http://127.0.0.1:/callback`) acts as the `redirect_uri` for the OAuth flow. @@ -115,10 +120,7 @@ encrypted_storage = FernetEncryptionWrapper( fernet=Fernet(os.environ["OAUTH_STORAGE_ENCRYPTION_KEY"]) ) -oauth = OAuth( - mcp_url="https://your-server.fastmcp.app/mcp", - token_storage=encrypted_storage -) +oauth = OAuth(token_storage=encrypted_storage) async with Client("https://your-server.fastmcp.app/mcp", auth=oauth) as client: await client.ping() @@ -129,3 +131,24 @@ You can use any `AsyncKeyValue`-compatible backend from the [key-value library]( When selecting a storage backend, review the [py-key-value documentation](https://github.com/strawgate/py-key-value) to understand the maturity level and limitations of your chosen backend. Some backends may be in preview or have constraints that affect production suitability. + +## CIMD Authentication + + + +Client ID Metadata Documents (CIMD) provide an alternative to Dynamic Client Registration. Instead of registering with each server, your client hosts a static JSON document at an HTTPS URL. That URL becomes your client's identity, and servers can verify who you are through your domain ownership. + +```python +from fastmcp import Client +from fastmcp.client.auth import OAuth + +async with Client( + "https://mcp-server.example.com/mcp", + auth=OAuth( + client_metadata_url="https://myapp.example.com/oauth/client.json", + ), +) as client: + await client.ping() +``` + +See the [CIMD Authentication](/clients/auth/cimd) page for complete documentation on creating, hosting, and validating CIMD documents. diff --git a/docs/development/v3-notes/v3-features.mdx b/docs/development/v3-notes/v3-features.mdx index 2fa032e6eb..05399fce69 100644 --- a/docs/development/v3-notes/v3-features.mdx +++ b/docs/development/v3-notes/v3-features.mdx @@ -73,6 +73,53 @@ fastmcp install stdio server.py The command automatically detects the project directory and generates the appropriate `uv run` invocation, making it easy to integrate FastMCP servers with MCP clients. +### CIMD (Client ID Metadata Documents) + +CIMD provides an alternative to Dynamic Client Registration for OAuth-authenticated MCP servers. Instead of registering with each server dynamically, clients host a static JSON document at an HTTPS URL. That URL becomes the client's `client_id`, and servers verify identity through domain ownership. + +**Client usage:** + +```python +from fastmcp import Client +from fastmcp.client.auth import OAuth + +async with Client( + "https://mcp-server.example.com/mcp", + auth=OAuth( + client_metadata_url="https://myapp.example.com/oauth/client.json", + ), +) as client: + await client.ping() +``` + +The `OAuth` helper now supports deferred binding — `mcp_url` is optional when using `OAuth` with `Client(auth=...)`, since the transport provides the server URL automatically. + +**CLI tools for document management:** + +```bash +# Generate a CIMD document +fastmcp auth cimd create --name "My App" \ + --redirect-uri "http://localhost:*/callback" \ + --client-id "https://myapp.example.com/oauth/client.json" \ + --output client.json + +# Validate a hosted document +fastmcp auth cimd validate https://myapp.example.com/oauth/client.json +``` + +**Server-side support:** + +CIMD is enabled by default on `OAuthProxy` and its provider subclasses (GitHub, Google, etc.). The server-side implementation includes SSRF-hardened document fetching with DNS pinning, dual redirect URI validation (both CIMD document patterns and proxy patterns must match), HTTP cache-aware revalidation, and `private_key_jwt` assertion validation for clients that need stronger authentication than public client auth. + +Key details: +- CIMD URLs must be HTTPS with a non-root path +- `token_endpoint_auth_method` limited to `none` or `private_key_jwt` (no shared secrets) +- `redirect_uris` in CIMD documents support wildcard port patterns (`http://localhost:*/callback`) +- Servers fetch and cache documents with standard HTTP caching (ETag, Last-Modified, Cache-Control) +- CIMD is a protocol-level feature — any auth provider implementing the spec can support it + +Documentation: [CIMD Authentication](/clients/auth/cimd), [OAuth Proxy CIMD config](/servers/auth/oauth-proxy#cimd-support) + ### MCP Apps (SDK Compatibility) Support for [MCP Apps](https://modelcontextprotocol.io/specification/2025-06-18/server/apps) — the spec extension that lets MCP servers deliver interactive UIs via sandboxed iframes. Extension negotiation, typed UI metadata on tools and resources, and the `ui://` resource scheme. No component DSL, renderer, or `FastMCPApp` class yet — those are future phases. diff --git a/docs/docs.json b/docs/docs.json index f20343ca1e..b9c1bc4dd6 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -198,6 +198,7 @@ "icon": "key", "pages": [ "clients/auth/oauth", + "clients/auth/cimd", "clients/auth/bearer" ] } diff --git a/docs/patterns/cli.mdx b/docs/patterns/cli.mdx index 72badb8703..8ccf9d01f2 100644 --- a/docs/patterns/cli.mdx +++ b/docs/patterns/cli.mdx @@ -25,6 +25,7 @@ fastmcp --help | `install` | Install a server in MCP client applications | **Supports:** Local files and fastmcp.json configs. **Deps:** Creates an isolated environment; dependencies must be explicitly specified with `--with` and/or `--with-editable`. With fastmcp.json: Uses configured dependencies | | `inspect` | Generate a JSON report about a FastMCP server | **Supports:** Local files and fastmcp.json configs. **Deps:** Uses your current environment; you are responsible for ensuring all dependencies are available | | `project prepare` | Create a persistent uv project from fastmcp.json environment config | **Supports:** fastmcp.json configs only. **Deps:** Creates a uv project directory with all dependencies pre-installed for reuse with `--project` flag | +| `auth cimd` | Create and validate CIMD documents for OAuth authentication | N/A | | `version` | Display version information | N/A | ## `fastmcp list` @@ -750,6 +751,87 @@ The prepare command creates a uv project with: This is useful when you want to separate environment setup from server execution, such as in deployment scenarios where dependencies are installed once and the server is run multiple times. +## `fastmcp auth` + + + +Authentication-related utilities and configuration commands. + +### `fastmcp auth cimd create` + +Generate a CIMD (Client ID Metadata Document) for hosting. This creates a JSON document that you can host at an HTTPS URL to use as your OAuth client identity. + +```bash +fastmcp auth cimd create --name "My App" --redirect-uri "http://localhost:*/callback" +``` + +#### Options + +| Option | Flag | Description | +| ------ | ---- | ----------- | +| Name | `--name` | **Required.** Human-readable name of the client application | +| Redirect URI | `--redirect-uri` | **Required.** Allowed redirect URIs (can specify multiple) | +| Client URI | `--client-uri` | URL of the client's home page | +| Logo URI | `--logo-uri` | URL of the client's logo image | +| Scope | `--scope` | Space-separated list of scopes the client may request | +| Output | `--output`, `-o` | Output file path (default: stdout) | +| Pretty | `--pretty` | Pretty-print JSON output (default: true) | + +#### Example + +```bash +# Generate document to stdout +fastmcp auth cimd create \ + --name "My Production App" \ + --redirect-uri "http://localhost:*/callback" \ + --redirect-uri "https://myapp.example.com/callback" \ + --client-uri "https://myapp.example.com" \ + --scope "read write" + +# Save to file +fastmcp auth cimd create \ + --name "My App" \ + --redirect-uri "http://localhost:*/callback" \ + --output client.json +``` + +The generated document includes a placeholder `client_id` that you must update to match the URL where you'll host the document before deploying. + +### `fastmcp auth cimd validate` + +Validate a hosted CIMD document by fetching it from its URL and checking that it conforms to the CIMD specification. + +```bash +fastmcp auth cimd validate https://myapp.example.com/oauth/client.json +``` + +#### Options + +| Option | Flag | Description | +| ------ | ---- | ----------- | +| Timeout | `--timeout`, `-t` | HTTP request timeout in seconds (default: 10) | + +The validator checks: + +- The URL is a valid CIMD URL (HTTPS with non-root path) +- The document is valid JSON and conforms to the CIMD schema +- The `client_id` field in the document matches the URL +- No shared-secret authentication methods are used + +On success, it displays the document details: + +``` +→ Fetching https://myapp.example.com/oauth/client.json... +✓ Valid CIMD document + +Document details: + client_id: https://myapp.example.com/oauth/client.json + client_name: My App + token_endpoint_auth_method: none + redirect_uris: + • http://localhost:*/callback +``` + ## `fastmcp version` Display version information about FastMCP and related components. diff --git a/docs/servers/auth/oauth-proxy.mdx b/docs/servers/auth/oauth-proxy.mdx index 3c5b328a76..86a8865f30 100644 --- a/docs/servers/auth/oauth-proxy.mdx +++ b/docs/servers/auth/oauth-proxy.mdx @@ -524,6 +524,74 @@ auth = OAuthProxy( Check your server logs for "Client registered with redirect_uri" messages to identify what URLs your clients use. +## CIMD Support + + + +The OAuth proxy supports **Client ID Metadata Documents (CIMD)**, an alternative to Dynamic Client Registration where clients host a static JSON document at an HTTPS URL. Instead of registering dynamically, clients simply provide their CIMD URL as their `client_id`, and the server fetches and validates the metadata. + +CIMD clients appear in the consent screen with a verified domain badge, giving users confidence about which application is requesting access. This provides stronger identity verification than DCR, where any client can claim any name. + +### How CIMD Works + +When a client presents an HTTPS URL as its `client_id` (for example, `https://myapp.example.com/oauth/client.json`), the OAuth proxy recognizes it as a CIMD client and: + +1. Fetches the JSON document from that URL +2. Validates that the document's `client_id` field matches the URL +3. Extracts client metadata (name, redirect URIs, scopes, etc.) +4. Stores the client persistently alongside DCR clients +5. Shows the verified domain in the consent screen + +This flow happens transparently. MCP clients that support CIMD simply provide their metadata URL instead of registering, and the OAuth proxy handles the rest. + +### CIMD Configuration + +CIMD support is enabled by default for `OAuthProxy`. + + + + Whether to accept CIMD URLs as client identifiers. When enabled, clients can use HTTPS URLs pointing to metadata documents as their `client_id` instead of registering via DCR. + + + +### Private Key JWT Authentication + +CIMD clients can authenticate using `private_key_jwt` instead of the default `none` authentication method. This provides cryptographic proof of client identity by signing JWT assertions with a private key, while the server verifies using the client's public key from their CIMD document. + +To use `private_key_jwt`, the CIMD document must include either a `jwks_uri` (URL to fetch the public key set) or inline `jwks` (the key set directly in the document): + +```json +{ + "client_id": "https://myapp.example.com/oauth/client.json", + "client_name": "My Secure App", + "redirect_uris": ["http://localhost:*/callback"], + "token_endpoint_auth_method": "private_key_jwt", + "jwks_uri": "https://myapp.example.com/.well-known/jwks.json" +} +``` + +The OAuth proxy validates JWT assertions according to RFC 7523, checking the signature, issuer, audience, subject claims, and preventing replay attacks via JTI tracking. + +### Security Considerations + +CIMD provides several security advantages over DCR: + +- **Verified identity**: The domain in the `client_id` URL is verified by HTTPS, so users know which organization is requesting access +- **No registration required**: Clients don't need to store or manage dynamically-issued credentials +- **Redirect URI enforcement**: CIMD documents must declare `redirect_uris`, which are enforced by the proxy (wildcard patterns supported) +- **SSRF protection**: The OAuth proxy blocks fetches to localhost, private IPs, and reserved addresses +- **Replay prevention**: For `private_key_jwt` clients, JTI claims are tracked to prevent assertion replay +- **Cache-aware fetching**: CIMD documents are cached according to HTTP cache headers and revalidated when required + +To disable CIMD support entirely (for example, to require all clients to register via DCR): + +```python +auth = OAuthProxy( + ..., + enable_cimd=False, +) +``` + ## Security ### Key and Storage Management diff --git a/docs/servers/auth/oidc-proxy.mdx b/docs/servers/auth/oidc-proxy.mdx index 6941c7f79f..86661bc16d 100644 --- a/docs/servers/auth/oidc-proxy.mdx +++ b/docs/servers/auth/oidc-proxy.mdx @@ -232,6 +232,22 @@ OAuth scopes are configured with `required_scopes` to automatically request the Dynamic clients created by the proxy will automatically include these scopes in their authorization requests. +## CIMD Support + + + +The OIDC proxy inherits full CIMD (Client ID Metadata Document) support from `OAuthProxy`. Clients can use HTTPS URLs as their `client_id` instead of registering dynamically, and the proxy will fetch and validate their metadata document. + +See the [OAuth Proxy CIMD documentation](/servers/auth/oauth-proxy#cimd-support) for complete details on how CIMD works, including private key JWT authentication and security considerations. + +The CIMD-related parameters available on `OIDCProxy` are: + + + + Whether to accept CIMD URLs as client identifiers. + + + ## Production Configuration For production deployments, load sensitive credentials from environment variables: diff --git a/examples/auth/github_oauth/client.py b/examples/auth/github_oauth/client.py index 5f1f39bb2f..a7ab5c47ee 100644 --- a/examples/auth/github_oauth/client.py +++ b/examples/auth/github_oauth/client.py @@ -8,14 +8,20 @@ import asyncio -from fastmcp.client import Client +from fastmcp.client import Client, OAuth -SERVER_URL = "http://127.0.0.1:8000/mcp" +SERVER_URL = "http://localhost:8000/mcp" async def main(): try: - async with Client(SERVER_URL, auth="oauth") as client: + async with Client( + SERVER_URL, + auth=OAuth( + # Replace with your own CIMD document URL + client_metadata_url="https://www.jlowin.dev/mcp-client.json", + ), + ) as client: assert await client.ping() print("✅ Successfully authenticated!") diff --git a/loq.toml b/loq.toml index 4c6d7e28ba..bbfe81827b 100644 --- a/loq.toml +++ b/loq.toml @@ -76,7 +76,7 @@ max_lines = 1584 [[rules]] path = "src/fastmcp/server/auth/oauth_proxy/proxy.py" -max_lines = 1600 +max_lines = 1740 [[rules]] path = "tests/server/test_dependencies.py" diff --git a/src/fastmcp/cli/auth.py b/src/fastmcp/cli/auth.py new file mode 100644 index 0000000000..4ea401b041 --- /dev/null +++ b/src/fastmcp/cli/auth.py @@ -0,0 +1,13 @@ +"""Authentication-related CLI commands.""" + +import cyclopts + +from fastmcp.cli.cimd import cimd_app + +auth_app = cyclopts.App( + name="auth", + help="Authentication-related utilities and configuration.", +) + +# Nest CIMD commands under auth +auth_app.command(cimd_app) diff --git a/src/fastmcp/cli/cimd.py b/src/fastmcp/cli/cimd.py new file mode 100644 index 0000000000..d2def490cf --- /dev/null +++ b/src/fastmcp/cli/cimd.py @@ -0,0 +1,218 @@ +"""CIMD (Client ID Metadata Document) CLI commands.""" + +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path +from typing import Annotated + +import cyclopts +from rich.console import Console + +from fastmcp.server.auth.cimd import ( + CIMDFetcher, + CIMDFetchError, + CIMDValidationError, +) +from fastmcp.utilities.logging import get_logger + +logger = get_logger("cli.cimd") +console = Console() + + +cimd_app = cyclopts.App( + name="cimd", + help="CIMD (Client ID Metadata Document) utilities for OAuth authentication.", +) + + +@cimd_app.command(name="create") +def create_command( + *, + name: Annotated[ + str, + cyclopts.Parameter(help="Human-readable name of the client application"), + ], + redirect_uri: Annotated[ + list[str], + cyclopts.Parameter( + name=["--redirect-uri", "-r"], + help="Allowed redirect URIs (can specify multiple)", + ), + ], + client_id: Annotated[ + str | None, + cyclopts.Parameter( + name="--client-id", + help="The URL where this document will be hosted (sets client_id directly)", + ), + ] = None, + client_uri: Annotated[ + str | None, + cyclopts.Parameter( + name="--client-uri", + help="URL of the client's home page", + ), + ] = None, + logo_uri: Annotated[ + str | None, + cyclopts.Parameter( + name="--logo-uri", + help="URL of the client's logo image", + ), + ] = None, + scope: Annotated[ + str | None, + cyclopts.Parameter( + name="--scope", + help="Space-separated list of scopes the client may request", + ), + ] = None, + output: Annotated[ + str | None, + cyclopts.Parameter( + name=["--output", "-o"], + help="Output file path (default: stdout)", + ), + ] = None, + pretty: Annotated[ + bool, + cyclopts.Parameter( + help="Pretty-print JSON output", + ), + ] = True, +) -> None: + """Generate a CIMD document for hosting. + + Create a Client ID Metadata Document that you can host at an HTTPS URL. + The URL where you host this document becomes your client_id. + + Example: + fastmcp cimd create --name "My App" -r "http://localhost:*/callback" + + After creating the document, host it at an HTTPS URL with a non-root path, + for example: https://myapp.example.com/oauth/client.json + """ + # Build the document + doc = { + "client_id": client_id or "https://YOUR-DOMAIN.com/path/to/client.json", + "client_name": name, + "redirect_uris": redirect_uri, + "token_endpoint_auth_method": "none", + "grant_types": ["authorization_code"], + "response_types": ["code"], + } + + # Add optional fields + if client_uri: + doc["client_uri"] = client_uri + if logo_uri: + doc["logo_uri"] = logo_uri + if scope: + doc["scope"] = scope + + # Format output + json_output = json.dumps(doc, indent=2) if pretty else json.dumps(doc) + + # Write output + if output: + output_path = Path(output).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + f.write(json_output) + f.write("\n") + console.print(f"[green]✓[/green] CIMD document written to {output}") + if not client_id: + console.print( + "\n[yellow]Important:[/yellow] client_id is a placeholder. Update it to the URL where you will host this document, or re-run with --client-id." + ) + else: + print(json_output) + if not client_id: + # Print instructions to stderr so they don't interfere with piping + stderr_console = Console(stderr=True) + stderr_console.print( + "\n[yellow]Important:[/yellow] client_id is a placeholder." + " Update it to the URL where you will host this document," + " or re-run with --client-id." + ) + + +@cimd_app.command(name="validate") +def validate_command( + url: Annotated[ + str, + cyclopts.Parameter(help="URL of the CIMD document to validate"), + ], + *, + timeout: Annotated[ + float, + cyclopts.Parameter( + name=["--timeout", "-t"], + help="HTTP request timeout in seconds", + ), + ] = 10.0, +) -> None: + """Validate a hosted CIMD document. + + Fetches the document from the given URL and validates: + - URL is valid CIMD URL (HTTPS, non-root path) + - Document is valid JSON + - Document conforms to CIMD schema + - client_id in document matches the URL + + Example: + fastmcp cimd validate https://myapp.example.com/oauth/client.json + """ + + async def _validate() -> bool: + fetcher = CIMDFetcher(timeout=timeout) + + # Check URL format first + if not fetcher.is_cimd_client_id(url): + console.print(f"[red]✗[/red] Invalid CIMD URL: {url}") + console.print() + console.print("CIMD URLs must:") + console.print(" • Use HTTPS (not HTTP)") + console.print(" • Have a non-root path (e.g., /client.json, not just /)") + return False + + console.print(f"[blue]→[/blue] Fetching {url}...") + + try: + doc = await fetcher.fetch(url) + except CIMDFetchError as e: + console.print(f"[red]✗[/red] Failed to fetch document: {e}") + return False + except CIMDValidationError as e: + console.print(f"[red]✗[/red] Validation error: {e}") + return False + + # Success - show document details + console.print("[green]✓[/green] Valid CIMD document") + console.print() + console.print("[bold]Document details:[/bold]") + console.print(f" client_id: {doc.client_id}") + console.print(f" client_name: {doc.client_name or '(not set)'}") + console.print(f" token_endpoint_auth_method: {doc.token_endpoint_auth_method}") + + if doc.redirect_uris: + console.print(" redirect_uris:") + for uri in doc.redirect_uris: + console.print(f" • {uri}") + else: + console.print(" redirect_uris: (none)") + + if doc.scope: + console.print(f" scope: {doc.scope}") + + if doc.client_uri: + console.print(f" client_uri: {doc.client_uri}") + + return True + + success = asyncio.run(_validate()) + if not success: + sys.exit(1) diff --git a/src/fastmcp/cli/cli.py b/src/fastmcp/cli/cli.py index 147d9dc2d8..5b9c24e4a8 100644 --- a/src/fastmcp/cli/cli.py +++ b/src/fastmcp/cli/cli.py @@ -19,6 +19,7 @@ import fastmcp from fastmcp.cli import run as run_module +from fastmcp.cli.auth import auth_app from fastmcp.cli.client import call_command, discover_command, list_command from fastmcp.cli.generate import generate_cli_command from fastmcp.cli.install import install_app @@ -960,6 +961,9 @@ async def prepare( app.command(discover_command, name="discover") app.command(generate_cli_command, name="generate-cli") +# Add auth subcommand group (includes CIMD commands) +app.command(auth_app) + if __name__ == "__main__": app() diff --git a/src/fastmcp/client/auth/oauth.py b/src/fastmcp/client/auth/oauth.py index 393844d073..9fc90b4e89 100644 --- a/src/fastmcp/client/auth/oauth.py +++ b/src/fastmcp/client/auth/oauth.py @@ -143,56 +143,82 @@ class OAuth(OAuthClientProvider): a browser for user authorization and running a local callback server. """ + _bound: bool + def __init__( self, - mcp_url: str, + mcp_url: str | None = None, scopes: str | list[str] | None = None, client_name: str = "FastMCP Client", token_storage: AsyncKeyValue | None = None, additional_client_metadata: dict[str, Any] | None = None, callback_port: int | None = None, httpx_client_factory: McpHttpClientFactory | None = None, + client_metadata_url: str | None = None, ): """ Initialize OAuth client provider for an MCP server. Args: - mcp_url: Full URL to the MCP endpoint (e.g. "http://host/mcp/sse/") + mcp_url: Full URL to the MCP endpoint (e.g. "http://host/mcp/sse/"). + Optional when OAuth is passed to Client(auth=...), which provides + the URL automatically from the transport. scopes: OAuth scopes to request. Can be a space-separated string or a list of strings. client_name: Name for this client during registration token_storage: An AsyncKeyValue-compatible token store, tokens are stored in memory if not provided additional_client_metadata: Extra fields for OAuthClientMetadata callback_port: Fixed port for OAuth callback (default: random available port) + client_metadata_url: A CIMD (Client ID Metadata Document) URL. When + provided, this URL is used as the client_id instead of performing + Dynamic Client Registration. Must be an HTTPS URL with a non-root + path (e.g. "https://myapp.example.com/oauth/client.json"). + """ + # Store config for deferred binding if mcp_url not yet known + self._scopes = scopes + self._client_name = client_name + self._token_storage = token_storage + self._additional_client_metadata = additional_client_metadata + self._callback_port = callback_port + self._client_metadata_url = client_metadata_url + self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient + self._bound = False + + if mcp_url is not None: + self._bind(mcp_url) + + def _bind(self, mcp_url: str) -> None: + """Bind this OAuth provider to a specific MCP server URL. + + Called automatically when mcp_url is provided to __init__, or by the + transport when OAuth is used without an explicit URL. """ - # Normalize the MCP URL (strip trailing slashes for consistency) + if self._bound: + return + mcp_url = mcp_url.rstrip("/") - # Setup OAuth client - self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient - self.redirect_port = callback_port or find_available_port() + self.redirect_port = self._callback_port or find_available_port() redirect_uri = f"http://localhost:{self.redirect_port}/callback" scopes_str: str - if isinstance(scopes, list): - scopes_str = " ".join(scopes) - elif scopes is not None: - scopes_str = str(scopes) + if isinstance(self._scopes, list): + scopes_str = " ".join(self._scopes) + elif self._scopes is not None: + scopes_str = str(self._scopes) else: scopes_str = "" client_metadata = OAuthClientMetadata( - client_name=client_name, + client_name=self._client_name, redirect_uris=[AnyHttpUrl(redirect_uri)], grant_types=["authorization_code", "refresh_token"], response_types=["code"], - # token_endpoint_auth_method="client_secret_post", scope=scopes_str, - **(additional_client_metadata or {}), + **(self._additional_client_metadata or {}), ) - # Create server-specific token storage - token_storage = token_storage or MemoryStore() + token_storage = self._token_storage or MemoryStore() if isinstance(token_storage, MemoryStore): from warnings import warn @@ -204,23 +230,23 @@ def __init__( stacklevel=2, ) - # Use full URL for token storage to properly separate tokens per MCP endpoint self.token_storage_adapter: TokenStorageAdapter = TokenStorageAdapter( async_key_value=token_storage, server_url=mcp_url ) - # Store full MCP URL for use in callback_handler display self.mcp_url = mcp_url - # Initialize parent class with full URL for proper OAuth metadata discovery super().__init__( server_url=mcp_url, client_metadata=client_metadata, storage=self.token_storage_adapter, redirect_handler=self.redirect_handler, callback_handler=self.callback_handler, + client_metadata_url=self._client_metadata_url, ) + self._bound = True + async def _initialize(self) -> None: """Load stored tokens and client info, properly setting token expiry.""" # Call parent's _initialize to load tokens and client info @@ -298,6 +324,11 @@ async def async_auth_flow( If the OAuth flow fails due to invalid/stale client credentials, clears the cache and retries once with fresh registration. """ + if not self._bound: + raise RuntimeError( + "OAuth provider has no server URL. Either pass mcp_url to OAuth() " + "or use it with Client(auth=...) which provides the URL automatically." + ) try: # First attempt with potentially cached credentials async with aclosing(super().async_auth_flow(request)) as gen: diff --git a/src/fastmcp/client/transports/http.py b/src/fastmcp/client/transports/http.py index 89ad8fc621..83dbb7cc86 100644 --- a/src/fastmcp/client/transports/http.py +++ b/src/fastmcp/client/transports/http.py @@ -76,11 +76,17 @@ def __init__( self._get_session_id_cb: Callable[[], str | None] | None = None def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): + resolved: httpx.Auth | None if auth == "oauth": - auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) + resolved = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) + elif isinstance(auth, OAuth): + auth._bind(self.url) + resolved = auth elif isinstance(auth, str): - auth = BearerAuth(auth) - self.auth = auth + resolved = BearerAuth(auth) + else: + resolved = auth + self.auth: httpx.Auth | None = resolved @contextlib.asynccontextmanager async def connect_session( diff --git a/src/fastmcp/client/transports/sse.py b/src/fastmcp/client/transports/sse.py index ec932e6d2d..45db01beef 100644 --- a/src/fastmcp/client/transports/sse.py +++ b/src/fastmcp/client/transports/sse.py @@ -48,11 +48,17 @@ def __init__( self.sse_read_timeout = normalize_timeout_to_timedelta(sse_read_timeout) def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): + resolved: httpx.Auth | None if auth == "oauth": - auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) + resolved = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) + elif isinstance(auth, OAuth): + auth._bind(self.url) + resolved = auth elif isinstance(auth, str): - auth = BearerAuth(auth) - self.auth = auth + resolved = BearerAuth(auth) + else: + resolved = auth + self.auth: httpx.Auth | None = resolved @contextlib.asynccontextmanager async def connect_session( diff --git a/src/fastmcp/server/auth/auth.py b/src/fastmcp/server/auth/auth.py index 9a804f05d7..b8b8b1f8c8 100644 --- a/src/fastmcp/server/auth/auth.py +++ b/src/fastmcp/server/auth/auth.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse from mcp.server.auth.handlers.token import TokenErrorResponse @@ -9,7 +9,13 @@ from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend -from mcp.server.auth.middleware.client_auth import ClientAuthenticator +from mcp.server.auth.middleware.client_auth import ( + AuthenticationError, + ClientAuthenticator, +) +from mcp.server.auth.middleware.client_auth import ( + ClientAuthenticator as _SDKClientAuthenticator, +) from mcp.server.auth.provider import ( AccessToken as _SDKAccessToken, ) @@ -30,13 +36,18 @@ ClientRegistrationOptions, RevocationOptions, ) +from mcp.shared.auth import OAuthClientInformationFull from pydantic import AnyHttpUrl, Field from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import Request from starlette.routing import Route from fastmcp.utilities.logging import get_logger +if TYPE_CHECKING: + from fastmcp.server.auth.cimd import CIMDClientManager + logger = get_logger(__name__) @@ -108,6 +119,91 @@ async def handle(self, request: Any): return response +# Expected assertion type for private_key_jwt +JWT_BEARER_ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + + +class PrivateKeyJWTClientAuthenticator(_SDKClientAuthenticator): + """Client authenticator with private_key_jwt support for CIMD clients. + + Extends the SDK's ClientAuthenticator to add support for the `private_key_jwt` + authentication method per RFC 7523. This is required for CIMD (Client ID Metadata + Document) clients that use asymmetric keys for authentication. + + The authenticator: + 1. Delegates to SDK for standard methods (client_secret_basic, client_secret_post, none) + 2. Adds private_key_jwt handling for CIMD clients + 3. Validates JWT assertions against client's JWKS + """ + + def __init__( + self, + provider: OAuthAuthorizationServerProvider[Any, Any, Any], + cimd_manager: CIMDClientManager, + token_endpoint_url: str, + ): + """Initialize the authenticator. + + Args: + provider: OAuth provider for client lookups + cimd_manager: CIMD manager for private_key_jwt validation + token_endpoint_url: Token endpoint URL for audience validation + """ + super().__init__(provider) + self._cimd_manager = cimd_manager + self._token_endpoint_url = token_endpoint_url + + async def authenticate_request( + self, request: Request + ) -> OAuthClientInformationFull: + """Authenticate a client from an HTTP request. + + Extends SDK authentication to support private_key_jwt for CIMD clients. + Delegates to SDK for client_secret_basic (Authorization header) and + client_secret_post (form body) authentication. + """ + form_data = await request.form() + client_id = form_data.get("client_id") + + # If client_id is not in form data, delegate to SDK + # This handles client_secret_basic which sends credentials in Authorization header + if not client_id: + return await super().authenticate_request(request) + + client = await self.provider.get_client(str(client_id)) + if not client: + raise AuthenticationError("Invalid client_id") + + # Handle private_key_jwt authentication for CIMD clients + if client.token_endpoint_auth_method == "private_key_jwt": + # Validate assertion parameters + assertion_type = form_data.get("client_assertion_type") + assertion = form_data.get("client_assertion") + + if assertion_type != JWT_BEARER_ASSERTION_TYPE: + raise AuthenticationError( + f"Invalid client_assertion_type: expected {JWT_BEARER_ASSERTION_TYPE}" + ) + + if not assertion or not isinstance(assertion, str): + raise AuthenticationError("Missing client_assertion") + + # Validate the JWT assertion using CIMD manager + try: + await self._cimd_manager.validate_private_key_jwt( + assertion=assertion, + client=client, + token_endpoint=self._token_endpoint_url, + ) + except ValueError as e: + raise AuthenticationError(f"Invalid client assertion: {e}") from e + + return client + + # Delegate to SDK for other authentication methods + return await super().authenticate_request(request) + + class AuthProvider(TokenVerifierProtocol): """Base class for all FastMCP authentication providers. diff --git a/src/fastmcp/server/auth/cimd.py b/src/fastmcp/server/auth/cimd.py new file mode 100644 index 0000000000..49aa6687cb --- /dev/null +++ b/src/fastmcp/server/auth/cimd.py @@ -0,0 +1,651 @@ +"""CIMD (Client ID Metadata Document) support for FastMCP. + +.. warning:: + **Beta Feature**: CIMD support is currently in beta. The API may change + in future releases. Please report any issues you encounter. + +CIMD is a simpler alternative to Dynamic Client Registration where clients +host a static JSON document at an HTTPS URL, and that URL becomes their +client_id. See the IETF draft: draft-parecki-oauth-client-id-metadata-document + +This module provides: +- CIMDDocument: Pydantic model for CIMD document validation +- CIMDFetcher: Fetch and validate CIMD documents with SSRF protection +- CIMDClientManager: Manages CIMD client operations +""" + +from __future__ import annotations + +import fnmatch +import json +import time +from typing import TYPE_CHECKING, Any, Literal +from urllib.parse import urlparse + +from pydantic import AnyHttpUrl, BaseModel, Field, field_validator + +from fastmcp.server.auth.ssrf import ( + SSRFError, + SSRFFetchError, + ssrf_safe_fetch, + validate_url, +) +from fastmcp.utilities.logging import get_logger + +if TYPE_CHECKING: + from fastmcp.server.auth.providers.jwt import JWTVerifier + +logger = get_logger(__name__) + + +class CIMDDocument(BaseModel): + """CIMD document per draft-parecki-oauth-client-id-metadata-document. + + The client metadata document is a JSON document containing OAuth client + metadata. The client_id property MUST match the URL where this document + is hosted. + + Key constraint: token_endpoint_auth_method MUST NOT use shared secrets + (client_secret_post, client_secret_basic, client_secret_jwt). + + redirect_uris is required and must contain at least one entry. + """ + + client_id: AnyHttpUrl = Field( + ..., + description="Must match the URL where this document is hosted", + ) + client_name: str | None = Field( + default=None, + description="Human-readable name of the client", + ) + client_uri: AnyHttpUrl | None = Field( + default=None, + description="URL of the client's home page", + ) + logo_uri: AnyHttpUrl | None = Field( + default=None, + description="URL of the client's logo image", + ) + redirect_uris: list[str] = Field( + ..., + description="Array of allowed redirect URIs (may include wildcards like http://localhost:*/callback)", + ) + token_endpoint_auth_method: Literal["none", "private_key_jwt"] = Field( + default="none", + description="Authentication method for token endpoint (no shared secrets allowed)", + ) + grant_types: list[str] = Field( + default_factory=lambda: ["authorization_code"], + description="OAuth grant types the client will use", + ) + response_types: list[str] = Field( + default_factory=lambda: ["code"], + description="OAuth response types the client will use", + ) + scope: str | None = Field( + default=None, + description="Space-separated list of scopes the client may request", + ) + contacts: list[str] | None = Field( + default=None, + description="Contact information for the client developer", + ) + tos_uri: AnyHttpUrl | None = Field( + default=None, + description="URL of the client's terms of service", + ) + policy_uri: AnyHttpUrl | None = Field( + default=None, + description="URL of the client's privacy policy", + ) + jwks_uri: AnyHttpUrl | None = Field( + default=None, + description="URL of the client's JSON Web Key Set (for private_key_jwt)", + ) + jwks: dict[str, Any] | None = Field( + default=None, + description="Client's JSON Web Key Set (for private_key_jwt)", + ) + software_id: str | None = Field( + default=None, + description="Unique identifier for the client software", + ) + software_version: str | None = Field( + default=None, + description="Version of the client software", + ) + + @field_validator("token_endpoint_auth_method") + @classmethod + def validate_auth_method(cls, v: str) -> str: + """Ensure no shared-secret auth methods are used.""" + forbidden = {"client_secret_post", "client_secret_basic", "client_secret_jwt"} + if v in forbidden: + raise ValueError( + f"CIMD documents cannot use shared-secret auth methods: {v}. " + "Use 'none' or 'private_key_jwt' instead." + ) + return v + + @field_validator("redirect_uris") + @classmethod + def validate_redirect_uris(cls, v: list[str]) -> list[str]: + """Ensure redirect_uris is non-empty and each entry is a valid URI.""" + if not v: + raise ValueError("CIMD documents must include at least one redirect_uri") + for uri in v: + if not uri or not uri.strip(): + raise ValueError("CIMD redirect_uris must be non-empty strings") + parsed = urlparse(uri) + if not parsed.scheme: + raise ValueError( + f"CIMD redirect_uri must have a scheme (e.g. http:// or https://): {uri!r}" + ) + if not parsed.netloc and not uri.startswith("urn:"): + raise ValueError(f"CIMD redirect_uri must have a host: {uri!r}") + return v + + +class CIMDValidationError(Exception): + """Raised when CIMD document validation fails.""" + + +class CIMDFetchError(Exception): + """Raised when CIMD document fetching fails.""" + + +class CIMDFetcher: + """Fetch and validate CIMD documents with SSRF protection. + + Delegates HTTP fetching to ssrf_safe_fetch which provides DNS pinning, + IP validation, size limits, and timeout enforcement. Documents are cached + with a simple TTL. + """ + + # Maximum response size (bytes) + MAX_RESPONSE_SIZE = 5120 # 5KB + # Default cache TTL (seconds) + DEFAULT_CACHE_TTL_SECONDS = 3600 + + def __init__( + self, + timeout: float = 10.0, + ): + """Initialize the CIMD fetcher. + + Args: + timeout: HTTP request timeout in seconds (default 10.0) + """ + self.timeout = timeout + self._cache: dict[str, tuple[CIMDDocument, float]] = {} + + def is_cimd_client_id(self, client_id: str) -> bool: + """Check if a client_id looks like a CIMD URL. + + CIMD URLs must be HTTPS with a host and non-root path. + """ + if not client_id: + return False + try: + parsed = urlparse(client_id) + return ( + parsed.scheme == "https" + and bool(parsed.netloc) + and parsed.path not in ("", "/") + ) + except (ValueError, AttributeError): + return False + + async def fetch(self, client_id_url: str) -> CIMDDocument: + """Fetch and validate a CIMD document with SSRF protection. + + Uses ssrf_safe_fetch for the HTTP layer, which provides: + - HTTPS only, DNS resolution with IP validation + - DNS pinning (connects to validated IP directly) + - Blocks private/loopback/link-local/multicast IPs + - Response size limit and timeout enforcement + - Redirects disabled + + Args: + client_id_url: The URL to fetch (also the expected client_id) + + Returns: + Validated CIMDDocument + + Raises: + CIMDValidationError: If document is invalid or URL blocked + CIMDFetchError: If document cannot be fetched + """ + cached = self._cache.get(client_id_url) + if cached is not None: + doc, expires_at = cached + if time.time() < expires_at: + return doc + + try: + content = await ssrf_safe_fetch( + client_id_url, + require_path=True, + max_size=self.MAX_RESPONSE_SIZE, + timeout=self.timeout, + overall_timeout=30.0, + ) + except SSRFError as e: + raise CIMDValidationError(str(e)) from e + except SSRFFetchError as e: + raise CIMDFetchError(str(e)) from e + + try: + data = json.loads(content) + except json.JSONDecodeError as e: + raise CIMDValidationError(f"CIMD document is not valid JSON: {e}") from e + + try: + doc = CIMDDocument.model_validate(data) + except Exception as e: + raise CIMDValidationError(f"Invalid CIMD document: {e}") from e + + if str(doc.client_id).rstrip("/") != client_id_url.rstrip("/"): + raise CIMDValidationError( + f"CIMD client_id mismatch: document says '{doc.client_id}' " + f"but was fetched from '{client_id_url}'" + ) + + # Validate jwks_uri if present (SSRF check for JWKS endpoint) + if doc.jwks_uri: + jwks_uri_str = str(doc.jwks_uri) + try: + await validate_url(jwks_uri_str) + except SSRFError as e: + raise CIMDValidationError( + f"CIMD jwks_uri failed SSRF validation: {e}" + ) from e + + logger.info( + "CIMD document fetched and validated: %s (client_name=%s)", + client_id_url, + doc.client_name, + ) + + self._cache[client_id_url] = (doc, time.time() + self.DEFAULT_CACHE_TTL_SECONDS) + return doc + + def validate_redirect_uri(self, doc: CIMDDocument, redirect_uri: str) -> bool: + """Validate that a redirect_uri is allowed by the CIMD document. + + Args: + doc: The CIMD document + redirect_uri: The redirect URI to validate + + Returns: + True if valid, False otherwise + """ + if not doc.redirect_uris: + # No redirect_uris specified - reject all + return False + + # Normalize for comparison + redirect_uri = redirect_uri.rstrip("/") + + for allowed in doc.redirect_uris: + allowed_str = allowed.rstrip("/") + if redirect_uri == allowed_str: + return True + + # Check for wildcard port matching (http://localhost:*/callback) + if "*" in allowed_str: + if fnmatch.fnmatch(redirect_uri, allowed_str): + return True + + return False + + +class CIMDAssertionValidator: + """Validates JWT assertions for private_key_jwt CIMD clients. + + Implements RFC 7523 (JSON Web Token (JWT) Profile for OAuth 2.0 Client + Authentication and Authorization Grants) for CIMD client authentication. + + JTI replay protection uses TTL-based caching to ensure proper security: + - JTIs are cached with expiration matching the JWT's exp claim + - Expired JTIs are automatically cleaned up + - Maximum assertion lifetime is enforced (5 minutes) + """ + + # Maximum allowed assertion lifetime in seconds (RFC 7523 recommends short-lived) + MAX_ASSERTION_LIFETIME = 300 # 5 minutes + + def __init__(self): + # JTI cache: maps jti -> expiration timestamp + self._jti_cache: dict[str, float] = {} + self._jti_cache_max_size = 10000 + self._last_cleanup = time.monotonic() + self._cleanup_interval = 60 # Cleanup every 60 seconds + # Cache JWTVerifier per jwks_uri so JWKS keys are not re-fetched + # on every token exchange + self._verifier_cache: dict[str, JWTVerifier] = {} + self._verifier_cache_max_size = 100 + self.logger = get_logger(__name__) + + def _cleanup_expired_jtis(self) -> None: + """Remove expired JTIs from cache.""" + now = time.time() + expired = [jti for jti, exp in self._jti_cache.items() if exp < now] + for jti in expired: + del self._jti_cache[jti] + if expired: + self.logger.debug("Cleaned up %d expired JTIs from cache", len(expired)) + + def _maybe_cleanup(self) -> None: + """Periodically cleanup expired JTIs to prevent unbounded growth.""" + now = time.monotonic() + if now - self._last_cleanup > self._cleanup_interval: + self._cleanup_expired_jtis() + self._last_cleanup = now + + async def validate_assertion( + self, + assertion: str, + client_id: str, + token_endpoint: str, + cimd_doc: CIMDDocument, + ) -> bool: + """Validate JWT assertion from client. + + Args: + assertion: The JWT assertion string + client_id: Expected client_id (must match iss and sub claims) + token_endpoint: Token endpoint URL (must match aud claim) + cimd_doc: CIMD document containing JWKS for key verification + + Returns: + True if valid + + Raises: + ValueError: If validation fails + """ + from fastmcp.server.auth.providers.jwt import JWTVerifier as _JWTVerifier + + # Periodic cleanup of expired JTIs + self._maybe_cleanup() + + # 1. Validate CIMD document has key material and get/create verifier + if cimd_doc.jwks_uri: + jwks_uri_str = str(cimd_doc.jwks_uri) + cache_key = f"{jwks_uri_str}|{client_id}|{token_endpoint}" + verifier = self._verifier_cache.get(cache_key) + if verifier is None: + verifier = _JWTVerifier( + jwks_uri=jwks_uri_str, + issuer=client_id, + audience=token_endpoint, + ssrf_safe=True, + ) + if len(self._verifier_cache) >= self._verifier_cache_max_size: + oldest_key = next(iter(self._verifier_cache)) + del self._verifier_cache[oldest_key] + self._verifier_cache[cache_key] = verifier + elif cimd_doc.jwks: + # Inline JWKS — no caching since the key is embedded + public_key = self._extract_public_key_from_jwks(assertion, cimd_doc.jwks) + verifier = _JWTVerifier( + public_key=public_key, + issuer=client_id, + audience=token_endpoint, + ) + else: + raise ValueError( + "CIMD document must have jwks_uri or jwks for private_key_jwt" + ) + + # 2. Verify JWT using JWTVerifier (handles signature, exp, iss, aud) + access_token = await verifier.load_access_token(assertion) + if not access_token: + raise ValueError("Invalid JWT assertion") + + claims = access_token.claims + + # 3. Validate assertion lifetime (exp and iat) + now = time.time() + exp = claims.get("exp") + iat = claims.get("iat") + + if not exp: + raise ValueError("Assertion must include exp claim") + + # Validate exp is in the future (with small clock skew tolerance) + if exp < now - 30: # 30 second clock skew tolerance + raise ValueError("Assertion has expired") + + # If iat is present, validate it and check assertion lifetime + if iat: + if iat > now + 30: # 30 second clock skew tolerance + raise ValueError("Assertion iat is in the future") + if exp - iat > self.MAX_ASSERTION_LIFETIME: + raise ValueError( + f"Assertion lifetime too long: {exp - iat}s (max {self.MAX_ASSERTION_LIFETIME}s)" + ) + else: + # No iat, enforce max lifetime from now + if exp > now + self.MAX_ASSERTION_LIFETIME: + raise ValueError( + f"Assertion exp too far in future (max {self.MAX_ASSERTION_LIFETIME}s)" + ) + + # 4. Additional RFC 7523 validation: sub claim must equal client_id + if claims.get("sub") != client_id: + raise ValueError(f"Assertion sub claim must be {client_id}") + + # 5. Check jti for replay attacks (RFC 7523 requirement) + jti = claims.get("jti") + if not jti: + raise ValueError("Assertion must include jti claim") + + # Check if JTI was already used (and hasn't expired from cache) + if jti in self._jti_cache: + cached_exp = self._jti_cache[jti] + if cached_exp > now: # Still valid in cache + raise ValueError(f"Assertion replay detected: jti {jti} already used") + # Expired in cache, can be reused (clean it up) + del self._jti_cache[jti] + + # Add to cache with expiration time + # Use the assertion's exp claim so it stays cached until it would expire anyway + self._jti_cache[jti] = exp + + # Emergency size limit (shouldn't hit with proper TTL cleanup) + if len(self._jti_cache) > self._jti_cache_max_size: + self._cleanup_expired_jtis() + # If still over limit after cleanup, reject to prevent DoS + if len(self._jti_cache) > self._jti_cache_max_size: + self.logger.warning( + "JTI cache at max capacity (%d), possible attack", + self._jti_cache_max_size, + ) + raise ValueError("Server overloaded, please retry") + + self.logger.debug( + "JWT assertion validated successfully for client %s", client_id + ) + return True + + def _extract_public_key_from_jwks(self, token: str, jwks: dict) -> str: + """Extract public key from inline JWKS. + + Args: + token: JWT token to extract kid from + jwks: JWKS document containing keys + + Returns: + PEM-encoded public key + + Raises: + ValueError: If key cannot be found or extracted + """ + import base64 + import json + + from authlib.jose import JsonWebKey + + # Extract kid from token header + try: + header_b64 = token.split(".")[0] + header_b64 += "=" * (4 - len(header_b64) % 4) # Add padding + header = json.loads(base64.urlsafe_b64decode(header_b64)) + kid = header.get("kid") + except Exception as e: + raise ValueError(f"Failed to extract key ID from token: {e}") from e + + # Find matching key in JWKS + keys = jwks.get("keys", []) + if not keys: + raise ValueError("JWKS document contains no keys") + + matching_key = None + for key in keys: + if kid and key.get("kid") == kid: + matching_key = key + break + + if not matching_key: + # If no kid match, try first key as fallback + if len(keys) == 1: + matching_key = keys[0] + self.logger.warning( + "No matching kid in JWKS, using single available key" + ) + else: + raise ValueError(f"No matching key found for kid={kid} in JWKS") + + # Convert JWK to PEM + try: + jwk = JsonWebKey.import_key(matching_key) + return jwk.as_pem().decode("utf-8") + except Exception as e: + raise ValueError(f"Failed to convert JWK to PEM: {e}") from e + + +class CIMDClientManager: + """Manages all CIMD client operations for OAuth proxy. + + This class encapsulates: + - CIMD client detection + - Document fetching and validation + - Synthetic OAuth client creation + - Private key JWT assertion validation + + This allows the OAuth proxy to delegate all CIMD-specific logic to a + single, focused manager class. + """ + + def __init__( + self, + enable_cimd: bool = True, + default_scope: str = "", + allowed_redirect_uri_patterns: list[str] | None = None, + ): + """Initialize CIMD client manager. + + Args: + enable_cimd: Whether CIMD support is enabled + default_scope: Default scope for CIMD clients if not specified in document + allowed_redirect_uri_patterns: Allowed redirect URI patterns (proxy's config) + """ + self.enabled = enable_cimd + self.default_scope = default_scope + self.allowed_redirect_uri_patterns = allowed_redirect_uri_patterns + + self._fetcher = CIMDFetcher() + self._assertion_validator = CIMDAssertionValidator() + self.logger = get_logger(__name__) + + def is_cimd_client_id(self, client_id: str) -> bool: + """Check if client_id is a CIMD URL. + + Args: + client_id: Client ID to check + + Returns: + True if client_id is an HTTPS URL (CIMD format) + """ + return self.enabled and self._fetcher.is_cimd_client_id(client_id) + + async def get_client(self, client_id_url: str): + """Fetch CIMD document and create synthetic OAuth client. + + Args: + client_id_url: HTTPS URL pointing to CIMD document + + Returns: + OAuthProxyClient with CIMD document attached, or None if fetch fails + + Note: + Return type is left untyped to avoid circular import with oauth_proxy. + Returns OAuthProxyClient instance or None. + """ + if not self.enabled: + return None + + try: + cimd_doc = await self._fetcher.fetch(client_id_url) + except (CIMDFetchError, CIMDValidationError) as e: + self.logger.warning("CIMD fetch failed for %s: %s", client_id_url, e) + return None + + # Import here to avoid circular dependency + from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient + + # Create synthetic client from CIMD document. + # Keep CIMD redirect_uris as strings on the document itself so wildcard + # patterns like http://localhost:*/callback remain valid. + redirect_uris = None + client = ProxyDCRClient( + client_id=client_id_url, + client_secret=None, + redirect_uris=redirect_uris, + grant_types=cimd_doc.grant_types, + scope=cimd_doc.scope or self.default_scope, + token_endpoint_auth_method=cimd_doc.token_endpoint_auth_method, + allowed_redirect_uri_patterns=self.allowed_redirect_uri_patterns, + client_name=cimd_doc.client_name, + cimd_document=cimd_doc, + cimd_fetched_at=time.time(), + ) + + self.logger.debug( + "CIMD client resolved: %s (name=%s)", + client_id_url, + cimd_doc.client_name, + ) + return client + + async def validate_private_key_jwt( + self, + assertion: str, + client, # OAuthProxyClient, untyped to avoid circular import + token_endpoint: str, + ) -> bool: + """Validate JWT assertion for private_key_jwt auth. + + Args: + assertion: JWT assertion string from client + client: OAuth proxy client (must have cimd_document) + token_endpoint: Token endpoint URL for aud validation + + Returns: + True if assertion is valid + + Raises: + ValueError: If client doesn't have CIMD document or validation fails + """ + if not hasattr(client, "cimd_document") or not client.cimd_document: + raise ValueError("Client must have CIMD document for private_key_jwt") + + cimd_doc = client.cimd_document + if cimd_doc.token_endpoint_auth_method != "private_key_jwt": + raise ValueError("CIMD document must specify private_key_jwt auth method") + + return await self._assertion_validator.validate_assertion( + assertion, client.client_id, token_endpoint, cimd_doc + ) diff --git a/src/fastmcp/server/auth/oauth_proxy/consent.py b/src/fastmcp/server/auth/oauth_proxy/consent.py index 6f47a5da77..87b63d88f2 100644 --- a/src/fastmcp/server/auth/oauth_proxy/consent.py +++ b/src/fastmcp/server/auth/oauth_proxy/consent.py @@ -21,6 +21,7 @@ from starlette.requests import Request from starlette.responses import HTMLResponse, RedirectResponse +from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient from fastmcp.server.auth.oauth_proxy.ui import create_consent_html from fastmcp.utilities.logging import get_logger from fastmcp.utilities.ui import create_secure_html_response @@ -245,10 +246,17 @@ async def _show_consent_page( txn["csrf_token"] = csrf_token txn["csrf_expires_at"] = csrf_expires_at - # Load client to get client_name if available + # Load client to get client_name and CIMD info if available client = await self.get_client(txn["client_id"]) client_name = getattr(client, "client_name", None) if client else None + # Detect CIMD clients for verified domain badge + is_cimd_client = False + cimd_domain: str | None = None + if isinstance(client, ProxyDCRClient) and client.cimd_document is not None: + is_cimd_client = True + cimd_domain = urlparse(txn["client_id"]).hostname + # Extract server metadata from app state fastmcp = getattr(request.app.state, "fastmcp_server", None) @@ -273,6 +281,8 @@ async def _show_consent_page( server_icon_url=server_icon_url, server_website_url=server_website_url, csp_policy=self._consent_csp_policy, + is_cimd_client=is_cimd_client, + cimd_domain=cimd_domain, ) response = create_secure_html_response(html) # Store CSRF in cookie with short lifetime diff --git a/src/fastmcp/server/auth/oauth_proxy/models.py b/src/fastmcp/server/auth/oauth_proxy/models.py index 53c939f2a1..575c846baa 100644 --- a/src/fastmcp/server/auth/oauth_proxy/models.py +++ b/src/fastmcp/server/auth/oauth_proxy/models.py @@ -11,7 +11,11 @@ from mcp.shared.auth import InvalidRedirectUriError, OAuthClientInformationFull from pydantic import AnyUrl, BaseModel, Field -from fastmcp.server.auth.redirect_validation import validate_redirect_uri +from fastmcp.server.auth.cimd import CIMDDocument +from fastmcp.server.auth.redirect_validation import ( + matches_allowed_pattern, + validate_redirect_uri, +) # ------------------------------------------------------------------------- # Constants @@ -156,28 +160,77 @@ class ProxyDCRClient(OAuthClientInformationFull): allowed_redirect_uri_patterns: list[str] | None = Field(default=None) client_name: str | None = Field(default=None) + cimd_document: CIMDDocument | None = Field(default=None) + cimd_fetched_at: float | None = Field(default=None) def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: - """Validate redirect URI against allowed patterns. + """Validate redirect URI against proxy patterns and optionally CIMD redirect_uris. - Since we're acting as a proxy and clients register dynamically, - we validate their redirect URIs against configurable patterns. - This is essential for cached token scenarios where the client may - reconnect with a different port. + For CIMD clients: validates against BOTH the CIMD document's redirect_uris + AND the proxy's allowed patterns (if configured). Both must pass. + + For DCR clients: validates against proxy patterns first, falling back to + base validation (registered redirect_uris) if patterns don't match. """ + if redirect_uri is None and self.cimd_document is not None: + cimd_redirect_uris = self.cimd_document.redirect_uris + if len(cimd_redirect_uris) == 1: + candidate = cimd_redirect_uris[0] + if "*" in candidate: + raise InvalidRedirectUriError( + "redirect_uri must be specified when CIMD redirect_uris uses wildcards." + ) + try: + return AnyUrl(candidate) + except Exception as e: + raise InvalidRedirectUriError( + f"Invalid CIMD redirect_uri: {e}" + ) from e + + raise InvalidRedirectUriError( + "redirect_uri must be specified when CIMD lists multiple redirect_uris." + ) + if redirect_uri is not None: - # Validate against allowed patterns - if validate_redirect_uri( + cimd_redirect_uris = ( + self.cimd_document.redirect_uris if self.cimd_document else None + ) + + if cimd_redirect_uris: + uri_str = str(redirect_uri) + cimd_match = any( + matches_allowed_pattern(uri_str, pattern) + for pattern in cimd_redirect_uris + ) + if not cimd_match: + raise InvalidRedirectUriError( + f"Redirect URI '{redirect_uri}' does not match CIMD redirect_uris." + ) + + if self.allowed_redirect_uri_patterns: + if not validate_redirect_uri( + redirect_uri=redirect_uri, + allowed_patterns=self.allowed_redirect_uri_patterns, + ): + raise InvalidRedirectUriError( + f"Redirect URI '{redirect_uri}' does not match allowed patterns." + ) + + return redirect_uri + + pattern_matches = validate_redirect_uri( redirect_uri=redirect_uri, allowed_patterns=self.allowed_redirect_uri_patterns, - ): + ) + + if pattern_matches: return redirect_uri - # If patterns are explicitly configured then reject non-matching URIs + # Patterns configured but didn't match if self.allowed_redirect_uri_patterns: raise InvalidRedirectUriError( f"Redirect URI '{redirect_uri}' does not match allowed patterns." ) - # If no redirect_uri provided, use default behavior + # No redirect_uri provided or no patterns configured — use base validation return super().validate_redirect_uri(redirect_uri) diff --git a/src/fastmcp/server/auth/oauth_proxy/proxy.py b/src/fastmcp/server/auth/oauth_proxy/proxy.py index e1a24720f5..27773ef25c 100644 --- a/src/fastmcp/server/auth/oauth_proxy/proxy.py +++ b/src/fastmcp/server/auth/oauth_proxy/proxy.py @@ -32,6 +32,7 @@ from key_value.aio.adapters.pydantic import PydanticAdapter from key_value.aio.protocols import AsyncKeyValue from key_value.aio.wrappers.encryption import FernetEncryptionWrapper +from mcp.server.auth.handlers.metadata import MetadataHandler from mcp.server.auth.provider import ( AccessToken, AuthorizationCode, @@ -40,6 +41,7 @@ RefreshToken, TokenError, ) +from mcp.server.auth.routes import build_metadata, cors_middleware from mcp.server.auth.settings import ( ClientRegistrationOptions, RevocationOptions, @@ -52,7 +54,13 @@ from typing_extensions import override from fastmcp import settings -from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier +from fastmcp.server.auth.auth import ( + OAuthProvider, + PrivateKeyJWTClientAuthenticator, + TokenHandler, + TokenVerifier, +) +from fastmcp.server.auth.cimd import CIMDClientManager from fastmcp.server.auth.handlers.authorize import AuthorizationHandler from fastmcp.server.auth.jwt_issuer import ( JWTIssuer, @@ -248,6 +256,8 @@ def __init__( consent_csp_policy: str | None = None, # Token expiry fallback fallback_access_token_expiry_seconds: int | None = None, + # CIMD (Client ID Metadata Document) support + enable_cimd: bool = True, ): """Initialize the OAuth proxy provider. @@ -302,6 +312,9 @@ def __init__( defaults: 1 hour if a refresh token is available (since we can refresh), or 1 year if no refresh token (for API-key-style tokens like GitHub OAuth Apps). Set explicitly to override these defaults. + enable_cimd: Enable CIMD (Client ID Metadata Document) support for URL-based + client IDs. When True, clients can authenticate using HTTPS URLs as client + IDs, with metadata fetched from the URL. Supports private_key_jwt auth. """ # Always enable DCR since we implement it locally for MCP clients @@ -484,6 +497,15 @@ def __init__( # Use the provided token validator self._token_validator: TokenVerifier = token_verifier + # CIMD (Client ID Metadata Document) support + self._cimd_manager: CIMDClientManager | None = None + if enable_cimd: + self._cimd_manager = CIMDClientManager( + enable_cimd=True, + default_scope=self._default_scope_str, + allowed_redirect_uri_patterns=self._allowed_client_redirect_uris, + ) + logger.debug( "Initialized OAuth proxy provider with upstream server %s", self._upstream_authorization_endpoint, @@ -559,15 +581,43 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: provided to the DCR client during registration, not the upstream client ID. For unregistered clients, returns None (which will raise an error in the SDK). + CIMD clients (URL-based client IDs) are looked up and cached automatically. """ # Load from storage - if not (client := await self._client_store.get(key=client_id)): - return None + client = await self._client_store.get(key=client_id) - if client.allowed_redirect_uri_patterns is None: - client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris + if client is not None: + if client.allowed_redirect_uri_patterns is None: + client.allowed_redirect_uri_patterns = ( + self._allowed_client_redirect_uris + ) + + # Refresh CIMD clients using HTTP cache-aware fetcher. + if self._cimd_manager is not None and client.cimd_document is not None: + try: + refreshed = await self._cimd_manager.get_client(client_id) + if refreshed is not None: + await self._client_store.put(key=client_id, value=refreshed) + return refreshed + except Exception as e: + logger.debug( + "CIMD refresh failed for %s, using cached client: %s", + client_id, + e, + ) - return client + return client + + # Client not in storage — try CIMD lookup for URL-based client IDs + if self._cimd_manager is not None and self._cimd_manager.is_cimd_client_id( + client_id + ): + cimd_client = await self._cimd_manager.get_client(client_id) + if cimd_client is not None: + await self._client_store.put(key=client_id, value=cimd_client) + return cimd_client + + return None @override async def register_client(self, client_info: OAuthClientInformationFull) -> None: @@ -1437,6 +1487,61 @@ def get_routes( methods=["GET", "POST"], ) ) + elif ( + self._cimd_manager is not None + and isinstance(route, Route) + and route.path == "/token" + and route.methods is not None + and "POST" in route.methods + ): + # Replace the token endpoint authenticator with one that supports + # private_key_jwt for CIMD clients + token_endpoint_url = f"{self.base_url}/token" + cimd_authenticator = PrivateKeyJWTClientAuthenticator( + provider=self, + cimd_manager=self._cimd_manager, + token_endpoint_url=token_endpoint_url, + ) + token_handler = TokenHandler( + provider=self, client_authenticator=cimd_authenticator + ) + custom_routes.append( + Route( + path="/token", + endpoint=cors_middleware( + token_handler.handle, ["POST", "OPTIONS"] + ), + methods=["POST", "OPTIONS"], + ) + ) + elif ( + self._cimd_manager is not None + and isinstance(route, Route) + and route.path.startswith("/.well-known/oauth-authorization-server") + ): + client_registration_options = ( + self.client_registration_options or ClientRegistrationOptions() + ) + revocation_options = self.revocation_options or RevocationOptions() + metadata = build_metadata( + self.base_url, # ty: ignore[invalid-argument-type] + self.service_documentation_url, + client_registration_options, + revocation_options, + ) + metadata.client_id_metadata_document_supported = True + handler = MetadataHandler(metadata) + methods = route.methods or ["GET", "OPTIONS"] + + custom_routes.append( + Route( + path=route.path, + endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]), + methods=methods, + name=route.name, + include_in_schema=route.include_in_schema, + ) + ) else: # Keep all other standard OAuth routes unchanged custom_routes.append(route) diff --git a/src/fastmcp/server/auth/oauth_proxy/ui.py b/src/fastmcp/server/auth/oauth_proxy/ui.py index 3bae1a11c1..4cbb3ec2ce 100644 --- a/src/fastmcp/server/auth/oauth_proxy/ui.py +++ b/src/fastmcp/server/auth/oauth_proxy/ui.py @@ -32,6 +32,8 @@ def create_consent_html( server_website_url: str | None = None, client_website_url: str | None = None, csp_policy: str | None = None, + is_cimd_client: bool = False, + cimd_domain: str | None = None, ) -> str: """Create a styled HTML consent page for OAuth authorization requests. @@ -60,6 +62,17 @@ def create_consent_html( """ + # Build CIMD verified domain badge if applicable + cimd_badge = "" + if is_cimd_client and cimd_domain: + cimd_domain_escaped = html_module.escape(cimd_domain) + cimd_badge = f""" +
+ + Verified domain: {cimd_domain_escaped} +
+ """ + # Build redirect URI section (yellow box, centered) redirect_uri_escaped = html_module.escape(redirect_uri) redirect_section = f""" @@ -144,6 +157,7 @@ def create_consent_html( {create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}

Application Access Request

{intro_box} + {cimd_badge} {redirect_section} {advanced_details} {form} @@ -152,6 +166,23 @@ def create_consent_html( """ # Additional styles needed for this page + cimd_badge_styles = """ + .cimd-badge { + background: #ecfdf5; + border: 1px solid #6ee7b7; + border-radius: 8px; + padding: 8px 16px; + margin-bottom: 16px; + font-size: 14px; + color: #065f46; + text-align: center; + } + .cimd-check { + color: #059669; + font-weight: bold; + margin-right: 4px; + } + """ additional_styles = ( INFO_BOX_STYLES + REDIRECT_SECTION_STYLES @@ -159,6 +190,7 @@ def create_consent_html( + DETAIL_BOX_STYLES + BUTTON_STYLES + TOOLTIP_STYLES + + cimd_badge_styles ) # Determine CSP policy to use diff --git a/src/fastmcp/server/auth/oidc_proxy.py b/src/fastmcp/server/auth/oidc_proxy.py index 1bcdef4e43..d89ac07561 100644 --- a/src/fastmcp/server/auth/oidc_proxy.py +++ b/src/fastmcp/server/auth/oidc_proxy.py @@ -228,6 +228,8 @@ def __init__( extra_token_params: dict[str, str] | None = None, # Token expiry fallback fallback_access_token_expiry_seconds: int | None = None, + # CIMD configuration + enable_cimd: bool = True, ) -> None: """Initialize the OIDC proxy provider. @@ -278,6 +280,9 @@ def __init__( doesn't return `expires_in` in the token response. If not set, uses smart defaults: 1 hour if a refresh token is available (since we can refresh), or 1 year if no refresh token (for API-key-style tokens like GitHub OAuth Apps). + enable_cimd: Whether to enable CIMD (Client ID Metadata Document) client support. + When True, clients can use their metadata document URL as client_id instead of + Dynamic Client Registration. Default is True. """ if not config_url: raise ValueError("Missing required config URL") @@ -351,6 +356,7 @@ def __init__( "require_authorization_consent": require_authorization_consent, "consent_csp_policy": consent_csp_policy, "fallback_access_token_expiry_seconds": fallback_access_token_expiry_seconds, + "enable_cimd": enable_cimd, } if redirect_path: diff --git a/src/fastmcp/server/auth/providers/jwt.py b/src/fastmcp/server/auth/providers/jwt.py index fd01c2f2cb..828b9238f7 100644 --- a/src/fastmcp/server/auth/providers/jwt.py +++ b/src/fastmcp/server/auth/providers/jwt.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import time from dataclasses import dataclass from typing import Any, cast @@ -15,6 +16,7 @@ from typing_extensions import TypedDict from fastmcp.server.auth import AccessToken, TokenVerifier +from fastmcp.server.auth.ssrf import SSRFError, SSRFFetchError, ssrf_safe_fetch from fastmcp.utilities.auth import decode_jwt_header, parse_scopes from fastmcp.utilities.logging import get_logger @@ -165,6 +167,7 @@ def __init__( algorithm: str | None = None, required_scopes: list[str] | None = None, base_url: AnyHttpUrl | str | None = None, + ssrf_safe: bool = False, ): """ Initialize a JWTVerifier configured to validate JWTs using either a static key or a JWKS endpoint. @@ -177,6 +180,10 @@ def __init__( algorithm: JWT signing algorithm to accept (default: "RS256"). Supported: HS256/384/512, RS256/384/512, ES256/384/512, PS256/384/512. required_scopes: Scopes that must be present in validated tokens. base_url: Base URL passed to the parent TokenVerifier. + ssrf_safe: If True, JWKS fetches use SSRF protection (HTTPS-only, + public IPs, DNS pinning). Enable when the JWKS URI comes from + untrusted input (e.g. CIMD documents). Defaults to False so + operator-configured JWKS URIs (including localhost) work normally. Raises: ValueError: If neither or both of `public_key` and `jwks_uri` are provided, or if `algorithm` is unsupported. @@ -220,6 +227,7 @@ def __init__( self.audience = audience self.public_key = public_key self.jwks_uri = jwks_uri + self.ssrf_safe = ssrf_safe self.jwt = JsonWebToken([self.algorithm]) self.logger = get_logger(__name__) @@ -239,11 +247,11 @@ async def _get_verification_key(self, token: str) -> str: kid = header.get("kid") return await self._get_jwks_key(kid) - except Exception as e: + except (ValueError, KeyError, IndexError, json.JSONDecodeError) as e: raise ValueError(f"Failed to extract key ID from token: {e}") from e async def _get_jwks_key(self, kid: str | None) -> str: - """Fetch key from JWKS with simple caching.""" + """Fetch key from JWKS with simple caching and SSRF protection.""" if not self.jwks_uri: raise ValueError("JWKS URI not configured") @@ -257,12 +265,9 @@ async def _get_jwks_key(self, kid: str | None) -> str: # If no kid but only one key cached, use it return next(iter(self._jwks_cache.values())) - # Fetch JWKS + # Fetch JWKS — with SSRF protection when enabled (untrusted URIs) try: - async with httpx.AsyncClient() as client: - response = await client.get(self.jwks_uri) - response.raise_for_status() - jwks_data = response.json() + jwks_data = await self._fetch_jwks() # Cache all keys self._jwks_cache = {} @@ -298,11 +303,35 @@ async def _get_jwks_key(self, kid: str | None) -> str: else: raise ValueError("No keys found in JWKS") - except httpx.HTTPError as e: + except (SSRFError, SSRFFetchError) as e: + self.logger.debug("JWKS fetch blocked by SSRF protection: %s", e) raise ValueError(f"Failed to fetch JWKS: {e}") from e - except Exception as e: - self.logger.debug(f"JWKS fetch failed: {e}") + except httpx.HTTPError as e: raise ValueError(f"Failed to fetch JWKS: {e}") from e + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JWKS JSON: {e}") from e + except (JoseError, TypeError, KeyError) as e: + self.logger.debug("JWKS key processing failed: %s", e) + raise ValueError(f"Failed to process JWKS: {e}") from e + + async def _fetch_jwks(self) -> dict[str, Any]: + """Fetch JWKS data, using SSRF-safe or standard fetch based on config.""" + if not self.jwks_uri: + raise ValueError("JWKS URI not configured") + + if self.ssrf_safe: + content = await ssrf_safe_fetch( + self.jwks_uri, + max_size=65536, + timeout=10.0, + overall_timeout=30.0, + ) + return json.loads(content) + else: + async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client: + response = await client.get(self.jwks_uri) + response.raise_for_status() + return response.json() def _extract_scopes(self, claims: dict[str, Any]) -> list[str]: """ @@ -435,7 +464,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: except JoseError: self.logger.debug("Token validation failed: JWT signature/format invalid") return None - except Exception as e: + except (ValueError, TypeError, KeyError, AttributeError) as e: self.logger.debug("Token validation failed: %s", str(e)) return None diff --git a/src/fastmcp/server/auth/redirect_validation.py b/src/fastmcp/server/auth/redirect_validation.py index f49958ad56..4d011416fb 100644 --- a/src/fastmcp/server/auth/redirect_validation.py +++ b/src/fastmcp/server/auth/redirect_validation.py @@ -1,19 +1,138 @@ -"""Utilities for validating client redirect URIs in OAuth flows.""" +"""Utilities for validating client redirect URIs in OAuth flows. + +This module provides secure redirect URI validation with wildcard support, +protecting against userinfo-based bypass attacks like http://localhost@evil.com. +""" import fnmatch +from urllib.parse import urlparse from pydantic import AnyUrl +def _parse_host_port(netloc: str) -> tuple[str | None, str | None]: + """Parse host and port from netloc, handling wildcards. + + Args: + netloc: The netloc component (e.g., "localhost:8080" or "localhost:*") + + Returns: + Tuple of (host, port_str) where port_str may be "*" or a number string + """ + # Handle userinfo (remove it for parsing, but we check separately) + if "@" in netloc: + netloc = netloc.split("@")[-1] + + # Handle IPv6 addresses [::1]:port + if netloc.startswith("["): + bracket_end = netloc.find("]") + if bracket_end == -1: + return netloc, None + host = netloc[1:bracket_end] + rest = netloc[bracket_end + 1 :] + if rest.startswith(":"): + return host, rest[1:] + return host, None + + # Handle regular host:port + if ":" in netloc: + host, port = netloc.rsplit(":", 1) + return host, port + + return netloc, None + + +def _match_host(uri_host: str | None, pattern_host: str | None) -> bool: + """Match host component, supporting *.example.com wildcard patterns. + + Args: + uri_host: The host from the URI being validated + pattern_host: The host pattern (may start with *.) + + Returns: + True if the host matches + """ + if not uri_host or not pattern_host: + return uri_host == pattern_host + + # Normalize to lowercase for comparison + uri_host = uri_host.lower() + pattern_host = pattern_host.lower() + + # Handle *.example.com wildcard subdomain patterns + if pattern_host.startswith("*."): + suffix = pattern_host[1:] # .example.com + # Only match actual subdomains (foo.example.com), NOT the base domain + return uri_host.endswith(suffix) and uri_host != pattern_host[2:] + + return uri_host == pattern_host + + +def _match_port( + uri_port: str | None, + pattern_port: str | None, + uri_scheme: str, +) -> bool: + """Match port component, supporting * wildcard for any port. + + Args: + uri_port: The port from the URI (None if default, string otherwise) + pattern_port: The port from the pattern (None if default, "*" for wildcard) + uri_scheme: The URI scheme (http/https) for default port handling + + Returns: + True if the port matches + """ + # Wildcard matches any port + if pattern_port == "*": + return True + + # Normalize None to default ports + default_port = "443" if uri_scheme == "https" else "80" + uri_effective = uri_port if uri_port else default_port + pattern_effective = pattern_port if pattern_port else default_port + + return uri_effective == pattern_effective + + +def _match_path(uri_path: str, pattern_path: str) -> bool: + """Match path component using fnmatch for wildcard support. + + Args: + uri_path: The path from the URI + pattern_path: The path pattern (may contain * wildcards) + + Returns: + True if the path matches + """ + # Normalize empty paths to / + uri_path = uri_path or "/" + pattern_path = pattern_path or "/" + + # Empty or root pattern path matches any path + # This makes http://localhost:* match http://localhost:3000/callback + if pattern_path == "/": + return True + + # Use fnmatch for path wildcards (e.g., /auth/*) + return fnmatch.fnmatch(uri_path, pattern_path) + + def matches_allowed_pattern(uri: str, pattern: str) -> bool: - """Check if a URI matches an allowed pattern with wildcard support. + """Securely check if a URI matches an allowed pattern with wildcard support. - Patterns support * wildcard matching: + This function parses both the URI and pattern as URLs, comparing each + component separately to prevent bypass attacks like userinfo injection. + + Patterns support wildcards: - http://localhost:* matches any localhost port - http://127.0.0.1:* matches any 127.0.0.1 port - https://*.example.com/* matches any subdomain of example.com - https://app.example.com/auth/* matches any path under /auth/ + Security: Rejects URIs with userinfo (user:pass@host) which could bypass + naive string matching (e.g., http://localhost@evil.com). + Args: uri: The redirect URI to validate pattern: The allowed pattern (may contain wildcards) @@ -21,8 +140,36 @@ def matches_allowed_pattern(uri: str, pattern: str) -> bool: Returns: True if the URI matches the pattern """ - # Use fnmatch for wildcard matching - return fnmatch.fnmatch(uri, pattern) + try: + uri_parsed = urlparse(uri) + pattern_parsed = urlparse(pattern) + except ValueError: + return False + + # SECURITY: Reject URIs with userinfo (user:pass@host) + # This prevents bypass attacks like http://localhost@evil.com/callback + # which would match http://localhost:* with naive fnmatch + if uri_parsed.username is not None or uri_parsed.password is not None: + return False + + # Scheme must match exactly + if uri_parsed.scheme.lower() != pattern_parsed.scheme.lower(): + return False + + # Parse host and port manually to handle wildcards + uri_host, uri_port = _parse_host_port(uri_parsed.netloc) + pattern_host, pattern_port = _parse_host_port(pattern_parsed.netloc) + + # Host must match (with subdomain wildcard support) + if not _match_host(uri_host, pattern_host): + return False + + # Port must match (with * wildcard support) + if not _match_port(uri_port, pattern_port, uri_parsed.scheme.lower()): + return False + + # Path must match (with fnmatch wildcards) + return _match_path(uri_parsed.path, pattern_parsed.path) def validate_redirect_uri( diff --git a/src/fastmcp/server/auth/ssrf.py b/src/fastmcp/server/auth/ssrf.py new file mode 100644 index 0000000000..8009269c6d --- /dev/null +++ b/src/fastmcp/server/auth/ssrf.py @@ -0,0 +1,307 @@ +"""SSRF-safe HTTP utilities for FastMCP. + +This module provides SSRF-protected HTTP fetching with: +- DNS resolution and IP validation before requests +- DNS pinning to prevent rebinding TOCTOU attacks +- Support for both CIMD and JWKS fetches +""" + +from __future__ import annotations + +import asyncio +import ipaddress +import socket +import time +from dataclasses import dataclass +from urllib.parse import urlparse + +import httpx + +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +def format_ip_for_url(ip_str: str) -> str: + """Format IP address for use in URL (bracket IPv6 addresses). + + IPv6 addresses must be bracketed in URLs to distinguish the address from + the port separator. For example: https://[2001:db8::1]:443/path + + Args: + ip_str: IP address string + + Returns: + IP string suitable for URL (IPv6 addresses are bracketed) + """ + try: + ip = ipaddress.ip_address(ip_str) + if isinstance(ip, ipaddress.IPv6Address): + return f"[{ip_str}]" + return ip_str + except ValueError: + return ip_str + + +class SSRFError(Exception): + """Raised when an SSRF protection check fails.""" + + +class SSRFFetchError(Exception): + """Raised when SSRF-safe fetch fails.""" + + +def is_ip_allowed(ip_str: str) -> bool: + """Check if an IP address is allowed (must be globally routable unicast). + + Uses ip.is_global which catches: + - Private (10.x, 172.16-31.x, 192.168.x) + - Loopback (127.x, ::1) + - Link-local (169.254.x, fe80::) - includes AWS metadata! + - Reserved, unspecified + - RFC6598 Carrier-Grade NAT (100.64.0.0/10) - can point to internal networks + + Additionally blocks multicast addresses (not caught by is_global). + + Args: + ip_str: IP address string to check + + Returns: + True if the IP is allowed (public unicast internet), False if blocked + """ + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + return False + + if not ip.is_global: + return False + + # Block multicast (not caught by is_global for some ranges) + if ip.is_multicast: + return False + + # IPv6-specific checks for embedded IPv4 addresses + if isinstance(ip, ipaddress.IPv6Address): + if ip.ipv4_mapped: + return is_ip_allowed(str(ip.ipv4_mapped)) + if ip.sixtofour: + return is_ip_allowed(str(ip.sixtofour)) + if ip.teredo: + server, client = ip.teredo + return is_ip_allowed(str(server)) and is_ip_allowed(str(client)) + + return True + + +async def resolve_hostname(hostname: str, port: int = 443) -> list[str]: + """Resolve hostname to IP addresses using DNS. + + Args: + hostname: Hostname to resolve + port: Port number (used for getaddrinfo) + + Returns: + List of resolved IP addresses + + Raises: + SSRFError: If resolution fails + """ + loop = asyncio.get_running_loop() + try: + infos = await loop.run_in_executor( + None, + lambda: socket.getaddrinfo( + hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ), + ) + ips = list({info[4][0] for info in infos}) + if not ips: + raise SSRFError(f"DNS resolution returned no addresses for {hostname}") + return ips + except socket.gaierror as e: + raise SSRFError(f"DNS resolution failed for {hostname}: {e}") from e + + +@dataclass +class ValidatedURL: + """A URL that has been validated for SSRF with resolved IPs.""" + + original_url: str + hostname: str + port: int + path: str + resolved_ips: list[str] + + +async def validate_url(url: str, require_path: bool = False) -> ValidatedURL: + """Validate URL for SSRF and resolve to IPs. + + Args: + url: URL to validate + require_path: If True, require non-root path (for CIMD) + + Returns: + ValidatedURL with resolved IPs + + Raises: + SSRFError: If URL is invalid or resolves to blocked IPs + """ + try: + parsed = urlparse(url) + except (ValueError, AttributeError) as e: + raise SSRFError(f"Invalid URL: {e}") from e + + if parsed.scheme != "https": + raise SSRFError(f"URL must use HTTPS, got: {parsed.scheme}") + + if not parsed.netloc: + raise SSRFError("URL must have a host") + + if require_path and parsed.path in ("", "/"): + raise SSRFError("URL must have a non-root path") + + hostname = parsed.hostname or parsed.netloc + port = parsed.port or 443 + + # Resolve and validate IPs + resolved_ips = await resolve_hostname(hostname, port) + + blocked = [ip for ip in resolved_ips if not is_ip_allowed(ip)] + if blocked: + raise SSRFError( + f"URL resolves to blocked IP address(es): {blocked}. " + f"Private, loopback, link-local, and reserved IPs are not allowed." + ) + + return ValidatedURL( + original_url=url, + hostname=hostname, + port=port, + path=parsed.path + ("?" + parsed.query if parsed.query else ""), + resolved_ips=resolved_ips, + ) + + +async def ssrf_safe_fetch( + url: str, + *, + require_path: bool = False, + max_size: int = 5120, + timeout: float = 10.0, + overall_timeout: float = 30.0, +) -> bytes: + """Fetch URL with comprehensive SSRF protection and DNS pinning. + + Security measures: + 1. HTTPS only + 2. DNS resolution with IP validation + 3. Connects to validated IP directly (DNS pinning prevents rebinding) + 4. Response size limit + 5. Redirects disabled + 6. Overall timeout + + Args: + url: URL to fetch + require_path: If True, require non-root path + max_size: Maximum response size in bytes (default 5KB) + timeout: Per-operation timeout in seconds + overall_timeout: Overall timeout for entire operation + + Returns: + Response body as bytes + + Raises: + SSRFError: If SSRF validation fails + SSRFFetchError: If fetch fails + """ + start_time = time.monotonic() + + # Validate URL and resolve DNS + validated = await validate_url(url, require_path=require_path) + + last_error: Exception | None = None + + for pinned_ip in validated.resolved_ips: + elapsed = time.monotonic() - start_time + if elapsed > overall_timeout: + raise SSRFFetchError(f"Overall timeout exceeded: {url}") + remaining = max(1.0, overall_timeout - elapsed) + + pinned_url = ( + f"https://{format_ip_for_url(pinned_ip)}:{validated.port}{validated.path}" + ) + + logger.debug( + "SSRF-safe fetch: %s -> %s (pinned to %s)", + url, + pinned_url, + pinned_ip, + ) + + try: + # Use httpx with streaming to enforce size limit during download + async with ( + httpx.AsyncClient( + timeout=httpx.Timeout( + connect=min(timeout, remaining), + read=min(timeout, remaining), + write=min(timeout, remaining), + pool=min(timeout, remaining), + ), + follow_redirects=False, + verify=True, + ) as client, + client.stream( + "GET", + pinned_url, + headers={"Host": validated.hostname}, + extensions={"sni_hostname": validated.hostname}, + ) as response, + ): + if time.monotonic() - start_time > overall_timeout: + raise SSRFFetchError(f"Overall timeout exceeded: {url}") + + if response.status_code != 200: + raise SSRFFetchError(f"HTTP {response.status_code} fetching {url}") + + # Check Content-Length header first if available + content_length = response.headers.get("content-length") + if content_length: + try: + size = int(content_length) + if size > max_size: + raise SSRFFetchError( + f"Response too large: {size} bytes (max {max_size})" + ) + except ValueError: + pass + + # Stream the response and enforce size limit during download + chunks = [] + total = 0 + async for chunk in response.aiter_bytes(): + if time.monotonic() - start_time > overall_timeout: + raise SSRFFetchError(f"Overall timeout exceeded: {url}") + total += len(chunk) + if total > max_size: + raise SSRFFetchError( + f"Response too large: exceeded {max_size} bytes" + ) + chunks.append(chunk) + + return b"".join(chunks) + + except httpx.TimeoutException as e: + last_error = e + continue + except httpx.RequestError as e: + last_error = e + continue + + if last_error is not None: + if isinstance(last_error, httpx.TimeoutException): + raise SSRFFetchError(f"Timeout fetching {url}") from last_error + raise SSRFFetchError(f"Error fetching {url}: {last_error}") from last_error + + raise SSRFFetchError(f"Error fetching {url}: no resolved IPs succeeded") diff --git a/tests/cli/test_cimd_cli.py b/tests/cli/test_cimd_cli.py new file mode 100644 index 0000000000..301c440ede --- /dev/null +++ b/tests/cli/test_cimd_cli.py @@ -0,0 +1,208 @@ +"""Tests for the CIMD CLI commands (create and validate).""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, patch + +import pytest +from pydantic import AnyHttpUrl + +from fastmcp.cli.cimd import create_command, validate_command +from fastmcp.server.auth.cimd import CIMDDocument, CIMDFetchError, CIMDValidationError + + +class TestCIMDCreateCommand: + """Tests for `fastmcp auth cimd create`.""" + + def test_minimal_output(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + ) + doc = json.loads(capsys.readouterr().out) + assert doc["client_name"] == "Test App" + assert doc["redirect_uris"] == ["http://localhost:*/callback"] + assert doc["token_endpoint_auth_method"] == "none" + assert doc["grant_types"] == ["authorization_code"] + assert doc["response_types"] == ["code"] + # Placeholder client_id + assert "YOUR-DOMAIN" in doc["client_id"] + + def test_with_client_id(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + client_id="https://myapp.example.com/client.json", + ) + doc = json.loads(capsys.readouterr().out) + assert doc["client_id"] == "https://myapp.example.com/client.json" + + def test_with_output_file(self, tmp_path): + output_file = tmp_path / "client.json" + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + client_id="https://example.com/client.json", + output=str(output_file), + ) + doc = json.loads(output_file.read_text()) + assert doc["client_id"] == "https://example.com/client.json" + assert doc["client_name"] == "Test App" + + def test_relative_path_resolved(self, tmp_path, monkeypatch): + """Relative paths should be resolved against cwd.""" + monkeypatch.chdir(tmp_path) + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + output="./subdir/client.json", + ) + resolved = tmp_path / "subdir" / "client.json" + assert resolved.exists() + doc = json.loads(resolved.read_text()) + assert doc["client_name"] == "Test App" + + def test_with_scope(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + scope="read write", + ) + doc = json.loads(capsys.readouterr().out) + assert doc["scope"] == "read write" + + def test_with_client_uri(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + client_uri="https://example.com", + ) + doc = json.loads(capsys.readouterr().out) + assert doc["client_uri"] == "https://example.com" + + def test_with_logo_uri(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + logo_uri="https://example.com/logo.png", + ) + doc = json.loads(capsys.readouterr().out) + assert doc["logo_uri"] == "https://example.com/logo.png" + + def test_multiple_redirect_uris(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=[ + "http://localhost:*/callback", + "https://myapp.example.com/callback", + ], + ) + doc = json.loads(capsys.readouterr().out) + assert len(doc["redirect_uris"]) == 2 + + def test_no_pretty(self, capsys: pytest.CaptureFixture[str]): + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + pretty=False, + ) + output = capsys.readouterr().out.strip() + # Compact JSON has no newlines within the object + assert "\n" not in output + doc = json.loads(output) + assert doc["client_name"] == "Test App" + + def test_placeholder_warning_on_stderr(self, capsys: pytest.CaptureFixture[str]): + """When outputting to stdout with no --client-id, warning goes to stderr.""" + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + ) + captured = capsys.readouterr() + # stdout has valid JSON + json.loads(captured.out) + # stderr has the warning (Rich Console writes to stderr) + assert "placeholder" in captured.err + + def test_no_warning_with_client_id(self, capsys: pytest.CaptureFixture[str]): + """No placeholder warning when --client-id is provided.""" + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + client_id="https://example.com/client.json", + ) + captured = capsys.readouterr() + assert "placeholder" not in captured.err + + def test_optional_fields_omitted_when_none( + self, capsys: pytest.CaptureFixture[str] + ): + """Optional fields like scope, client_uri, logo_uri are omitted if not given.""" + create_command( + name="Test App", + redirect_uri=["http://localhost:*/callback"], + ) + doc = json.loads(capsys.readouterr().out) + assert "scope" not in doc + assert "client_uri" not in doc + assert "logo_uri" not in doc + + +class TestCIMDValidateCommand: + """Tests for `fastmcp auth cimd validate`.""" + + def test_invalid_url_format(self, capsys: pytest.CaptureFixture[str]): + with pytest.raises(SystemExit, match="1"): + validate_command("http://insecure.com/client.json") + captured = capsys.readouterr() + assert "Invalid CIMD URL" in captured.out + + def test_root_path_rejected(self, capsys: pytest.CaptureFixture[str]): + with pytest.raises(SystemExit, match="1"): + validate_command("https://example.com/") + captured = capsys.readouterr() + assert "Invalid CIMD URL" in captured.out + + def test_success(self, capsys: pytest.CaptureFixture[str]): + mock_doc = CIMDDocument( + client_id=AnyHttpUrl("https://myapp.example.com/client.json"), + client_name="Test App", + redirect_uris=["http://localhost:*/callback"], + token_endpoint_auth_method="none", + grant_types=["authorization_code"], + response_types=["code"], + ) + with patch.object(CIMDDocument, "__init__", return_value=None): + pass + mock_fetch = AsyncMock(return_value=mock_doc) + with patch( + "fastmcp.cli.cimd.CIMDFetcher.fetch", + mock_fetch, + ): + validate_command("https://myapp.example.com/client.json") + captured = capsys.readouterr() + assert "Valid CIMD document" in captured.out + assert "Test App" in captured.out + + def test_fetch_error(self, capsys: pytest.CaptureFixture[str]): + mock_fetch = AsyncMock(side_effect=CIMDFetchError("Connection refused")) + with patch( + "fastmcp.cli.cimd.CIMDFetcher.fetch", + mock_fetch, + ): + with pytest.raises(SystemExit, match="1"): + validate_command("https://myapp.example.com/client.json") + captured = capsys.readouterr() + assert "Failed to fetch" in captured.out + + def test_validation_error(self, capsys: pytest.CaptureFixture[str]): + mock_fetch = AsyncMock(side_effect=CIMDValidationError("client_id mismatch")) + with patch( + "fastmcp.cli.cimd.CIMDFetcher.fetch", + mock_fetch, + ): + with pytest.raises(SystemExit, match="1"): + validate_command("https://myapp.example.com/client.json") + captured = capsys.readouterr() + assert "Validation error" in captured.out diff --git a/tests/client/auth/test_oauth_cimd.py b/tests/client/auth/test_oauth_cimd.py new file mode 100644 index 0000000000..04818af605 --- /dev/null +++ b/tests/client/auth/test_oauth_cimd.py @@ -0,0 +1,164 @@ +"""Tests for CIMD (Client ID Metadata Document) support in the OAuth client.""" + +from __future__ import annotations + +import warnings + +import httpx +import pytest + +from fastmcp.client.auth import OAuth +from fastmcp.client.transports import StreamableHttpTransport +from fastmcp.client.transports.sse import SSETransport + +VALID_CIMD_URL = "https://myapp.example.com/oauth/client.json" +MCP_SERVER_URL = "https://mcp-server.example.com/mcp" + + +class TestOAuthClientMetadataURL: + """Tests for the client_metadata_url parameter on OAuth.""" + + def test_stored_on_instance(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + assert oauth._client_metadata_url == VALID_CIMD_URL + + def test_none_by_default(self): + oauth = OAuth() + assert oauth._client_metadata_url is None + + def test_passed_to_parent_on_bind(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + oauth._bind(MCP_SERVER_URL) + assert oauth.context.client_metadata_url == VALID_CIMD_URL + + def test_none_metadata_url_on_parent(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth = OAuth(mcp_url=MCP_SERVER_URL) + assert oauth.context.client_metadata_url is None + + def test_unbound_when_no_mcp_url(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + assert oauth._bound is False + + def test_bound_when_mcp_url_provided(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth = OAuth( + mcp_url=MCP_SERVER_URL, + client_metadata_url=VALID_CIMD_URL, + ) + assert oauth._bound is True + + def test_invalid_cimd_url_rejected(self): + """CIMD URLs must be HTTPS with a non-root path.""" + with pytest.raises(ValueError, match="valid HTTPS URL"): + OAuth( + mcp_url=MCP_SERVER_URL, + client_metadata_url="http://insecure.com/client.json", + ) + + def test_root_path_cimd_url_rejected(self): + with pytest.raises(ValueError, match="valid HTTPS URL"): + OAuth( + mcp_url=MCP_SERVER_URL, + client_metadata_url="https://example.com/", + ) + + +class TestOAuthBind: + """Tests for the _bind() deferred initialization.""" + + def test_bind_sets_bound_true(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + assert oauth._bound is False + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth._bind(MCP_SERVER_URL) + assert oauth._bound is True + + def test_bind_idempotent(self): + """Second call to _bind is a no-op.""" + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth._bind(MCP_SERVER_URL) + oauth._bind("https://other-server.example.com/mcp") + # First binding wins + assert oauth.mcp_url == MCP_SERVER_URL + + def test_bind_sets_mcp_url(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth._bind(MCP_SERVER_URL + "/") + # Trailing slash stripped + assert oauth.mcp_url == MCP_SERVER_URL + + def test_bind_creates_token_storage(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + assert not hasattr(oauth, "token_storage_adapter") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth._bind(MCP_SERVER_URL) + assert hasattr(oauth, "token_storage_adapter") + + async def test_unbound_raises_runtime_error(self): + """async_auth_flow should fail clearly when OAuth is not bound.""" + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + request = httpx.Request("GET", MCP_SERVER_URL) + with pytest.raises(RuntimeError, match="no server URL"): + async for _ in oauth.async_auth_flow(request): + pass + + def test_scopes_forwarded_as_list(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth = OAuth( + client_metadata_url=VALID_CIMD_URL, + scopes=["read", "write"], + ) + oauth._bind(MCP_SERVER_URL) + assert oauth.context.client_metadata.scope == "read write" + + def test_scopes_forwarded_as_string(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + oauth = OAuth( + client_metadata_url=VALID_CIMD_URL, + scopes="read write", + ) + oauth._bind(MCP_SERVER_URL) + assert oauth.context.client_metadata.scope == "read write" + + +class TestOAuthBindFromTransport: + """Tests that transports call _bind() on OAuth instances.""" + + def test_http_transport_binds_oauth(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + assert oauth._bound is False + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + StreamableHttpTransport(MCP_SERVER_URL, auth=oauth) + assert oauth._bound is True + assert oauth.mcp_url == MCP_SERVER_URL + + def test_sse_transport_binds_oauth(self): + oauth = OAuth(client_metadata_url=VALID_CIMD_URL) + assert oauth._bound is False + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + SSETransport(MCP_SERVER_URL, auth=oauth) + assert oauth._bound is True + assert oauth.mcp_url == MCP_SERVER_URL + + def test_http_transport_oauth_string_still_works(self): + """auth="oauth" should still create a new OAuth instance.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + transport = StreamableHttpTransport(MCP_SERVER_URL, auth="oauth") + assert isinstance(transport.auth, OAuth) + assert transport.auth._bound is True diff --git a/tests/server/auth/oauth_proxy/test_oauth_proxy.py b/tests/server/auth/oauth_proxy/test_oauth_proxy.py index 27c31e1988..b605e50fd2 100644 --- a/tests/server/auth/oauth_proxy/test_oauth_proxy.py +++ b/tests/server/auth/oauth_proxy/test_oauth_proxy.py @@ -1,6 +1,8 @@ """Tests for OAuth proxy initialization and configuration.""" +import httpx from key_value.aio.stores.memory import MemoryStore +from starlette.applications import Starlette from fastmcp.server.auth.oauth_proxy import OAuthProxy @@ -72,3 +74,29 @@ def test_redirect_path_normalization(self, jwt_verifier): client_storage=MemoryStore(), ) assert proxy._redirect_path == "/auth/callback" + + async def test_metadata_advertises_cimd_support(self, jwt_verifier): + """OAuth metadata should advertise CIMD support when enabled.""" + proxy = OAuthProxy( + upstream_authorization_endpoint="https://auth.example.com/authorize", + upstream_token_endpoint="https://auth.example.com/token", + upstream_client_id="client-123", + upstream_client_secret="secret-456", + token_verifier=jwt_verifier, + base_url="https://api.example.com", + jwt_signing_key="test-secret", + client_storage=MemoryStore(), + enable_cimd=True, + ) + + app = Starlette(routes=proxy.get_routes()) + transport = httpx.ASGITransport(app=app) + + async with httpx.AsyncClient( + transport=transport, base_url="https://api.example.com" + ) as client: + response = await client.get("/.well-known/oauth-authorization-server") + + assert response.status_code == 200 + metadata = response.json() + assert metadata.get("client_id_metadata_document_supported") is True diff --git a/tests/server/auth/test_cimd.py b/tests/server/auth/test_cimd.py new file mode 100644 index 0000000000..d3c3e316e2 --- /dev/null +++ b/tests/server/auth/test_cimd.py @@ -0,0 +1,971 @@ +"""Unit tests for CIMD (Client ID Metadata Document) functionality.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from pydantic import AnyHttpUrl, ValidationError + +from fastmcp.server.auth.cimd import ( + CIMDAssertionValidator, + CIMDClientManager, + CIMDDocument, + CIMDFetcher, + CIMDFetchError, + CIMDValidationError, +) +from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient + +# Standard public IP used for DNS mocking in tests +TEST_PUBLIC_IP = "93.184.216.34" + + +class TestCIMDDocument: + """Tests for CIMDDocument model validation.""" + + def test_valid_minimal_document(self): + """Test that minimal valid document passes validation.""" + doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + ) + assert str(doc.client_id) == "https://example.com/client.json" + assert doc.token_endpoint_auth_method == "none" + assert doc.grant_types == ["authorization_code"] + assert doc.response_types == ["code"] + + def test_valid_full_document(self): + """Test that full document passes validation.""" + doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + client_name="My App", + client_uri=AnyHttpUrl("https://example.com"), + logo_uri=AnyHttpUrl("https://example.com/logo.png"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="none", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="read write", + ) + assert doc.client_name == "My App" + assert doc.scope == "read write" + + def test_private_key_jwt_auth_method_allowed(self): + """Test that private_key_jwt is allowed for CIMD.""" + doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="private_key_jwt", + jwks_uri=AnyHttpUrl("https://example.com/.well-known/jwks.json"), + ) + assert doc.token_endpoint_auth_method == "private_key_jwt" + + def test_client_secret_basic_rejected(self): + """Test that client_secret_basic is rejected for CIMD.""" + with pytest.raises(ValidationError) as exc_info: + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="client_secret_basic", # type: ignore[arg-type] - testing invalid value + ) + # Literal type rejects invalid values before custom validator + assert "token_endpoint_auth_method" in str(exc_info.value) + + def test_client_secret_post_rejected(self): + """Test that client_secret_post is rejected for CIMD.""" + with pytest.raises(ValidationError) as exc_info: + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="client_secret_post", # type: ignore[arg-type] - testing invalid value + ) + assert "token_endpoint_auth_method" in str(exc_info.value) + + def test_client_secret_jwt_rejected(self): + """Test that client_secret_jwt is rejected for CIMD.""" + with pytest.raises(ValidationError) as exc_info: + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="client_secret_jwt", # type: ignore[arg-type] - testing invalid value + ) + assert "token_endpoint_auth_method" in str(exc_info.value) + + def test_missing_redirect_uris_rejected(self): + """Test that redirect_uris is required for CIMD.""" + with pytest.raises(ValidationError) as exc_info: + CIMDDocument(client_id=AnyHttpUrl("https://example.com/client.json")) + assert "redirect_uris" in str(exc_info.value) + + def test_empty_redirect_uris_rejected(self): + """Test that empty redirect_uris is rejected.""" + with pytest.raises(ValidationError) as exc_info: + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=[], + ) + assert "redirect_uris" in str(exc_info.value) + + def test_redirect_uri_without_scheme_rejected(self): + """Test that redirect_uris without a scheme are rejected.""" + with pytest.raises(ValidationError, match="must have a scheme"): + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["/just/a/path"], + ) + + def test_redirect_uri_without_host_rejected(self): + """Test that redirect_uris without a host are rejected.""" + with pytest.raises(ValidationError, match="must have a host"): + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://"], + ) + + def test_redirect_uri_whitespace_only_rejected(self): + """Test that whitespace-only redirect_uris are rejected.""" + with pytest.raises(ValidationError, match="non-empty"): + CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=[" "], + ) + + +class TestCIMDFetcher: + """Tests for CIMDFetcher.""" + + @pytest.fixture + def fetcher(self): + """Create a CIMDFetcher for testing.""" + return CIMDFetcher() + + def test_is_cimd_client_id_valid_urls(self, fetcher: CIMDFetcher): + """Test is_cimd_client_id accepts valid CIMD URLs.""" + assert fetcher.is_cimd_client_id("https://example.com/client.json") + assert fetcher.is_cimd_client_id("https://example.com/path/to/client") + assert fetcher.is_cimd_client_id("https://sub.example.com/cimd.json") + + def test_is_cimd_client_id_rejects_http(self, fetcher: CIMDFetcher): + """Test is_cimd_client_id rejects HTTP URLs.""" + assert not fetcher.is_cimd_client_id("http://example.com/client.json") + + def test_is_cimd_client_id_rejects_root_path(self, fetcher: CIMDFetcher): + """Test is_cimd_client_id rejects URLs with no path.""" + assert not fetcher.is_cimd_client_id("https://example.com/") + assert not fetcher.is_cimd_client_id("https://example.com") + + def test_is_cimd_client_id_rejects_non_url(self, fetcher: CIMDFetcher): + """Test is_cimd_client_id rejects non-URL strings.""" + assert not fetcher.is_cimd_client_id("client-123") + assert not fetcher.is_cimd_client_id("my-client") + assert not fetcher.is_cimd_client_id("") + assert not fetcher.is_cimd_client_id("not a url") + + def test_validate_redirect_uri_exact_match(self, fetcher: CIMDFetcher): + """Test redirect_uri validation with exact match.""" + doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + ) + assert fetcher.validate_redirect_uri(doc, "http://localhost:3000/callback") + assert not fetcher.validate_redirect_uri(doc, "http://localhost:4000/callback") + + def test_validate_redirect_uri_wildcard_match(self, fetcher: CIMDFetcher): + """Test redirect_uri validation with wildcard port.""" + doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:*/callback"], + ) + assert fetcher.validate_redirect_uri(doc, "http://localhost:3000/callback") + assert fetcher.validate_redirect_uri(doc, "http://localhost:8080/callback") + assert not fetcher.validate_redirect_uri(doc, "http://localhost:3000/other") + + +class TestCIMDFetcherHTTP: + """Tests for CIMDFetcher HTTP fetching (using httpx mock). + + Note: With SSRF protection and DNS pinning, HTTP requests go to the resolved IP + instead of the hostname. These tests mock DNS resolution to return a public IP + and configure httpx_mock to expect the IP-based URL. + """ + + @pytest.fixture + def fetcher(self): + """Create a CIMDFetcher for testing.""" + return CIMDFetcher() + + @pytest.fixture + def mock_dns(self): + """Mock DNS resolution to return test public IP.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + yield + + async def test_fetch_success(self, fetcher: CIMDFetcher, httpx_mock, mock_dns): + """Test successful CIMD document fetch.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + + # With DNS pinning, request goes to IP. Match any URL. + httpx_mock.add_response( + json=doc_data, + headers={ + "content-type": "application/json", + "content-length": "200", + }, + ) + + doc = await fetcher.fetch(url) + assert str(doc.client_id) == url + assert doc.client_name == "Test App" + + async def test_fetch_ttl_cache(self, fetcher: CIMDFetcher, httpx_mock, mock_dns): + """Test that fetched documents are cached and served from cache within TTL.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + + assert first.client_id == second.client_id + assert len(httpx_mock.get_requests()) == 1 + + async def test_fetch_client_id_mismatch( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Test that client_id mismatch is rejected.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": "https://other.com/client.json", # Different URL + "client_name": "Test App", + "redirect_uris": ["http://localhost:3000/callback"], + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "100"}, + ) + + with pytest.raises(CIMDValidationError) as exc_info: + await fetcher.fetch(url) + assert "mismatch" in str(exc_info.value).lower() + + async def test_fetch_http_error(self, fetcher: CIMDFetcher, httpx_mock, mock_dns): + """Test handling of HTTP errors.""" + url = "https://example.com/client.json" + httpx_mock.add_response(status_code=404) + + with pytest.raises(CIMDFetchError) as exc_info: + await fetcher.fetch(url) + assert "404" in str(exc_info.value) + + async def test_fetch_invalid_json(self, fetcher: CIMDFetcher, httpx_mock, mock_dns): + """Test handling of invalid JSON response.""" + url = "https://example.com/client.json" + httpx_mock.add_response( + content=b"not json", + headers={"content-length": "10"}, + ) + + with pytest.raises(CIMDValidationError) as exc_info: + await fetcher.fetch(url) + assert "JSON" in str(exc_info.value) + + async def test_fetch_invalid_document( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Test handling of invalid CIMD document.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "client_secret_basic", # Not allowed + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "100"}, + ) + + with pytest.raises(CIMDValidationError) as exc_info: + await fetcher.fetch(url) + assert "Invalid CIMD document" in str(exc_info.value) + + +class TestCIMDAssertionValidator: + """Tests for CIMDAssertionValidator (private_key_jwt support).""" + + @pytest.fixture + def validator(self): + """Create a CIMDAssertionValidator for testing.""" + return CIMDAssertionValidator() + + @pytest.fixture + def key_pair(self): + """Generate RSA key pair for testing.""" + from fastmcp.server.auth.providers.jwt import RSAKeyPair + + return RSAKeyPair.generate() + + @pytest.fixture + def jwks(self, key_pair): + """Create JWKS from key pair.""" + import base64 + + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + # Load public key + public_key = serialization.load_pem_public_key( + key_pair.public_key.encode(), backend=default_backend() + ) + + # Get RSA public numbers + from cryptography.hazmat.primitives.asymmetric import rsa + + if isinstance(public_key, rsa.RSAPublicKey): + numbers = public_key.public_numbers() + + # Convert to JWK format + return { + "keys": [ + { + "kty": "RSA", + "kid": "test-key-1", + "use": "sig", + "alg": "RS256", + "n": base64.urlsafe_b64encode( + numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big") + ) + .rstrip(b"=") + .decode(), + "e": base64.urlsafe_b64encode( + numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big") + ) + .rstrip(b"=") + .decode(), + } + ] + } + + @pytest.fixture + def cimd_doc_with_jwks_uri(self): + """Create CIMD document with jwks_uri.""" + return CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="private_key_jwt", + jwks_uri=AnyHttpUrl("https://example.com/.well-known/jwks.json"), + ) + + @pytest.fixture + def cimd_doc_with_inline_jwks(self, jwks): + """Create CIMD document with inline JWKS.""" + return CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="private_key_jwt", + jwks=jwks, + ) + + async def test_valid_assertion_with_jwks_uri( + self, validator, key_pair, cimd_doc_with_jwks_uri, httpx_mock + ): + """Test that valid JWT assertion passes validation (jwks_uri).""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Mock JWKS endpoint + import base64 + + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + public_key = serialization.load_pem_public_key( + key_pair.public_key.encode(), backend=default_backend() + ) + from cryptography.hazmat.primitives.asymmetric import rsa + + assert isinstance(public_key, rsa.RSAPublicKey) + numbers = public_key.public_numbers() + + jwks = { + "keys": [ + { + "kty": "RSA", + "kid": "test-key-1", + "use": "sig", + "alg": "RS256", + "n": base64.urlsafe_b64encode( + numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big") + ) + .rstrip(b"=") + .decode(), + "e": base64.urlsafe_b64encode( + numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big") + ) + .rstrip(b"=") + .decode(), + } + ] + } + + # Mock DNS resolution for SSRF-safe fetch + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + httpx_mock.add_response(json=jwks) + + # Create valid assertion (use short lifetime for security compliance) + assertion = key_pair.create_token( + subject=client_id, + issuer=client_id, + audience=token_endpoint, + additional_claims={"jti": "unique-jti-123"}, + expires_in_seconds=60, # 1 minute (max allowed is 300s) + kid="test-key-1", + ) + + # Should validate successfully + assert await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_jwks_uri + ) + + async def test_valid_assertion_with_inline_jwks( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that valid JWT assertion passes validation (inline JWKS).""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create valid assertion (use short lifetime for security compliance) + assertion = key_pair.create_token( + subject=client_id, + issuer=client_id, + audience=token_endpoint, + additional_claims={"jti": "unique-jti-456"}, + expires_in_seconds=60, # 1 minute (max allowed is 300s) + kid="test-key-1", + ) + + # Should validate successfully + assert await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + + async def test_rejects_wrong_issuer( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that wrong issuer is rejected.""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create assertion with wrong issuer + assertion = key_pair.create_token( + subject=client_id, + issuer="https://attacker.com", # Wrong! + audience=token_endpoint, + additional_claims={"jti": "unique-jti-789"}, + expires_in_seconds=60, + kid="test-key-1", + ) + + with pytest.raises(ValueError) as exc_info: + await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + assert "Invalid JWT assertion" in str(exc_info.value) + + async def test_rejects_wrong_audience( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that wrong audience is rejected.""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create assertion with wrong audience + assertion = key_pair.create_token( + subject=client_id, + issuer=client_id, + audience="https://wrong-endpoint.com/token", # Wrong! + additional_claims={"jti": "unique-jti-abc"}, + expires_in_seconds=60, + kid="test-key-1", + ) + + with pytest.raises(ValueError) as exc_info: + await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + assert "Invalid JWT assertion" in str(exc_info.value) + + async def test_rejects_wrong_subject( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that wrong subject claim is rejected.""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create assertion with wrong subject + assertion = key_pair.create_token( + subject="https://different-client.com", # Wrong! + issuer=client_id, + audience=token_endpoint, + additional_claims={"jti": "unique-jti-def"}, + expires_in_seconds=60, + kid="test-key-1", + ) + + with pytest.raises(ValueError) as exc_info: + await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + assert "sub claim must be" in str(exc_info.value) + + async def test_rejects_missing_jti( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that missing jti claim is rejected.""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create assertion without jti + assertion = key_pair.create_token( + subject=client_id, + issuer=client_id, + audience=token_endpoint, + # No jti! + expires_in_seconds=60, + kid="test-key-1", + ) + + with pytest.raises(ValueError) as exc_info: + await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + assert "jti claim" in str(exc_info.value) + + async def test_rejects_replayed_jti( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that replayed JTI is detected and rejected.""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create assertion + assertion = key_pair.create_token( + subject=client_id, + issuer=client_id, + audience=token_endpoint, + additional_claims={"jti": "replayed-jti"}, + expires_in_seconds=60, + kid="test-key-1", + ) + + # First use should succeed + assert await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + + # Second use with same jti should fail (replay attack) + with pytest.raises(ValueError) as exc_info: + await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + assert "replay" in str(exc_info.value).lower() + + async def test_rejects_expired_token( + self, validator, key_pair, cimd_doc_with_inline_jwks + ): + """Test that expired tokens are rejected.""" + client_id = "https://example.com/client.json" + token_endpoint = "https://oauth.example.com/token" + + # Create expired assertion (expired 1 hour ago) + assertion = key_pair.create_token( + subject=client_id, + issuer=client_id, + audience=token_endpoint, + additional_claims={"jti": "expired-jti"}, + expires_in_seconds=-3600, # Negative = expired + kid="test-key-1", + ) + + with pytest.raises(ValueError) as exc_info: + await validator.validate_assertion( + assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks + ) + assert "Invalid JWT assertion" in str(exc_info.value) + + +class TestCIMDClientManager: + """Tests for CIMDClientManager.""" + + @pytest.fixture + def manager(self): + """Create a CIMDClientManager for testing.""" + return CIMDClientManager(enable_cimd=True) + + @pytest.fixture + def disabled_manager(self): + """Create a disabled CIMDClientManager for testing.""" + return CIMDClientManager(enable_cimd=False) + + @pytest.fixture + def mock_dns(self): + """Mock DNS resolution to return test public IP.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + yield + + def test_is_cimd_client_id_enabled(self, manager): + """Test CIMD URL detection when enabled.""" + assert manager.is_cimd_client_id("https://example.com/client.json") + assert not manager.is_cimd_client_id("regular-client-id") + + def test_is_cimd_client_id_disabled(self, disabled_manager): + """Test CIMD URL detection when disabled.""" + assert not disabled_manager.is_cimd_client_id("https://example.com/client.json") + assert not disabled_manager.is_cimd_client_id("regular-client-id") + + async def test_get_client_success(self, manager, httpx_mock, mock_dns): + """Test successful CIMD client creation.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + client = await manager.get_client(url) + assert client is not None + assert client.client_id == url + assert client.client_name == "Test App" + # Verify it uses proxy's patterns (None by default), not document's redirect_uris + assert client.allowed_redirect_uri_patterns is None + + async def test_get_client_disabled(self, disabled_manager): + """Test that get_client returns None when disabled.""" + client = await disabled_manager.get_client("https://example.com/client.json") + assert client is None + + async def test_get_client_fetch_failure(self, manager, httpx_mock, mock_dns): + """Test that get_client returns None on fetch failure.""" + url = "https://example.com/client.json" + httpx_mock.add_response(status_code=404) + + client = await manager.get_client(url) + assert client is None + + # Trust policy and consent bypass tests removed - functionality removed from CIMD + + +class TestCIMDClientManagerGetClientOptions: + """Tests for CIMDClientManager.get_client with default_scope and allowed patterns.""" + + @pytest.fixture + def mock_dns(self): + """Mock DNS resolution to return test public IP.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + yield + + async def test_default_scope_applied_when_doc_has_no_scope( + self, httpx_mock, mock_dns + ): + """When the CIMD document omits scope, the manager's default_scope is used.""" + + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + # No scope field + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + manager = CIMDClientManager( + enable_cimd=True, + default_scope="read write admin", + ) + client = await manager.get_client(url) + assert client is not None + assert client.scope == "read write admin" + + async def test_doc_scope_takes_precedence_over_default(self, httpx_mock, mock_dns): + """When the CIMD document specifies scope, it wins over the default.""" + + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + "scope": "custom-scope", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + manager = CIMDClientManager( + enable_cimd=True, + default_scope="default-scope", + ) + client = await manager.get_client(url) + assert client is not None + assert client.scope == "custom-scope" + + async def test_allowed_redirect_uri_patterns_stored_on_client( + self, httpx_mock, mock_dns + ): + """Proxy's allowed_redirect_uri_patterns are forwarded to the created client.""" + + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + "redirect_uris": ["http://localhost:*/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + patterns = ["http://localhost:*", "https://app.example.com/*"] + manager = CIMDClientManager( + enable_cimd=True, + allowed_redirect_uri_patterns=patterns, + ) + client = await manager.get_client(url) + assert client is not None + assert client.allowed_redirect_uri_patterns == patterns + + async def test_cimd_document_attached_to_client(self, httpx_mock, mock_dns): + """The fetched CIMDDocument is attached to the created client.""" + + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Attached Doc App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + manager = CIMDClientManager(enable_cimd=True) + client = await manager.get_client(url) + assert client is not None + assert client.cimd_document is not None + assert client.cimd_document.client_name == "Attached Doc App" + assert str(client.cimd_document.client_id) == url + + +class TestCIMDClientManagerValidatePrivateKeyJwt: + """Tests for CIMDClientManager.validate_private_key_jwt wrapper.""" + + @pytest.fixture + def manager(self): + return CIMDClientManager(enable_cimd=True) + + async def test_missing_cimd_document_raises(self, manager): + """validate_private_key_jwt raises ValueError if client has no cimd_document.""" + + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=None, + ) + with pytest.raises(ValueError, match="must have CIMD document"): + await manager.validate_private_key_jwt( + "fake.jwt.token", + client, + "https://oauth.example.com/token", + ) + + async def test_wrong_auth_method_raises(self, manager): + """validate_private_key_jwt raises ValueError if auth method is not private_key_jwt.""" + + cimd_doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="none", # Not private_key_jwt + ) + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=cimd_doc, + ) + with pytest.raises(ValueError, match="private_key_jwt"): + await manager.validate_private_key_jwt( + "fake.jwt.token", + client, + "https://oauth.example.com/token", + ) + + async def test_success_delegates_to_assertion_validator(self, manager): + """On success, validate_private_key_jwt delegates to the assertion validator.""" + + cimd_doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + token_endpoint_auth_method="private_key_jwt", + jwks_uri=AnyHttpUrl("https://example.com/.well-known/jwks.json"), + ) + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=cimd_doc, + ) + + manager._assertion_validator.validate_assertion = AsyncMock(return_value=True) + + result = await manager.validate_private_key_jwt( + "test.jwt.assertion", + client, + "https://oauth.example.com/token", + ) + assert result is True + manager._assertion_validator.validate_assertion.assert_awaited_once_with( + "test.jwt.assertion", + "https://example.com/client.json", + "https://oauth.example.com/token", + cimd_doc, + ) + + +class TestCIMDRedirectUriEnforcement: + """Tests for CIMD redirect_uri validation security. + + Verifies that CIMD clients enforce BOTH: + 1. CIMD document's redirect_uris + 2. Proxy's allowed_redirect_uri_patterns + """ + + @pytest.fixture + def mock_dns(self): + """Mock DNS resolution to return test public IP.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + yield + + async def test_cimd_redirect_uris_enforced(self, httpx_mock, mock_dns): + """Test that CIMD document redirect_uris are enforced. + + Even if proxy patterns allow http://localhost:*, a CIMD client + should only accept URIs declared in its document. + """ + from mcp.shared.auth import InvalidRedirectUriError + from pydantic import AnyUrl + + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + # CIMD only declares port 3000 + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + # Proxy allows any localhost port + manager = CIMDClientManager( + enable_cimd=True, + allowed_redirect_uri_patterns=["http://localhost:*"], + ) + client = await manager.get_client(url) + assert client is not None + + # Declared URI should work + validated = client.validate_redirect_uri( + AnyUrl("http://localhost:3000/callback") + ) + assert str(validated) == "http://localhost:3000/callback" + + # Different port should fail (not in CIMD redirect_uris) + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(AnyUrl("http://localhost:4000/callback")) + + async def test_proxy_patterns_also_checked(self, httpx_mock, mock_dns): + """Test that proxy patterns are checked even for CIMD clients. + + A CIMD client should not be able to use a redirect_uri that's + in its document but not allowed by proxy patterns. + """ + from mcp.shared.auth import InvalidRedirectUriError + from pydantic import AnyUrl + + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Test App", + # CIMD declares both localhost and external URI + "redirect_uris": [ + "http://localhost:3000/callback", + "https://evil.com/callback", + ], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + # Proxy only allows localhost + manager = CIMDClientManager( + enable_cimd=True, + allowed_redirect_uri_patterns=["http://localhost:*"], + ) + client = await manager.get_client(url) + assert client is not None + + # Localhost should work (in CIMD and matches pattern) + validated = client.validate_redirect_uri( + AnyUrl("http://localhost:3000/callback") + ) + assert str(validated) == "http://localhost:3000/callback" + + # Evil.com should fail (in CIMD but doesn't match proxy patterns) + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(AnyUrl("https://evil.com/callback")) diff --git a/tests/server/auth/test_jwt_provider.py b/tests/server/auth/test_jwt_provider.py index 14a81299e8..bced42a1f5 100644 --- a/tests/server/auth/test_jwt_provider.py +++ b/tests/server/auth/test_jwt_provider.py @@ -1,5 +1,6 @@ from collections.abc import AsyncGenerator from typing import Any +from unittest.mock import patch import httpx import pytest @@ -10,6 +11,9 @@ from fastmcp.server.auth.providers.jwt import JWKData, JWKSData, JWTVerifier, RSAKeyPair from fastmcp.utilities.tests import run_server_async +# Standard public IP used for DNS mocking in tests +TEST_PUBLIC_IP = "93.184.216.34" + class SymmetricKeyHelper: """Helper class for generating symmetric key JWT tokens for testing.""" @@ -378,7 +382,11 @@ async def test_symmetric_token_algorithm_mismatch( class TestBearerTokenJWKS: - """Tests for JWKS URI functionality.""" + """Tests for JWKS URI functionality. + + Note: With SSRF protection, JWKS fetches validate DNS and connect to the + resolved IP. Tests mock DNS resolution to return a public IP. + """ @pytest.fixture def jwks_provider(self, rsa_key_pair: RSAKeyPair) -> JWTVerifier: @@ -402,18 +410,25 @@ def mock_jwks_data(self, rsa_key_pair: RSAKeyPair) -> JWKSData: return {"keys": [jwk_data]} + @pytest.fixture + def mock_dns(self): + """Mock DNS resolution to return test public IP.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + yield + async def test_jwks_token_validation( self, rsa_key_pair: RSAKeyPair, jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): """Test token validation using JWKS URI.""" - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) username = "test-user" issuer = "https://test.example.com" @@ -440,11 +455,9 @@ async def test_jwks_token_validation_with_invalid_key( jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) token = RSAKeyPair.generate().create_token( subject="test-user", issuer="https://test.example.com", @@ -460,12 +473,10 @@ async def test_jwks_token_validation_with_kid( jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): mock_jwks_data["keys"][0]["kid"] = "test-key-1" - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) token = rsa_key_pair.create_token( subject="test-user", issuer="https://test.example.com", @@ -483,12 +494,10 @@ async def test_jwks_token_validation_with_kid_and_no_kid_in_token( jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): mock_jwks_data["keys"][0]["kid"] = "test-key-1" - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) token = rsa_key_pair.create_token( subject="test-user", issuer="https://test.example.com", @@ -505,12 +514,10 @@ async def test_jwks_token_validation_with_no_kid_and_kid_in_jwks( jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): mock_jwks_data["keys"][0]["kid"] = "test-key-1" - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) token = rsa_key_pair.create_token( subject="test-user", issuer="https://test.example.com", @@ -527,12 +534,10 @@ async def test_jwks_token_validation_with_kid_mismatch( jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): mock_jwks_data["keys"][0]["kid"] = "test-key-1" - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) token = rsa_key_pair.create_token( subject="test-user", issuer="https://test.example.com", @@ -549,6 +554,7 @@ async def test_jwks_token_validation_with_multiple_keys_and_no_kid_in_token( jwks_provider: JWTVerifier, mock_jwks_data: JWKSData, httpx_mock: HTTPXMock, + mock_dns, ): mock_jwks_data["keys"] = [ # type: ignore[typeddict-item] { @@ -561,10 +567,7 @@ async def test_jwks_token_validation_with_multiple_keys_and_no_kid_in_token( }, ] - httpx_mock.add_response( - url="https://test.example.com/.well-known/jwks.json", - json=mock_jwks_data, - ) + httpx_mock.add_response(json=mock_jwks_data) token = rsa_key_pair.create_token( subject="test-user", issuer="https://test.example.com", diff --git a/tests/server/auth/test_oauth_proxy_redirect_validation.py b/tests/server/auth/test_oauth_proxy_redirect_validation.py index 20d4afdd75..391977b886 100644 --- a/tests/server/auth/test_oauth_proxy_redirect_validation.py +++ b/tests/server/auth/test_oauth_proxy_redirect_validation.py @@ -1,14 +1,20 @@ """Tests for OAuth proxy redirect URI validation.""" +from unittest.mock import patch + import pytest from key_value.aio.stores.memory import MemoryStore from mcp.shared.auth import InvalidRedirectUriError -from pydantic import AnyUrl +from pydantic import AnyHttpUrl, AnyUrl from fastmcp.server.auth.auth import TokenVerifier +from fastmcp.server.auth.cimd import CIMDDocument from fastmcp.server.auth.oauth_proxy import OAuthProxy from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient +# Standard public IP used for DNS mocking in tests +TEST_PUBLIC_IP = "93.184.216.34" + class MockTokenVerifier(TokenVerifier): """Mock token verifier for testing.""" @@ -133,6 +139,38 @@ def test_none_redirect_uri(self): result = client.validate_redirect_uri(None) assert result == AnyUrl("http://localhost:3000") + def test_cimd_none_redirect_uri_single_exact(self): + """CIMD clients may omit redirect_uri only when a single exact URI exists.""" + cimd_doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + ) + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=cimd_doc, + ) + + result = client.validate_redirect_uri(None) + assert result == AnyUrl("http://localhost:3000/callback") + + def test_cimd_none_redirect_uri_wildcard_rejected(self): + """CIMD clients must specify redirect_uri when only wildcard patterns exist.""" + cimd_doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:*/callback"], + ) + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=cimd_doc, + ) + + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(None) + class TestOAuthProxyRedirectValidation: """Test OAuth proxy with redirect URI validation.""" @@ -240,3 +278,90 @@ async def test_proxy_unregistered_client_returns_none(self): # Get an unregistered client client = await proxy.get_client("unknown-client") assert client is None + + +class TestOAuthProxyCIMDClient: + """Test that CIMD clients obtained via proxy carry their document and apply dual validation.""" + + @pytest.fixture + def mock_dns(self): + """Mock DNS resolution to return test public IP.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[TEST_PUBLIC_IP], + ): + yield + + async def test_proxy_get_client_returns_cimd_client(self, httpx_mock, mock_dns): + """CIMD client obtained via proxy's get_client has cimd_document attached.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "CIMD App", + "redirect_uris": ["http://localhost:*/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + proxy = OAuthProxy( + upstream_authorization_endpoint="https://auth.example.com/authorize", + upstream_token_endpoint="https://auth.example.com/token", + upstream_client_id="test-client", + upstream_client_secret="test-secret", + token_verifier=MockTokenVerifier(), + base_url="http://localhost:8000", + jwt_signing_key="test-secret", + client_storage=MemoryStore(), + ) + + client = await proxy.get_client(url) + assert isinstance(client, ProxyDCRClient) + assert client.cimd_document is not None + assert client.cimd_document.client_name == "CIMD App" + assert client.client_id == url + + async def test_proxy_cimd_dual_redirect_validation(self, httpx_mock, mock_dns): + """CIMD client from proxy enforces both CIMD redirect_uris and proxy patterns.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Dual Validation App", + "redirect_uris": [ + "http://localhost:3000/callback", + "https://evil.com/callback", + ], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"content-length": "200"}, + ) + + proxy = OAuthProxy( + upstream_authorization_endpoint="https://auth.example.com/authorize", + upstream_token_endpoint="https://auth.example.com/token", + upstream_client_id="test-client", + upstream_client_secret="test-secret", + token_verifier=MockTokenVerifier(), + base_url="http://localhost:8000", + allowed_client_redirect_uris=["http://localhost:*"], + jwt_signing_key="test-secret", + client_storage=MemoryStore(), + ) + + client = await proxy.get_client(url) + assert client is not None + + # In CIMD AND matches proxy pattern → accepted + assert client.validate_redirect_uri(AnyUrl("http://localhost:3000/callback")) + + # In CIMD but NOT in proxy pattern → rejected + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(AnyUrl("https://evil.com/callback")) + + # NOT in CIMD but matches proxy pattern → rejected + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(AnyUrl("http://localhost:9999/other")) diff --git a/tests/server/auth/test_oauth_proxy_storage.py b/tests/server/auth/test_oauth_proxy_storage.py index cc1808c3f6..7c898823e9 100644 --- a/tests/server/auth/test_oauth_proxy_storage.py +++ b/tests/server/auth/test_oauth_proxy_storage.py @@ -112,7 +112,7 @@ async def test_nonexistent_client_returns_none( async def test_proxy_dcr_client_redirect_validation( self, jwt_verifier: TokenVerifier, temp_storage: AsyncKeyValue ): - """Test that ProxyDCRClient is created with redirect URI patterns.""" + """Test that OAuthProxyClient is created with redirect URI patterns.""" proxy = OAuthProxy( upstream_authorization_endpoint="https://github.com/login/oauth/authorize", upstream_token_endpoint="https://github.com/login/oauth/access_token", @@ -132,11 +132,11 @@ async def test_proxy_dcr_client_redirect_validation( ) await proxy.register_client(client_info) - # Get client back - should be ProxyDCRClient + # Get client back - should be OAuthProxyClient client = await proxy.get_client("test-proxy-client") assert client is not None - # ProxyDCRClient should validate dynamic localhost ports + # OAuthProxyClient should validate dynamic localhost ports validated = client.validate_redirect_uri( AnyUrl("http://localhost:12345/callback") ) @@ -205,5 +205,7 @@ async def test_storage_data_structure(self, jwt_verifier, temp_storage): "client_id_issued_at": None, "client_secret_expires_at": None, "allowed_redirect_uri_patterns": None, + "cimd_document": None, + "cimd_fetched_at": None, } ) diff --git a/tests/server/auth/test_oidc_proxy.py b/tests/server/auth/test_oidc_proxy.py index b8e373e401..e3dd95e1c1 100644 --- a/tests/server/auth/test_oidc_proxy.py +++ b/tests/server/auth/test_oidc_proxy.py @@ -15,10 +15,10 @@ TEST_AUTHORIZATION_ENDPOINT = "https://example.com/authorize" TEST_TOKEN_ENDPOINT = "https://example.com/oauth/token" -TEST_CONFIG_URL = "https://example.com/.well-known/openid-configuration" +TEST_CONFIG_URL = AnyHttpUrl("https://example.com/.well-known/openid-configuration") TEST_CLIENT_ID = "test-client-id" TEST_CLIENT_SECRET = "test-client-secret" -TEST_BASE_URL = "https://example.com:8000/" +TEST_BASE_URL = AnyHttpUrl("https://example.com:8000/") # ============================================================================= @@ -366,7 +366,7 @@ def validate_get_oidc_configuration(oidc_configuration, strict, timeout_seconds) mock_get.return_value = mock_response config = OIDCConfiguration.get_oidc_configuration( - config_url=AnyHttpUrl(TEST_CONFIG_URL), + config_url=TEST_CONFIG_URL, strict=strict, timeout_seconds=timeout_seconds, ) @@ -376,7 +376,7 @@ def validate_get_oidc_configuration(oidc_configuration, strict, timeout_seconds) mock_get.assert_called_once() call_args = mock_get.call_args - assert call_args[0][0] == TEST_CONFIG_URL + assert str(call_args[0][0]) == str(TEST_CONFIG_URL) return call_args @@ -415,7 +415,7 @@ def test_get_oidc_configuration_not_strict( mock_get.return_value = mock_response OIDCConfiguration.get_oidc_configuration( - config_url=AnyHttpUrl(TEST_CONFIG_URL), + config_url=TEST_CONFIG_URL, strict=False, timeout_seconds=10, ) @@ -423,7 +423,7 @@ def test_get_oidc_configuration_not_strict( mock_get.assert_called_once() call_args = mock_get.call_args - assert call_args[0][0] == TEST_CONFIG_URL + assert str(call_args[0][0]) == str(TEST_CONFIG_URL) def validate_proxy(mock_get, proxy, oidc_config): @@ -431,13 +431,13 @@ def validate_proxy(mock_get, proxy, oidc_config): mock_get.assert_called_once() call_args = mock_get.call_args - assert str(call_args[0][0]) == TEST_CONFIG_URL + assert str(call_args[0][0]) == str(TEST_CONFIG_URL) assert proxy._upstream_authorization_endpoint == TEST_AUTHORIZATION_ENDPOINT assert proxy._upstream_token_endpoint == TEST_TOKEN_ENDPOINT assert proxy._upstream_client_id == TEST_CLIENT_ID assert proxy._upstream_client_secret.get_secret_value() == TEST_CLIENT_SECRET - assert str(proxy.base_url) == TEST_BASE_URL + assert str(proxy.base_url) == str(TEST_BASE_URL) assert proxy.oidc_config == oidc_config diff --git a/tests/server/auth/test_redirect_validation.py b/tests/server/auth/test_redirect_validation.py index 87071a91fc..10945d2fb4 100644 --- a/tests/server/auth/test_redirect_validation.py +++ b/tests/server/auth/test_redirect_validation.py @@ -109,6 +109,65 @@ def test_anyurl_conversion(self): assert not validate_redirect_uri(uri, patterns) +class TestSecurityBypass: + """Test protection against redirect URI security bypass attacks.""" + + def test_userinfo_bypass_blocked(self): + """Test that userinfo-style bypasses are blocked. + + Attack: http://localhost@evil.com/callback would match http://localhost:* + with naive string matching, but actually points to evil.com. + """ + pattern = "http://localhost:*" + + # These should be blocked - the "host" is actually in the userinfo + assert not matches_allowed_pattern( + "http://localhost@evil.com/callback", pattern + ) + assert not matches_allowed_pattern( + "http://localhost:3000@malicious.io/callback", pattern + ) + assert not matches_allowed_pattern( + "http://user:pass@localhost:3000/callback", pattern + ) + + def test_userinfo_bypass_with_subdomain_pattern(self): + """Test userinfo bypass with subdomain wildcard patterns.""" + pattern = "https://*.example.com/callback" + + # Blocked: userinfo tricks + assert not matches_allowed_pattern( + "https://app.example.com@attacker.com/callback", pattern + ) + assert not matches_allowed_pattern( + "https://user:pass@app.example.com/callback", pattern + ) + + def test_legitimate_uris_still_work(self): + """Test that legitimate URIs work after security hardening.""" + pattern = "http://localhost:*" + assert matches_allowed_pattern("http://localhost:3000/callback", pattern) + assert matches_allowed_pattern("http://localhost:8080/auth", pattern) + + pattern = "https://*.example.com/callback" + assert matches_allowed_pattern("https://app.example.com/callback", pattern) + + def test_scheme_mismatch_blocked(self): + """Test that scheme mismatches are blocked.""" + assert not matches_allowed_pattern( + "http://localhost:3000/callback", "https://localhost:*" + ) + assert not matches_allowed_pattern( + "https://localhost:3000/callback", "http://localhost:*" + ) + + def test_host_mismatch_blocked(self): + """Test that host mismatches are blocked even with wildcards.""" + pattern = "http://localhost:*" + assert not matches_allowed_pattern("http://127.0.0.1:3000/callback", pattern) + assert not matches_allowed_pattern("http://example.com:3000/callback", pattern) + + class TestDefaultPatterns: """Test the default localhost patterns constant.""" diff --git a/tests/server/auth/test_ssrf_protection.py b/tests/server/auth/test_ssrf_protection.py new file mode 100644 index 0000000000..79cf926a09 --- /dev/null +++ b/tests/server/auth/test_ssrf_protection.py @@ -0,0 +1,447 @@ +"""Tests for SSRF-safe HTTP utilities. + +This module tests the ssrf.py module which provides SSRF-protected HTTP fetching. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from fastmcp.server.auth.ssrf import ( + SSRFError, + SSRFFetchError, + is_ip_allowed, + ssrf_safe_fetch, + validate_url, +) + + +class TestIsIPAllowed: + """Tests for is_ip_allowed function.""" + + def test_public_ipv4_allowed(self): + """Public IPv4 addresses should be allowed.""" + assert is_ip_allowed("8.8.8.8") is True + assert is_ip_allowed("1.1.1.1") is True + assert is_ip_allowed("93.184.216.34") is True + + def test_private_ipv4_blocked(self): + """Private IPv4 addresses should be blocked.""" + assert is_ip_allowed("192.168.1.1") is False + assert is_ip_allowed("10.0.0.1") is False + assert is_ip_allowed("172.16.0.1") is False + + def test_loopback_blocked(self): + """Loopback addresses should be blocked.""" + assert is_ip_allowed("127.0.0.1") is False + assert is_ip_allowed("::1") is False + + def test_link_local_blocked(self): + """Link-local addresses (AWS metadata) should be blocked.""" + assert is_ip_allowed("169.254.169.254") is False + + def test_rfc6598_cgnat_blocked(self): + """RFC6598 Carrier-Grade NAT addresses should be blocked.""" + assert is_ip_allowed("100.64.0.1") is False + assert is_ip_allowed("100.100.100.100") is False + + def test_ipv4_mapped_ipv6_blocked_if_private(self): + """IPv4-mapped IPv6 addresses should check the embedded IPv4.""" + assert is_ip_allowed("::ffff:127.0.0.1") is False + assert is_ip_allowed("::ffff:192.168.1.1") is False + + +class TestValidateURL: + """Tests for validate_url function.""" + + async def test_http_rejected(self): + """HTTP URLs should be rejected (HTTPS required).""" + with pytest.raises(SSRFError, match="must use HTTPS"): + await validate_url("http://example.com/path") + + async def test_missing_host_rejected(self): + """URLs without host should be rejected.""" + with pytest.raises(SSRFError, match="must have a host"): + await validate_url("https:///path") + + async def test_root_path_rejected_when_required(self): + """Root paths should be rejected when require_path=True.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["93.184.216.34"], + ): + with pytest.raises(SSRFError, match="non-root path"): + await validate_url("https://example.com/", require_path=True) + + async def test_private_ip_rejected(self): + """URLs resolving to private IPs should be rejected.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["192.168.1.1"], + ): + with pytest.raises(SSRFError, match="blocked IP"): + await validate_url("https://example.com/path") + + +class TestSSRFSafeFetch: + """Tests for ssrf_safe_fetch function.""" + + async def test_private_ip_blocked(self): + """Fetch to private IP should be blocked.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["192.168.1.1"], + ): + with pytest.raises(SSRFError, match="blocked IP"): + await ssrf_safe_fetch("https://internal.example.com/api") + + async def test_cgnat_blocked(self): + """Fetch to RFC6598 CGNAT IP should be blocked.""" + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["100.64.0.1"], + ): + with pytest.raises(SSRFError, match="blocked IP"): + await ssrf_safe_fetch("https://cgnat.example.com/api") + + async def test_connects_to_pinned_ip(self): + """Verify connection uses pinned IP, not re-resolved DNS.""" + resolved_ip = "93.184.216.34" + + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[resolved_ip], + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {"content-length": "15"} + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + async def aiter_bytes(): + yield b'{"data": "test"}' + + mock_stream.aiter_bytes = aiter_bytes + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_stream) + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + await ssrf_safe_fetch("https://example.com/api") + + # Verify URL contains pinned IP + call_args = mock_client.stream.call_args + url_called = call_args[0][1] + assert resolved_ip in url_called + + async def test_fallback_to_second_ip(self): + """If the first IP fails, the next resolved IP should be tried.""" + resolved_ips = ["2001:4860:4860::8888", "93.184.216.34"] + + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=resolved_ips, + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + request = httpx.Request("GET", "https://example.com/api") + + first_client = AsyncMock() + first_client.stream = MagicMock( + side_effect=httpx.RequestError("boom", request=request) + ) + first_client.__aenter__.return_value = first_client + first_client.__aexit__ = AsyncMock(return_value=None) + + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {"content-length": "2"} + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + async def aiter_bytes(): + yield b"ok" + + mock_stream.aiter_bytes = aiter_bytes + + second_client = AsyncMock() + second_client.stream = MagicMock(return_value=mock_stream) + second_client.__aenter__.return_value = second_client + second_client.__aexit__ = AsyncMock(return_value=None) + + mock_client_class.side_effect = [first_client, second_client] + + content = await ssrf_safe_fetch("https://example.com/api") + assert content == b"ok" + + call_args = second_client.stream.call_args + url_called = call_args[0][1] + assert resolved_ips[1] in url_called + + async def test_host_header_set(self): + """Verify Host header is set to original hostname.""" + resolved_ip = "93.184.216.34" + original_host = "example.com" + + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[resolved_ip], + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {"content-length": "15"} + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + async def aiter_bytes(): + yield b'{"data": "test"}' + + mock_stream.aiter_bytes = aiter_bytes + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_stream) + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + await ssrf_safe_fetch(f"https://{original_host}/api") + + # Verify Host header + call_kwargs = mock_client.stream.call_args[1] + assert call_kwargs["headers"]["Host"] == original_host + + async def test_response_size_limit(self): + """Verify response size limit is enforced via streaming.""" + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["93.184.216.34"], + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + # Response larger than default 5KB (no Content-Length, so streaming enforces) + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {} # No Content-Length to force streaming check + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + async def aiter_bytes(): + # Yield 10KB total + for _ in range(10): + yield b"x" * 1024 + + mock_stream.aiter_bytes = aiter_bytes + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_stream) + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + with pytest.raises(SSRFFetchError, match="too large"): + await ssrf_safe_fetch("https://example.com/api") + + +class TestJWKSSSRFProtection: + """Tests for SSRF protection in JWTVerifier JWKS fetching.""" + + async def test_jwks_private_ip_blocked(self): + """JWKS fetch to private IP should be blocked.""" + from fastmcp.server.auth.providers.jwt import JWTVerifier + + verifier = JWTVerifier( + jwks_uri="https://internal.example.com/.well-known/jwks.json", + issuer="https://issuer.example.com", + ssrf_safe=True, + ) + + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["192.168.1.1"], + ): + with pytest.raises(ValueError, match="Failed to fetch JWKS"): + # Create a dummy token to trigger JWKS fetch + await verifier._get_jwks_key("test-kid") + + async def test_jwks_cgnat_blocked(self): + """JWKS fetch to RFC6598 CGNAT IP should be blocked.""" + from fastmcp.server.auth.providers.jwt import JWTVerifier + + verifier = JWTVerifier( + jwks_uri="https://cgnat.example.com/.well-known/jwks.json", + issuer="https://issuer.example.com", + ssrf_safe=True, + ) + + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["100.64.0.1"], + ): + with pytest.raises(ValueError, match="Failed to fetch JWKS"): + await verifier._get_jwks_key("test-kid") + + async def test_jwks_loopback_blocked(self): + """JWKS fetch to loopback should be blocked.""" + from fastmcp.server.auth.providers.jwt import JWTVerifier + + verifier = JWTVerifier( + jwks_uri="https://localhost/.well-known/jwks.json", + issuer="https://issuer.example.com", + ssrf_safe=True, + ) + + with patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["127.0.0.1"], + ): + with pytest.raises(ValueError, match="Failed to fetch JWKS"): + await verifier._get_jwks_key("test-kid") + + +class TestIPv6URLFormatting: + """Tests for proper IPv6 address bracketing in URLs.""" + + def test_format_ip_for_url_ipv4(self): + """IPv4 addresses should not be bracketed.""" + from fastmcp.server.auth.ssrf import format_ip_for_url + + assert format_ip_for_url("8.8.8.8") == "8.8.8.8" + assert format_ip_for_url("192.168.1.1") == "192.168.1.1" + + def test_format_ip_for_url_ipv6(self): + """IPv6 addresses should be bracketed for URL use.""" + from fastmcp.server.auth.ssrf import format_ip_for_url + + assert format_ip_for_url("2001:db8::1") == "[2001:db8::1]" + assert format_ip_for_url("::1") == "[::1]" + assert format_ip_for_url("fe80::1") == "[fe80::1]" + + def test_format_ip_for_url_invalid(self): + """Invalid IP strings should be returned unchanged.""" + from fastmcp.server.auth.ssrf import format_ip_for_url + + assert format_ip_for_url("not-an-ip") == "not-an-ip" + assert format_ip_for_url("") == "" + + async def test_ipv6_pinned_url_is_valid(self): + """Verify IPv6 addresses are properly bracketed in pinned URLs.""" + resolved_ipv6 = "2001:4860:4860::8888" + + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=[resolved_ipv6], + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {"content-length": "10"} + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + async def aiter_bytes(): + yield b'{"key": 1}' + + mock_stream.aiter_bytes = aiter_bytes + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_stream) + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + await ssrf_safe_fetch("https://example.com/api") + + # Verify the URL contains bracketed IPv6 address + call_args = mock_client.stream.call_args + url_called = call_args[0][1] + + # IPv6 should be bracketed: https://[2001:4860:4860::8888]:443/path + assert f"[{resolved_ipv6}]" in url_called, ( + f"Expected bracketed IPv6 [{resolved_ipv6}] in URL, got {url_called}" + ) + + +class TestStreamingResponseSizeLimit: + """Tests for streaming-based response size enforcement.""" + + async def test_size_limit_enforced_during_streaming(self): + """Verify that size limit is enforced as chunks are received, not after.""" + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["93.184.216.34"], + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + chunks_yielded = [] + + async def aiter_bytes(): + # Yield chunks that exceed the limit + for i in range(10): + chunk = b"x" * 1024 # 1KB per chunk + chunks_yielded.append(chunk) + yield chunk + + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {} # No content-length to force streaming check + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + mock_stream.aiter_bytes = aiter_bytes + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_stream) + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + with pytest.raises(SSRFFetchError, match="too large"): + await ssrf_safe_fetch("https://example.com/api", max_size=5120) + + # Verify we stopped after exceeding the limit (should be ~6 chunks for 5KB limit) + # This confirms we're enforcing during streaming, not after downloading all + assert len(chunks_yielded) <= 7, ( + f"Downloaded {len(chunks_yielded)} chunks (expected <=7 for streaming enforcement)" + ) + + async def test_content_length_header_checked_first(self): + """Verify Content-Length header is checked before streaming.""" + with ( + patch( + "fastmcp.server.auth.ssrf.resolve_hostname", + return_value=["93.184.216.34"], + ), + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_stream = MagicMock() + mock_stream.status_code = 200 + mock_stream.headers = {"content-length": "10240"} # 10KB + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + # aiter_bytes should never be called if Content-Length is checked + mock_stream.aiter_bytes = MagicMock( + side_effect=AssertionError("Should not stream") + ) + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_stream) + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + with pytest.raises(SSRFFetchError, match="too large"): + await ssrf_safe_fetch("https://example.com/api", max_size=5120) diff --git a/tests/utilities/openapi/test_models.py b/tests/utilities/openapi/test_models.py index cc4baadb37..4361635c24 100644 --- a/tests/utilities/openapi/test_models.py +++ b/tests/utilities/openapi/test_models.py @@ -4,8 +4,10 @@ from inline_snapshot import snapshot from fastmcp.utilities.openapi.models import ( + HttpMethod, HTTPRoute, ParameterInfo, + ParameterLocation, RequestBodyInfo, ResponseInfo, ) @@ -51,7 +53,7 @@ def test_parameter_with_all_fields(self): assert param.style == "deepObject" @pytest.mark.parametrize("location", ["path", "query", "header", "cookie"]) - def test_valid_parameter_locations(self, location): + def test_valid_parameter_locations(self, location: ParameterLocation): """Test that all valid parameter locations are accepted.""" param = ParameterInfo( name="test", @@ -286,7 +288,7 @@ def test_route_pre_calculated_fields(self): @pytest.mark.parametrize( "method", ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] ) - def test_valid_http_methods(self, method): + def test_valid_http_methods(self, method: HttpMethod): """Test that all valid HTTP methods are accepted.""" route = HTTPRoute( path="/test",