/callback`) acts as the `redirect_uri` for the OAuth flow.
@@ -115,10 +120,7 @@ encrypted_storage = FernetEncryptionWrapper(
fernet=Fernet(os.environ["OAUTH_STORAGE_ENCRYPTION_KEY"])
)
-oauth = OAuth(
- mcp_url="https://your-server.fastmcp.app/mcp",
- token_storage=encrypted_storage
-)
+oauth = OAuth(token_storage=encrypted_storage)
async with Client("https://your-server.fastmcp.app/mcp", auth=oauth) as client:
await client.ping()
@@ -129,3 +131,24 @@ You can use any `AsyncKeyValue`-compatible backend from the [key-value library](
When selecting a storage backend, review the [py-key-value documentation](https://github.com/strawgate/py-key-value) to understand the maturity level and limitations of your chosen backend. Some backends may be in preview or have constraints that affect production suitability.
+
+## CIMD Authentication
+
+
+
+Client ID Metadata Documents (CIMD) provide an alternative to Dynamic Client Registration. Instead of registering with each server, your client hosts a static JSON document at an HTTPS URL. That URL becomes your client's identity, and servers can verify who you are through your domain ownership.
+
+```python
+from fastmcp import Client
+from fastmcp.client.auth import OAuth
+
+async with Client(
+ "https://mcp-server.example.com/mcp",
+ auth=OAuth(
+ client_metadata_url="https://myapp.example.com/oauth/client.json",
+ ),
+) as client:
+ await client.ping()
+```
+
+See the [CIMD Authentication](/clients/auth/cimd) page for complete documentation on creating, hosting, and validating CIMD documents.
diff --git a/docs/development/v3-notes/v3-features.mdx b/docs/development/v3-notes/v3-features.mdx
index 2fa032e6eb..05399fce69 100644
--- a/docs/development/v3-notes/v3-features.mdx
+++ b/docs/development/v3-notes/v3-features.mdx
@@ -73,6 +73,53 @@ fastmcp install stdio server.py
The command automatically detects the project directory and generates the appropriate `uv run` invocation, making it easy to integrate FastMCP servers with MCP clients.
+### CIMD (Client ID Metadata Documents)
+
+CIMD provides an alternative to Dynamic Client Registration for OAuth-authenticated MCP servers. Instead of registering with each server dynamically, clients host a static JSON document at an HTTPS URL. That URL becomes the client's `client_id`, and servers verify identity through domain ownership.
+
+**Client usage:**
+
+```python
+from fastmcp import Client
+from fastmcp.client.auth import OAuth
+
+async with Client(
+ "https://mcp-server.example.com/mcp",
+ auth=OAuth(
+ client_metadata_url="https://myapp.example.com/oauth/client.json",
+ ),
+) as client:
+ await client.ping()
+```
+
+The `OAuth` helper now supports deferred binding — `mcp_url` is optional when using `OAuth` with `Client(auth=...)`, since the transport provides the server URL automatically.
+
+**CLI tools for document management:**
+
+```bash
+# Generate a CIMD document
+fastmcp auth cimd create --name "My App" \
+ --redirect-uri "http://localhost:*/callback" \
+ --client-id "https://myapp.example.com/oauth/client.json" \
+ --output client.json
+
+# Validate a hosted document
+fastmcp auth cimd validate https://myapp.example.com/oauth/client.json
+```
+
+**Server-side support:**
+
+CIMD is enabled by default on `OAuthProxy` and its provider subclasses (GitHub, Google, etc.). The server-side implementation includes SSRF-hardened document fetching with DNS pinning, dual redirect URI validation (both CIMD document patterns and proxy patterns must match), HTTP cache-aware revalidation, and `private_key_jwt` assertion validation for clients that need stronger authentication than public client auth.
+
+Key details:
+- CIMD URLs must be HTTPS with a non-root path
+- `token_endpoint_auth_method` limited to `none` or `private_key_jwt` (no shared secrets)
+- `redirect_uris` in CIMD documents support wildcard port patterns (`http://localhost:*/callback`)
+- Servers fetch and cache documents with standard HTTP caching (ETag, Last-Modified, Cache-Control)
+- CIMD is a protocol-level feature — any auth provider implementing the spec can support it
+
+Documentation: [CIMD Authentication](/clients/auth/cimd), [OAuth Proxy CIMD config](/servers/auth/oauth-proxy#cimd-support)
+
### MCP Apps (SDK Compatibility)
Support for [MCP Apps](https://modelcontextprotocol.io/specification/2025-06-18/server/apps) — the spec extension that lets MCP servers deliver interactive UIs via sandboxed iframes. Extension negotiation, typed UI metadata on tools and resources, and the `ui://` resource scheme. No component DSL, renderer, or `FastMCPApp` class yet — those are future phases.
diff --git a/docs/docs.json b/docs/docs.json
index f20343ca1e..b9c1bc4dd6 100644
--- a/docs/docs.json
+++ b/docs/docs.json
@@ -198,6 +198,7 @@
"icon": "key",
"pages": [
"clients/auth/oauth",
+ "clients/auth/cimd",
"clients/auth/bearer"
]
}
diff --git a/docs/patterns/cli.mdx b/docs/patterns/cli.mdx
index 72badb8703..8ccf9d01f2 100644
--- a/docs/patterns/cli.mdx
+++ b/docs/patterns/cli.mdx
@@ -25,6 +25,7 @@ fastmcp --help
| `install` | Install a server in MCP client applications | **Supports:** Local files and fastmcp.json configs. **Deps:** Creates an isolated environment; dependencies must be explicitly specified with `--with` and/or `--with-editable`. With fastmcp.json: Uses configured dependencies |
| `inspect` | Generate a JSON report about a FastMCP server | **Supports:** Local files and fastmcp.json configs. **Deps:** Uses your current environment; you are responsible for ensuring all dependencies are available |
| `project prepare` | Create a persistent uv project from fastmcp.json environment config | **Supports:** fastmcp.json configs only. **Deps:** Creates a uv project directory with all dependencies pre-installed for reuse with `--project` flag |
+| `auth cimd` | Create and validate CIMD documents for OAuth authentication | N/A |
| `version` | Display version information | N/A |
## `fastmcp list`
@@ -750,6 +751,87 @@ The prepare command creates a uv project with:
This is useful when you want to separate environment setup from server execution, such as in deployment scenarios where dependencies are installed once and the server is run multiple times.
+## `fastmcp auth`
+
+
+
+Authentication-related utilities and configuration commands.
+
+### `fastmcp auth cimd create`
+
+Generate a CIMD (Client ID Metadata Document) for hosting. This creates a JSON document that you can host at an HTTPS URL to use as your OAuth client identity.
+
+```bash
+fastmcp auth cimd create --name "My App" --redirect-uri "http://localhost:*/callback"
+```
+
+#### Options
+
+| Option | Flag | Description |
+| ------ | ---- | ----------- |
+| Name | `--name` | **Required.** Human-readable name of the client application |
+| Redirect URI | `--redirect-uri` | **Required.** Allowed redirect URIs (can specify multiple) |
+| Client URI | `--client-uri` | URL of the client's home page |
+| Logo URI | `--logo-uri` | URL of the client's logo image |
+| Scope | `--scope` | Space-separated list of scopes the client may request |
+| Output | `--output`, `-o` | Output file path (default: stdout) |
+| Pretty | `--pretty` | Pretty-print JSON output (default: true) |
+
+#### Example
+
+```bash
+# Generate document to stdout
+fastmcp auth cimd create \
+ --name "My Production App" \
+ --redirect-uri "http://localhost:*/callback" \
+ --redirect-uri "https://myapp.example.com/callback" \
+ --client-uri "https://myapp.example.com" \
+ --scope "read write"
+
+# Save to file
+fastmcp auth cimd create \
+ --name "My App" \
+ --redirect-uri "http://localhost:*/callback" \
+ --output client.json
+```
+
+The generated document includes a placeholder `client_id` that you must update to match the URL where you'll host the document before deploying.
+
+### `fastmcp auth cimd validate`
+
+Validate a hosted CIMD document by fetching it from its URL and checking that it conforms to the CIMD specification.
+
+```bash
+fastmcp auth cimd validate https://myapp.example.com/oauth/client.json
+```
+
+#### Options
+
+| Option | Flag | Description |
+| ------ | ---- | ----------- |
+| Timeout | `--timeout`, `-t` | HTTP request timeout in seconds (default: 10) |
+
+The validator checks:
+
+- The URL is a valid CIMD URL (HTTPS with non-root path)
+- The document is valid JSON and conforms to the CIMD schema
+- The `client_id` field in the document matches the URL
+- No shared-secret authentication methods are used
+
+On success, it displays the document details:
+
+```
+→ Fetching https://myapp.example.com/oauth/client.json...
+✓ Valid CIMD document
+
+Document details:
+ client_id: https://myapp.example.com/oauth/client.json
+ client_name: My App
+ token_endpoint_auth_method: none
+ redirect_uris:
+ • http://localhost:*/callback
+```
+
## `fastmcp version`
Display version information about FastMCP and related components.
diff --git a/docs/servers/auth/oauth-proxy.mdx b/docs/servers/auth/oauth-proxy.mdx
index 3c5b328a76..86a8865f30 100644
--- a/docs/servers/auth/oauth-proxy.mdx
+++ b/docs/servers/auth/oauth-proxy.mdx
@@ -524,6 +524,74 @@ auth = OAuthProxy(
Check your server logs for "Client registered with redirect_uri" messages to identify what URLs your clients use.
+## CIMD Support
+
+
+
+The OAuth proxy supports **Client ID Metadata Documents (CIMD)**, an alternative to Dynamic Client Registration where clients host a static JSON document at an HTTPS URL. Instead of registering dynamically, clients simply provide their CIMD URL as their `client_id`, and the server fetches and validates the metadata.
+
+CIMD clients appear in the consent screen with a verified domain badge, giving users confidence about which application is requesting access. This provides stronger identity verification than DCR, where any client can claim any name.
+
+### How CIMD Works
+
+When a client presents an HTTPS URL as its `client_id` (for example, `https://myapp.example.com/oauth/client.json`), the OAuth proxy recognizes it as a CIMD client and:
+
+1. Fetches the JSON document from that URL
+2. Validates that the document's `client_id` field matches the URL
+3. Extracts client metadata (name, redirect URIs, scopes, etc.)
+4. Stores the client persistently alongside DCR clients
+5. Shows the verified domain in the consent screen
+
+This flow happens transparently. MCP clients that support CIMD simply provide their metadata URL instead of registering, and the OAuth proxy handles the rest.
+
+### CIMD Configuration
+
+CIMD support is enabled by default for `OAuthProxy`.
+
+
+
+ Whether to accept CIMD URLs as client identifiers. When enabled, clients can use HTTPS URLs pointing to metadata documents as their `client_id` instead of registering via DCR.
+
+
+
+### Private Key JWT Authentication
+
+CIMD clients can authenticate using `private_key_jwt` instead of the default `none` authentication method. This provides cryptographic proof of client identity by signing JWT assertions with a private key, while the server verifies using the client's public key from their CIMD document.
+
+To use `private_key_jwt`, the CIMD document must include either a `jwks_uri` (URL to fetch the public key set) or inline `jwks` (the key set directly in the document):
+
+```json
+{
+ "client_id": "https://myapp.example.com/oauth/client.json",
+ "client_name": "My Secure App",
+ "redirect_uris": ["http://localhost:*/callback"],
+ "token_endpoint_auth_method": "private_key_jwt",
+ "jwks_uri": "https://myapp.example.com/.well-known/jwks.json"
+}
+```
+
+The OAuth proxy validates JWT assertions according to RFC 7523, checking the signature, issuer, audience, subject claims, and preventing replay attacks via JTI tracking.
+
+### Security Considerations
+
+CIMD provides several security advantages over DCR:
+
+- **Verified identity**: The domain in the `client_id` URL is verified by HTTPS, so users know which organization is requesting access
+- **No registration required**: Clients don't need to store or manage dynamically-issued credentials
+- **Redirect URI enforcement**: CIMD documents must declare `redirect_uris`, which are enforced by the proxy (wildcard patterns supported)
+- **SSRF protection**: The OAuth proxy blocks fetches to localhost, private IPs, and reserved addresses
+- **Replay prevention**: For `private_key_jwt` clients, JTI claims are tracked to prevent assertion replay
+- **Cache-aware fetching**: CIMD documents are cached according to HTTP cache headers and revalidated when required
+
+To disable CIMD support entirely (for example, to require all clients to register via DCR):
+
+```python
+auth = OAuthProxy(
+ ...,
+ enable_cimd=False,
+)
+```
+
## Security
### Key and Storage Management
diff --git a/docs/servers/auth/oidc-proxy.mdx b/docs/servers/auth/oidc-proxy.mdx
index 6941c7f79f..86661bc16d 100644
--- a/docs/servers/auth/oidc-proxy.mdx
+++ b/docs/servers/auth/oidc-proxy.mdx
@@ -232,6 +232,22 @@ OAuth scopes are configured with `required_scopes` to automatically request the
Dynamic clients created by the proxy will automatically include these scopes in their authorization requests.
+## CIMD Support
+
+
+
+The OIDC proxy inherits full CIMD (Client ID Metadata Document) support from `OAuthProxy`. Clients can use HTTPS URLs as their `client_id` instead of registering dynamically, and the proxy will fetch and validate their metadata document.
+
+See the [OAuth Proxy CIMD documentation](/servers/auth/oauth-proxy#cimd-support) for complete details on how CIMD works, including private key JWT authentication and security considerations.
+
+The CIMD-related parameters available on `OIDCProxy` are:
+
+
+
+ Whether to accept CIMD URLs as client identifiers.
+
+
+
## Production Configuration
For production deployments, load sensitive credentials from environment variables:
diff --git a/examples/auth/github_oauth/client.py b/examples/auth/github_oauth/client.py
index 5f1f39bb2f..a7ab5c47ee 100644
--- a/examples/auth/github_oauth/client.py
+++ b/examples/auth/github_oauth/client.py
@@ -8,14 +8,20 @@
import asyncio
-from fastmcp.client import Client
+from fastmcp.client import Client, OAuth
-SERVER_URL = "http://127.0.0.1:8000/mcp"
+SERVER_URL = "http://localhost:8000/mcp"
async def main():
try:
- async with Client(SERVER_URL, auth="oauth") as client:
+ async with Client(
+ SERVER_URL,
+ auth=OAuth(
+ # Replace with your own CIMD document URL
+ client_metadata_url="https://www.jlowin.dev/mcp-client.json",
+ ),
+ ) as client:
assert await client.ping()
print("✅ Successfully authenticated!")
diff --git a/loq.toml b/loq.toml
index 4c6d7e28ba..bbfe81827b 100644
--- a/loq.toml
+++ b/loq.toml
@@ -76,7 +76,7 @@ max_lines = 1584
[[rules]]
path = "src/fastmcp/server/auth/oauth_proxy/proxy.py"
-max_lines = 1600
+max_lines = 1740
[[rules]]
path = "tests/server/test_dependencies.py"
diff --git a/src/fastmcp/cli/auth.py b/src/fastmcp/cli/auth.py
new file mode 100644
index 0000000000..4ea401b041
--- /dev/null
+++ b/src/fastmcp/cli/auth.py
@@ -0,0 +1,13 @@
+"""Authentication-related CLI commands."""
+
+import cyclopts
+
+from fastmcp.cli.cimd import cimd_app
+
+auth_app = cyclopts.App(
+ name="auth",
+ help="Authentication-related utilities and configuration.",
+)
+
+# Nest CIMD commands under auth
+auth_app.command(cimd_app)
diff --git a/src/fastmcp/cli/cimd.py b/src/fastmcp/cli/cimd.py
new file mode 100644
index 0000000000..d2def490cf
--- /dev/null
+++ b/src/fastmcp/cli/cimd.py
@@ -0,0 +1,218 @@
+"""CIMD (Client ID Metadata Document) CLI commands."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import sys
+from pathlib import Path
+from typing import Annotated
+
+import cyclopts
+from rich.console import Console
+
+from fastmcp.server.auth.cimd import (
+ CIMDFetcher,
+ CIMDFetchError,
+ CIMDValidationError,
+)
+from fastmcp.utilities.logging import get_logger
+
+logger = get_logger("cli.cimd")
+console = Console()
+
+
+cimd_app = cyclopts.App(
+ name="cimd",
+ help="CIMD (Client ID Metadata Document) utilities for OAuth authentication.",
+)
+
+
+@cimd_app.command(name="create")
+def create_command(
+ *,
+ name: Annotated[
+ str,
+ cyclopts.Parameter(help="Human-readable name of the client application"),
+ ],
+ redirect_uri: Annotated[
+ list[str],
+ cyclopts.Parameter(
+ name=["--redirect-uri", "-r"],
+ help="Allowed redirect URIs (can specify multiple)",
+ ),
+ ],
+ client_id: Annotated[
+ str | None,
+ cyclopts.Parameter(
+ name="--client-id",
+ help="The URL where this document will be hosted (sets client_id directly)",
+ ),
+ ] = None,
+ client_uri: Annotated[
+ str | None,
+ cyclopts.Parameter(
+ name="--client-uri",
+ help="URL of the client's home page",
+ ),
+ ] = None,
+ logo_uri: Annotated[
+ str | None,
+ cyclopts.Parameter(
+ name="--logo-uri",
+ help="URL of the client's logo image",
+ ),
+ ] = None,
+ scope: Annotated[
+ str | None,
+ cyclopts.Parameter(
+ name="--scope",
+ help="Space-separated list of scopes the client may request",
+ ),
+ ] = None,
+ output: Annotated[
+ str | None,
+ cyclopts.Parameter(
+ name=["--output", "-o"],
+ help="Output file path (default: stdout)",
+ ),
+ ] = None,
+ pretty: Annotated[
+ bool,
+ cyclopts.Parameter(
+ help="Pretty-print JSON output",
+ ),
+ ] = True,
+) -> None:
+ """Generate a CIMD document for hosting.
+
+ Create a Client ID Metadata Document that you can host at an HTTPS URL.
+ The URL where you host this document becomes your client_id.
+
+ Example:
+ fastmcp cimd create --name "My App" -r "http://localhost:*/callback"
+
+ After creating the document, host it at an HTTPS URL with a non-root path,
+ for example: https://myapp.example.com/oauth/client.json
+ """
+ # Build the document
+ doc = {
+ "client_id": client_id or "https://YOUR-DOMAIN.com/path/to/client.json",
+ "client_name": name,
+ "redirect_uris": redirect_uri,
+ "token_endpoint_auth_method": "none",
+ "grant_types": ["authorization_code"],
+ "response_types": ["code"],
+ }
+
+ # Add optional fields
+ if client_uri:
+ doc["client_uri"] = client_uri
+ if logo_uri:
+ doc["logo_uri"] = logo_uri
+ if scope:
+ doc["scope"] = scope
+
+ # Format output
+ json_output = json.dumps(doc, indent=2) if pretty else json.dumps(doc)
+
+ # Write output
+ if output:
+ output_path = Path(output).expanduser().resolve()
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(output_path, "w") as f:
+ f.write(json_output)
+ f.write("\n")
+ console.print(f"[green]✓[/green] CIMD document written to {output}")
+ if not client_id:
+ console.print(
+ "\n[yellow]Important:[/yellow] client_id is a placeholder. Update it to the URL where you will host this document, or re-run with --client-id."
+ )
+ else:
+ print(json_output)
+ if not client_id:
+ # Print instructions to stderr so they don't interfere with piping
+ stderr_console = Console(stderr=True)
+ stderr_console.print(
+ "\n[yellow]Important:[/yellow] client_id is a placeholder."
+ " Update it to the URL where you will host this document,"
+ " or re-run with --client-id."
+ )
+
+
+@cimd_app.command(name="validate")
+def validate_command(
+ url: Annotated[
+ str,
+ cyclopts.Parameter(help="URL of the CIMD document to validate"),
+ ],
+ *,
+ timeout: Annotated[
+ float,
+ cyclopts.Parameter(
+ name=["--timeout", "-t"],
+ help="HTTP request timeout in seconds",
+ ),
+ ] = 10.0,
+) -> None:
+ """Validate a hosted CIMD document.
+
+ Fetches the document from the given URL and validates:
+ - URL is valid CIMD URL (HTTPS, non-root path)
+ - Document is valid JSON
+ - Document conforms to CIMD schema
+ - client_id in document matches the URL
+
+ Example:
+ fastmcp cimd validate https://myapp.example.com/oauth/client.json
+ """
+
+ async def _validate() -> bool:
+ fetcher = CIMDFetcher(timeout=timeout)
+
+ # Check URL format first
+ if not fetcher.is_cimd_client_id(url):
+ console.print(f"[red]✗[/red] Invalid CIMD URL: {url}")
+ console.print()
+ console.print("CIMD URLs must:")
+ console.print(" • Use HTTPS (not HTTP)")
+ console.print(" • Have a non-root path (e.g., /client.json, not just /)")
+ return False
+
+ console.print(f"[blue]→[/blue] Fetching {url}...")
+
+ try:
+ doc = await fetcher.fetch(url)
+ except CIMDFetchError as e:
+ console.print(f"[red]✗[/red] Failed to fetch document: {e}")
+ return False
+ except CIMDValidationError as e:
+ console.print(f"[red]✗[/red] Validation error: {e}")
+ return False
+
+ # Success - show document details
+ console.print("[green]✓[/green] Valid CIMD document")
+ console.print()
+ console.print("[bold]Document details:[/bold]")
+ console.print(f" client_id: {doc.client_id}")
+ console.print(f" client_name: {doc.client_name or '(not set)'}")
+ console.print(f" token_endpoint_auth_method: {doc.token_endpoint_auth_method}")
+
+ if doc.redirect_uris:
+ console.print(" redirect_uris:")
+ for uri in doc.redirect_uris:
+ console.print(f" • {uri}")
+ else:
+ console.print(" redirect_uris: (none)")
+
+ if doc.scope:
+ console.print(f" scope: {doc.scope}")
+
+ if doc.client_uri:
+ console.print(f" client_uri: {doc.client_uri}")
+
+ return True
+
+ success = asyncio.run(_validate())
+ if not success:
+ sys.exit(1)
diff --git a/src/fastmcp/cli/cli.py b/src/fastmcp/cli/cli.py
index 147d9dc2d8..5b9c24e4a8 100644
--- a/src/fastmcp/cli/cli.py
+++ b/src/fastmcp/cli/cli.py
@@ -19,6 +19,7 @@
import fastmcp
from fastmcp.cli import run as run_module
+from fastmcp.cli.auth import auth_app
from fastmcp.cli.client import call_command, discover_command, list_command
from fastmcp.cli.generate import generate_cli_command
from fastmcp.cli.install import install_app
@@ -960,6 +961,9 @@ async def prepare(
app.command(discover_command, name="discover")
app.command(generate_cli_command, name="generate-cli")
+# Add auth subcommand group (includes CIMD commands)
+app.command(auth_app)
+
if __name__ == "__main__":
app()
diff --git a/src/fastmcp/client/auth/oauth.py b/src/fastmcp/client/auth/oauth.py
index 393844d073..9fc90b4e89 100644
--- a/src/fastmcp/client/auth/oauth.py
+++ b/src/fastmcp/client/auth/oauth.py
@@ -143,56 +143,82 @@ class OAuth(OAuthClientProvider):
a browser for user authorization and running a local callback server.
"""
+ _bound: bool
+
def __init__(
self,
- mcp_url: str,
+ mcp_url: str | None = None,
scopes: str | list[str] | None = None,
client_name: str = "FastMCP Client",
token_storage: AsyncKeyValue | None = None,
additional_client_metadata: dict[str, Any] | None = None,
callback_port: int | None = None,
httpx_client_factory: McpHttpClientFactory | None = None,
+ client_metadata_url: str | None = None,
):
"""
Initialize OAuth client provider for an MCP server.
Args:
- mcp_url: Full URL to the MCP endpoint (e.g. "http://host/mcp/sse/")
+ mcp_url: Full URL to the MCP endpoint (e.g. "http://host/mcp/sse/").
+ Optional when OAuth is passed to Client(auth=...), which provides
+ the URL automatically from the transport.
scopes: OAuth scopes to request. Can be a
space-separated string or a list of strings.
client_name: Name for this client during registration
token_storage: An AsyncKeyValue-compatible token store, tokens are stored in memory if not provided
additional_client_metadata: Extra fields for OAuthClientMetadata
callback_port: Fixed port for OAuth callback (default: random available port)
+ client_metadata_url: A CIMD (Client ID Metadata Document) URL. When
+ provided, this URL is used as the client_id instead of performing
+ Dynamic Client Registration. Must be an HTTPS URL with a non-root
+ path (e.g. "https://myapp.example.com/oauth/client.json").
+ """
+ # Store config for deferred binding if mcp_url not yet known
+ self._scopes = scopes
+ self._client_name = client_name
+ self._token_storage = token_storage
+ self._additional_client_metadata = additional_client_metadata
+ self._callback_port = callback_port
+ self._client_metadata_url = client_metadata_url
+ self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient
+ self._bound = False
+
+ if mcp_url is not None:
+ self._bind(mcp_url)
+
+ def _bind(self, mcp_url: str) -> None:
+ """Bind this OAuth provider to a specific MCP server URL.
+
+ Called automatically when mcp_url is provided to __init__, or by the
+ transport when OAuth is used without an explicit URL.
"""
- # Normalize the MCP URL (strip trailing slashes for consistency)
+ if self._bound:
+ return
+
mcp_url = mcp_url.rstrip("/")
- # Setup OAuth client
- self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient
- self.redirect_port = callback_port or find_available_port()
+ self.redirect_port = self._callback_port or find_available_port()
redirect_uri = f"http://localhost:{self.redirect_port}/callback"
scopes_str: str
- if isinstance(scopes, list):
- scopes_str = " ".join(scopes)
- elif scopes is not None:
- scopes_str = str(scopes)
+ if isinstance(self._scopes, list):
+ scopes_str = " ".join(self._scopes)
+ elif self._scopes is not None:
+ scopes_str = str(self._scopes)
else:
scopes_str = ""
client_metadata = OAuthClientMetadata(
- client_name=client_name,
+ client_name=self._client_name,
redirect_uris=[AnyHttpUrl(redirect_uri)],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
- # token_endpoint_auth_method="client_secret_post",
scope=scopes_str,
- **(additional_client_metadata or {}),
+ **(self._additional_client_metadata or {}),
)
- # Create server-specific token storage
- token_storage = token_storage or MemoryStore()
+ token_storage = self._token_storage or MemoryStore()
if isinstance(token_storage, MemoryStore):
from warnings import warn
@@ -204,23 +230,23 @@ def __init__(
stacklevel=2,
)
- # Use full URL for token storage to properly separate tokens per MCP endpoint
self.token_storage_adapter: TokenStorageAdapter = TokenStorageAdapter(
async_key_value=token_storage, server_url=mcp_url
)
- # Store full MCP URL for use in callback_handler display
self.mcp_url = mcp_url
- # Initialize parent class with full URL for proper OAuth metadata discovery
super().__init__(
server_url=mcp_url,
client_metadata=client_metadata,
storage=self.token_storage_adapter,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
+ client_metadata_url=self._client_metadata_url,
)
+ self._bound = True
+
async def _initialize(self) -> None:
"""Load stored tokens and client info, properly setting token expiry."""
# Call parent's _initialize to load tokens and client info
@@ -298,6 +324,11 @@ async def async_auth_flow(
If the OAuth flow fails due to invalid/stale client credentials,
clears the cache and retries once with fresh registration.
"""
+ if not self._bound:
+ raise RuntimeError(
+ "OAuth provider has no server URL. Either pass mcp_url to OAuth() "
+ "or use it with Client(auth=...) which provides the URL automatically."
+ )
try:
# First attempt with potentially cached credentials
async with aclosing(super().async_auth_flow(request)) as gen:
diff --git a/src/fastmcp/client/transports/http.py b/src/fastmcp/client/transports/http.py
index 89ad8fc621..83dbb7cc86 100644
--- a/src/fastmcp/client/transports/http.py
+++ b/src/fastmcp/client/transports/http.py
@@ -76,11 +76,17 @@ def __init__(
self._get_session_id_cb: Callable[[], str | None] | None = None
def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
+ resolved: httpx.Auth | None
if auth == "oauth":
- auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
+ resolved = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
+ elif isinstance(auth, OAuth):
+ auth._bind(self.url)
+ resolved = auth
elif isinstance(auth, str):
- auth = BearerAuth(auth)
- self.auth = auth
+ resolved = BearerAuth(auth)
+ else:
+ resolved = auth
+ self.auth: httpx.Auth | None = resolved
@contextlib.asynccontextmanager
async def connect_session(
diff --git a/src/fastmcp/client/transports/sse.py b/src/fastmcp/client/transports/sse.py
index ec932e6d2d..45db01beef 100644
--- a/src/fastmcp/client/transports/sse.py
+++ b/src/fastmcp/client/transports/sse.py
@@ -48,11 +48,17 @@ def __init__(
self.sse_read_timeout = normalize_timeout_to_timedelta(sse_read_timeout)
def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
+ resolved: httpx.Auth | None
if auth == "oauth":
- auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
+ resolved = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
+ elif isinstance(auth, OAuth):
+ auth._bind(self.url)
+ resolved = auth
elif isinstance(auth, str):
- auth = BearerAuth(auth)
- self.auth = auth
+ resolved = BearerAuth(auth)
+ else:
+ resolved = auth
+ self.auth: httpx.Auth | None = resolved
@contextlib.asynccontextmanager
async def connect_session(
diff --git a/src/fastmcp/server/auth/auth.py b/src/fastmcp/server/auth/auth.py
index 9a804f05d7..b8b8b1f8c8 100644
--- a/src/fastmcp/server/auth/auth.py
+++ b/src/fastmcp/server/auth/auth.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import json
-from typing import Any, cast
+from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
from mcp.server.auth.handlers.token import TokenErrorResponse
@@ -9,7 +9,13 @@
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
-from mcp.server.auth.middleware.client_auth import ClientAuthenticator
+from mcp.server.auth.middleware.client_auth import (
+ AuthenticationError,
+ ClientAuthenticator,
+)
+from mcp.server.auth.middleware.client_auth import (
+ ClientAuthenticator as _SDKClientAuthenticator,
+)
from mcp.server.auth.provider import (
AccessToken as _SDKAccessToken,
)
@@ -30,13 +36,18 @@
ClientRegistrationOptions,
RevocationOptions,
)
+from mcp.shared.auth import OAuthClientInformationFull
from pydantic import AnyHttpUrl, Field
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.requests import Request
from starlette.routing import Route
from fastmcp.utilities.logging import get_logger
+if TYPE_CHECKING:
+ from fastmcp.server.auth.cimd import CIMDClientManager
+
logger = get_logger(__name__)
@@ -108,6 +119,91 @@ async def handle(self, request: Any):
return response
+# Expected assertion type for private_key_jwt
+JWT_BEARER_ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
+
+
+class PrivateKeyJWTClientAuthenticator(_SDKClientAuthenticator):
+ """Client authenticator with private_key_jwt support for CIMD clients.
+
+ Extends the SDK's ClientAuthenticator to add support for the `private_key_jwt`
+ authentication method per RFC 7523. This is required for CIMD (Client ID Metadata
+ Document) clients that use asymmetric keys for authentication.
+
+ The authenticator:
+ 1. Delegates to SDK for standard methods (client_secret_basic, client_secret_post, none)
+ 2. Adds private_key_jwt handling for CIMD clients
+ 3. Validates JWT assertions against client's JWKS
+ """
+
+ def __init__(
+ self,
+ provider: OAuthAuthorizationServerProvider[Any, Any, Any],
+ cimd_manager: CIMDClientManager,
+ token_endpoint_url: str,
+ ):
+ """Initialize the authenticator.
+
+ Args:
+ provider: OAuth provider for client lookups
+ cimd_manager: CIMD manager for private_key_jwt validation
+ token_endpoint_url: Token endpoint URL for audience validation
+ """
+ super().__init__(provider)
+ self._cimd_manager = cimd_manager
+ self._token_endpoint_url = token_endpoint_url
+
+ async def authenticate_request(
+ self, request: Request
+ ) -> OAuthClientInformationFull:
+ """Authenticate a client from an HTTP request.
+
+ Extends SDK authentication to support private_key_jwt for CIMD clients.
+ Delegates to SDK for client_secret_basic (Authorization header) and
+ client_secret_post (form body) authentication.
+ """
+ form_data = await request.form()
+ client_id = form_data.get("client_id")
+
+ # If client_id is not in form data, delegate to SDK
+ # This handles client_secret_basic which sends credentials in Authorization header
+ if not client_id:
+ return await super().authenticate_request(request)
+
+ client = await self.provider.get_client(str(client_id))
+ if not client:
+ raise AuthenticationError("Invalid client_id")
+
+ # Handle private_key_jwt authentication for CIMD clients
+ if client.token_endpoint_auth_method == "private_key_jwt":
+ # Validate assertion parameters
+ assertion_type = form_data.get("client_assertion_type")
+ assertion = form_data.get("client_assertion")
+
+ if assertion_type != JWT_BEARER_ASSERTION_TYPE:
+ raise AuthenticationError(
+ f"Invalid client_assertion_type: expected {JWT_BEARER_ASSERTION_TYPE}"
+ )
+
+ if not assertion or not isinstance(assertion, str):
+ raise AuthenticationError("Missing client_assertion")
+
+ # Validate the JWT assertion using CIMD manager
+ try:
+ await self._cimd_manager.validate_private_key_jwt(
+ assertion=assertion,
+ client=client,
+ token_endpoint=self._token_endpoint_url,
+ )
+ except ValueError as e:
+ raise AuthenticationError(f"Invalid client assertion: {e}") from e
+
+ return client
+
+ # Delegate to SDK for other authentication methods
+ return await super().authenticate_request(request)
+
+
class AuthProvider(TokenVerifierProtocol):
"""Base class for all FastMCP authentication providers.
diff --git a/src/fastmcp/server/auth/cimd.py b/src/fastmcp/server/auth/cimd.py
new file mode 100644
index 0000000000..49aa6687cb
--- /dev/null
+++ b/src/fastmcp/server/auth/cimd.py
@@ -0,0 +1,651 @@
+"""CIMD (Client ID Metadata Document) support for FastMCP.
+
+.. warning::
+ **Beta Feature**: CIMD support is currently in beta. The API may change
+ in future releases. Please report any issues you encounter.
+
+CIMD is a simpler alternative to Dynamic Client Registration where clients
+host a static JSON document at an HTTPS URL, and that URL becomes their
+client_id. See the IETF draft: draft-parecki-oauth-client-id-metadata-document
+
+This module provides:
+- CIMDDocument: Pydantic model for CIMD document validation
+- CIMDFetcher: Fetch and validate CIMD documents with SSRF protection
+- CIMDClientManager: Manages CIMD client operations
+"""
+
+from __future__ import annotations
+
+import fnmatch
+import json
+import time
+from typing import TYPE_CHECKING, Any, Literal
+from urllib.parse import urlparse
+
+from pydantic import AnyHttpUrl, BaseModel, Field, field_validator
+
+from fastmcp.server.auth.ssrf import (
+ SSRFError,
+ SSRFFetchError,
+ ssrf_safe_fetch,
+ validate_url,
+)
+from fastmcp.utilities.logging import get_logger
+
+if TYPE_CHECKING:
+ from fastmcp.server.auth.providers.jwt import JWTVerifier
+
+logger = get_logger(__name__)
+
+
+class CIMDDocument(BaseModel):
+ """CIMD document per draft-parecki-oauth-client-id-metadata-document.
+
+ The client metadata document is a JSON document containing OAuth client
+ metadata. The client_id property MUST match the URL where this document
+ is hosted.
+
+ Key constraint: token_endpoint_auth_method MUST NOT use shared secrets
+ (client_secret_post, client_secret_basic, client_secret_jwt).
+
+ redirect_uris is required and must contain at least one entry.
+ """
+
+ client_id: AnyHttpUrl = Field(
+ ...,
+ description="Must match the URL where this document is hosted",
+ )
+ client_name: str | None = Field(
+ default=None,
+ description="Human-readable name of the client",
+ )
+ client_uri: AnyHttpUrl | None = Field(
+ default=None,
+ description="URL of the client's home page",
+ )
+ logo_uri: AnyHttpUrl | None = Field(
+ default=None,
+ description="URL of the client's logo image",
+ )
+ redirect_uris: list[str] = Field(
+ ...,
+ description="Array of allowed redirect URIs (may include wildcards like http://localhost:*/callback)",
+ )
+ token_endpoint_auth_method: Literal["none", "private_key_jwt"] = Field(
+ default="none",
+ description="Authentication method for token endpoint (no shared secrets allowed)",
+ )
+ grant_types: list[str] = Field(
+ default_factory=lambda: ["authorization_code"],
+ description="OAuth grant types the client will use",
+ )
+ response_types: list[str] = Field(
+ default_factory=lambda: ["code"],
+ description="OAuth response types the client will use",
+ )
+ scope: str | None = Field(
+ default=None,
+ description="Space-separated list of scopes the client may request",
+ )
+ contacts: list[str] | None = Field(
+ default=None,
+ description="Contact information for the client developer",
+ )
+ tos_uri: AnyHttpUrl | None = Field(
+ default=None,
+ description="URL of the client's terms of service",
+ )
+ policy_uri: AnyHttpUrl | None = Field(
+ default=None,
+ description="URL of the client's privacy policy",
+ )
+ jwks_uri: AnyHttpUrl | None = Field(
+ default=None,
+ description="URL of the client's JSON Web Key Set (for private_key_jwt)",
+ )
+ jwks: dict[str, Any] | None = Field(
+ default=None,
+ description="Client's JSON Web Key Set (for private_key_jwt)",
+ )
+ software_id: str | None = Field(
+ default=None,
+ description="Unique identifier for the client software",
+ )
+ software_version: str | None = Field(
+ default=None,
+ description="Version of the client software",
+ )
+
+ @field_validator("token_endpoint_auth_method")
+ @classmethod
+ def validate_auth_method(cls, v: str) -> str:
+ """Ensure no shared-secret auth methods are used."""
+ forbidden = {"client_secret_post", "client_secret_basic", "client_secret_jwt"}
+ if v in forbidden:
+ raise ValueError(
+ f"CIMD documents cannot use shared-secret auth methods: {v}. "
+ "Use 'none' or 'private_key_jwt' instead."
+ )
+ return v
+
+ @field_validator("redirect_uris")
+ @classmethod
+ def validate_redirect_uris(cls, v: list[str]) -> list[str]:
+ """Ensure redirect_uris is non-empty and each entry is a valid URI."""
+ if not v:
+ raise ValueError("CIMD documents must include at least one redirect_uri")
+ for uri in v:
+ if not uri or not uri.strip():
+ raise ValueError("CIMD redirect_uris must be non-empty strings")
+ parsed = urlparse(uri)
+ if not parsed.scheme:
+ raise ValueError(
+ f"CIMD redirect_uri must have a scheme (e.g. http:// or https://): {uri!r}"
+ )
+ if not parsed.netloc and not uri.startswith("urn:"):
+ raise ValueError(f"CIMD redirect_uri must have a host: {uri!r}")
+ return v
+
+
+class CIMDValidationError(Exception):
+ """Raised when CIMD document validation fails."""
+
+
+class CIMDFetchError(Exception):
+ """Raised when CIMD document fetching fails."""
+
+
+class CIMDFetcher:
+ """Fetch and validate CIMD documents with SSRF protection.
+
+ Delegates HTTP fetching to ssrf_safe_fetch which provides DNS pinning,
+ IP validation, size limits, and timeout enforcement. Documents are cached
+ with a simple TTL.
+ """
+
+ # Maximum response size (bytes)
+ MAX_RESPONSE_SIZE = 5120 # 5KB
+ # Default cache TTL (seconds)
+ DEFAULT_CACHE_TTL_SECONDS = 3600
+
+ def __init__(
+ self,
+ timeout: float = 10.0,
+ ):
+ """Initialize the CIMD fetcher.
+
+ Args:
+ timeout: HTTP request timeout in seconds (default 10.0)
+ """
+ self.timeout = timeout
+ self._cache: dict[str, tuple[CIMDDocument, float]] = {}
+
+ def is_cimd_client_id(self, client_id: str) -> bool:
+ """Check if a client_id looks like a CIMD URL.
+
+ CIMD URLs must be HTTPS with a host and non-root path.
+ """
+ if not client_id:
+ return False
+ try:
+ parsed = urlparse(client_id)
+ return (
+ parsed.scheme == "https"
+ and bool(parsed.netloc)
+ and parsed.path not in ("", "/")
+ )
+ except (ValueError, AttributeError):
+ return False
+
+ async def fetch(self, client_id_url: str) -> CIMDDocument:
+ """Fetch and validate a CIMD document with SSRF protection.
+
+ Uses ssrf_safe_fetch for the HTTP layer, which provides:
+ - HTTPS only, DNS resolution with IP validation
+ - DNS pinning (connects to validated IP directly)
+ - Blocks private/loopback/link-local/multicast IPs
+ - Response size limit and timeout enforcement
+ - Redirects disabled
+
+ Args:
+ client_id_url: The URL to fetch (also the expected client_id)
+
+ Returns:
+ Validated CIMDDocument
+
+ Raises:
+ CIMDValidationError: If document is invalid or URL blocked
+ CIMDFetchError: If document cannot be fetched
+ """
+ cached = self._cache.get(client_id_url)
+ if cached is not None:
+ doc, expires_at = cached
+ if time.time() < expires_at:
+ return doc
+
+ try:
+ content = await ssrf_safe_fetch(
+ client_id_url,
+ require_path=True,
+ max_size=self.MAX_RESPONSE_SIZE,
+ timeout=self.timeout,
+ overall_timeout=30.0,
+ )
+ except SSRFError as e:
+ raise CIMDValidationError(str(e)) from e
+ except SSRFFetchError as e:
+ raise CIMDFetchError(str(e)) from e
+
+ try:
+ data = json.loads(content)
+ except json.JSONDecodeError as e:
+ raise CIMDValidationError(f"CIMD document is not valid JSON: {e}") from e
+
+ try:
+ doc = CIMDDocument.model_validate(data)
+ except Exception as e:
+ raise CIMDValidationError(f"Invalid CIMD document: {e}") from e
+
+ if str(doc.client_id).rstrip("/") != client_id_url.rstrip("/"):
+ raise CIMDValidationError(
+ f"CIMD client_id mismatch: document says '{doc.client_id}' "
+ f"but was fetched from '{client_id_url}'"
+ )
+
+ # Validate jwks_uri if present (SSRF check for JWKS endpoint)
+ if doc.jwks_uri:
+ jwks_uri_str = str(doc.jwks_uri)
+ try:
+ await validate_url(jwks_uri_str)
+ except SSRFError as e:
+ raise CIMDValidationError(
+ f"CIMD jwks_uri failed SSRF validation: {e}"
+ ) from e
+
+ logger.info(
+ "CIMD document fetched and validated: %s (client_name=%s)",
+ client_id_url,
+ doc.client_name,
+ )
+
+ self._cache[client_id_url] = (doc, time.time() + self.DEFAULT_CACHE_TTL_SECONDS)
+ return doc
+
+ def validate_redirect_uri(self, doc: CIMDDocument, redirect_uri: str) -> bool:
+ """Validate that a redirect_uri is allowed by the CIMD document.
+
+ Args:
+ doc: The CIMD document
+ redirect_uri: The redirect URI to validate
+
+ Returns:
+ True if valid, False otherwise
+ """
+ if not doc.redirect_uris:
+ # No redirect_uris specified - reject all
+ return False
+
+ # Normalize for comparison
+ redirect_uri = redirect_uri.rstrip("/")
+
+ for allowed in doc.redirect_uris:
+ allowed_str = allowed.rstrip("/")
+ if redirect_uri == allowed_str:
+ return True
+
+ # Check for wildcard port matching (http://localhost:*/callback)
+ if "*" in allowed_str:
+ if fnmatch.fnmatch(redirect_uri, allowed_str):
+ return True
+
+ return False
+
+
+class CIMDAssertionValidator:
+ """Validates JWT assertions for private_key_jwt CIMD clients.
+
+ Implements RFC 7523 (JSON Web Token (JWT) Profile for OAuth 2.0 Client
+ Authentication and Authorization Grants) for CIMD client authentication.
+
+ JTI replay protection uses TTL-based caching to ensure proper security:
+ - JTIs are cached with expiration matching the JWT's exp claim
+ - Expired JTIs are automatically cleaned up
+ - Maximum assertion lifetime is enforced (5 minutes)
+ """
+
+ # Maximum allowed assertion lifetime in seconds (RFC 7523 recommends short-lived)
+ MAX_ASSERTION_LIFETIME = 300 # 5 minutes
+
+ def __init__(self):
+ # JTI cache: maps jti -> expiration timestamp
+ self._jti_cache: dict[str, float] = {}
+ self._jti_cache_max_size = 10000
+ self._last_cleanup = time.monotonic()
+ self._cleanup_interval = 60 # Cleanup every 60 seconds
+ # Cache JWTVerifier per jwks_uri so JWKS keys are not re-fetched
+ # on every token exchange
+ self._verifier_cache: dict[str, JWTVerifier] = {}
+ self._verifier_cache_max_size = 100
+ self.logger = get_logger(__name__)
+
+ def _cleanup_expired_jtis(self) -> None:
+ """Remove expired JTIs from cache."""
+ now = time.time()
+ expired = [jti for jti, exp in self._jti_cache.items() if exp < now]
+ for jti in expired:
+ del self._jti_cache[jti]
+ if expired:
+ self.logger.debug("Cleaned up %d expired JTIs from cache", len(expired))
+
+ def _maybe_cleanup(self) -> None:
+ """Periodically cleanup expired JTIs to prevent unbounded growth."""
+ now = time.monotonic()
+ if now - self._last_cleanup > self._cleanup_interval:
+ self._cleanup_expired_jtis()
+ self._last_cleanup = now
+
+ async def validate_assertion(
+ self,
+ assertion: str,
+ client_id: str,
+ token_endpoint: str,
+ cimd_doc: CIMDDocument,
+ ) -> bool:
+ """Validate JWT assertion from client.
+
+ Args:
+ assertion: The JWT assertion string
+ client_id: Expected client_id (must match iss and sub claims)
+ token_endpoint: Token endpoint URL (must match aud claim)
+ cimd_doc: CIMD document containing JWKS for key verification
+
+ Returns:
+ True if valid
+
+ Raises:
+ ValueError: If validation fails
+ """
+ from fastmcp.server.auth.providers.jwt import JWTVerifier as _JWTVerifier
+
+ # Periodic cleanup of expired JTIs
+ self._maybe_cleanup()
+
+ # 1. Validate CIMD document has key material and get/create verifier
+ if cimd_doc.jwks_uri:
+ jwks_uri_str = str(cimd_doc.jwks_uri)
+ cache_key = f"{jwks_uri_str}|{client_id}|{token_endpoint}"
+ verifier = self._verifier_cache.get(cache_key)
+ if verifier is None:
+ verifier = _JWTVerifier(
+ jwks_uri=jwks_uri_str,
+ issuer=client_id,
+ audience=token_endpoint,
+ ssrf_safe=True,
+ )
+ if len(self._verifier_cache) >= self._verifier_cache_max_size:
+ oldest_key = next(iter(self._verifier_cache))
+ del self._verifier_cache[oldest_key]
+ self._verifier_cache[cache_key] = verifier
+ elif cimd_doc.jwks:
+ # Inline JWKS — no caching since the key is embedded
+ public_key = self._extract_public_key_from_jwks(assertion, cimd_doc.jwks)
+ verifier = _JWTVerifier(
+ public_key=public_key,
+ issuer=client_id,
+ audience=token_endpoint,
+ )
+ else:
+ raise ValueError(
+ "CIMD document must have jwks_uri or jwks for private_key_jwt"
+ )
+
+ # 2. Verify JWT using JWTVerifier (handles signature, exp, iss, aud)
+ access_token = await verifier.load_access_token(assertion)
+ if not access_token:
+ raise ValueError("Invalid JWT assertion")
+
+ claims = access_token.claims
+
+ # 3. Validate assertion lifetime (exp and iat)
+ now = time.time()
+ exp = claims.get("exp")
+ iat = claims.get("iat")
+
+ if not exp:
+ raise ValueError("Assertion must include exp claim")
+
+ # Validate exp is in the future (with small clock skew tolerance)
+ if exp < now - 30: # 30 second clock skew tolerance
+ raise ValueError("Assertion has expired")
+
+ # If iat is present, validate it and check assertion lifetime
+ if iat:
+ if iat > now + 30: # 30 second clock skew tolerance
+ raise ValueError("Assertion iat is in the future")
+ if exp - iat > self.MAX_ASSERTION_LIFETIME:
+ raise ValueError(
+ f"Assertion lifetime too long: {exp - iat}s (max {self.MAX_ASSERTION_LIFETIME}s)"
+ )
+ else:
+ # No iat, enforce max lifetime from now
+ if exp > now + self.MAX_ASSERTION_LIFETIME:
+ raise ValueError(
+ f"Assertion exp too far in future (max {self.MAX_ASSERTION_LIFETIME}s)"
+ )
+
+ # 4. Additional RFC 7523 validation: sub claim must equal client_id
+ if claims.get("sub") != client_id:
+ raise ValueError(f"Assertion sub claim must be {client_id}")
+
+ # 5. Check jti for replay attacks (RFC 7523 requirement)
+ jti = claims.get("jti")
+ if not jti:
+ raise ValueError("Assertion must include jti claim")
+
+ # Check if JTI was already used (and hasn't expired from cache)
+ if jti in self._jti_cache:
+ cached_exp = self._jti_cache[jti]
+ if cached_exp > now: # Still valid in cache
+ raise ValueError(f"Assertion replay detected: jti {jti} already used")
+ # Expired in cache, can be reused (clean it up)
+ del self._jti_cache[jti]
+
+ # Add to cache with expiration time
+ # Use the assertion's exp claim so it stays cached until it would expire anyway
+ self._jti_cache[jti] = exp
+
+ # Emergency size limit (shouldn't hit with proper TTL cleanup)
+ if len(self._jti_cache) > self._jti_cache_max_size:
+ self._cleanup_expired_jtis()
+ # If still over limit after cleanup, reject to prevent DoS
+ if len(self._jti_cache) > self._jti_cache_max_size:
+ self.logger.warning(
+ "JTI cache at max capacity (%d), possible attack",
+ self._jti_cache_max_size,
+ )
+ raise ValueError("Server overloaded, please retry")
+
+ self.logger.debug(
+ "JWT assertion validated successfully for client %s", client_id
+ )
+ return True
+
+ def _extract_public_key_from_jwks(self, token: str, jwks: dict) -> str:
+ """Extract public key from inline JWKS.
+
+ Args:
+ token: JWT token to extract kid from
+ jwks: JWKS document containing keys
+
+ Returns:
+ PEM-encoded public key
+
+ Raises:
+ ValueError: If key cannot be found or extracted
+ """
+ import base64
+ import json
+
+ from authlib.jose import JsonWebKey
+
+ # Extract kid from token header
+ try:
+ header_b64 = token.split(".")[0]
+ header_b64 += "=" * (4 - len(header_b64) % 4) # Add padding
+ header = json.loads(base64.urlsafe_b64decode(header_b64))
+ kid = header.get("kid")
+ except Exception as e:
+ raise ValueError(f"Failed to extract key ID from token: {e}") from e
+
+ # Find matching key in JWKS
+ keys = jwks.get("keys", [])
+ if not keys:
+ raise ValueError("JWKS document contains no keys")
+
+ matching_key = None
+ for key in keys:
+ if kid and key.get("kid") == kid:
+ matching_key = key
+ break
+
+ if not matching_key:
+ # If no kid match, try first key as fallback
+ if len(keys) == 1:
+ matching_key = keys[0]
+ self.logger.warning(
+ "No matching kid in JWKS, using single available key"
+ )
+ else:
+ raise ValueError(f"No matching key found for kid={kid} in JWKS")
+
+ # Convert JWK to PEM
+ try:
+ jwk = JsonWebKey.import_key(matching_key)
+ return jwk.as_pem().decode("utf-8")
+ except Exception as e:
+ raise ValueError(f"Failed to convert JWK to PEM: {e}") from e
+
+
+class CIMDClientManager:
+ """Manages all CIMD client operations for OAuth proxy.
+
+ This class encapsulates:
+ - CIMD client detection
+ - Document fetching and validation
+ - Synthetic OAuth client creation
+ - Private key JWT assertion validation
+
+ This allows the OAuth proxy to delegate all CIMD-specific logic to a
+ single, focused manager class.
+ """
+
+ def __init__(
+ self,
+ enable_cimd: bool = True,
+ default_scope: str = "",
+ allowed_redirect_uri_patterns: list[str] | None = None,
+ ):
+ """Initialize CIMD client manager.
+
+ Args:
+ enable_cimd: Whether CIMD support is enabled
+ default_scope: Default scope for CIMD clients if not specified in document
+ allowed_redirect_uri_patterns: Allowed redirect URI patterns (proxy's config)
+ """
+ self.enabled = enable_cimd
+ self.default_scope = default_scope
+ self.allowed_redirect_uri_patterns = allowed_redirect_uri_patterns
+
+ self._fetcher = CIMDFetcher()
+ self._assertion_validator = CIMDAssertionValidator()
+ self.logger = get_logger(__name__)
+
+ def is_cimd_client_id(self, client_id: str) -> bool:
+ """Check if client_id is a CIMD URL.
+
+ Args:
+ client_id: Client ID to check
+
+ Returns:
+ True if client_id is an HTTPS URL (CIMD format)
+ """
+ return self.enabled and self._fetcher.is_cimd_client_id(client_id)
+
+ async def get_client(self, client_id_url: str):
+ """Fetch CIMD document and create synthetic OAuth client.
+
+ Args:
+ client_id_url: HTTPS URL pointing to CIMD document
+
+ Returns:
+ OAuthProxyClient with CIMD document attached, or None if fetch fails
+
+ Note:
+ Return type is left untyped to avoid circular import with oauth_proxy.
+ Returns OAuthProxyClient instance or None.
+ """
+ if not self.enabled:
+ return None
+
+ try:
+ cimd_doc = await self._fetcher.fetch(client_id_url)
+ except (CIMDFetchError, CIMDValidationError) as e:
+ self.logger.warning("CIMD fetch failed for %s: %s", client_id_url, e)
+ return None
+
+ # Import here to avoid circular dependency
+ from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient
+
+ # Create synthetic client from CIMD document.
+ # Keep CIMD redirect_uris as strings on the document itself so wildcard
+ # patterns like http://localhost:*/callback remain valid.
+ redirect_uris = None
+ client = ProxyDCRClient(
+ client_id=client_id_url,
+ client_secret=None,
+ redirect_uris=redirect_uris,
+ grant_types=cimd_doc.grant_types,
+ scope=cimd_doc.scope or self.default_scope,
+ token_endpoint_auth_method=cimd_doc.token_endpoint_auth_method,
+ allowed_redirect_uri_patterns=self.allowed_redirect_uri_patterns,
+ client_name=cimd_doc.client_name,
+ cimd_document=cimd_doc,
+ cimd_fetched_at=time.time(),
+ )
+
+ self.logger.debug(
+ "CIMD client resolved: %s (name=%s)",
+ client_id_url,
+ cimd_doc.client_name,
+ )
+ return client
+
+ async def validate_private_key_jwt(
+ self,
+ assertion: str,
+ client, # OAuthProxyClient, untyped to avoid circular import
+ token_endpoint: str,
+ ) -> bool:
+ """Validate JWT assertion for private_key_jwt auth.
+
+ Args:
+ assertion: JWT assertion string from client
+ client: OAuth proxy client (must have cimd_document)
+ token_endpoint: Token endpoint URL for aud validation
+
+ Returns:
+ True if assertion is valid
+
+ Raises:
+ ValueError: If client doesn't have CIMD document or validation fails
+ """
+ if not hasattr(client, "cimd_document") or not client.cimd_document:
+ raise ValueError("Client must have CIMD document for private_key_jwt")
+
+ cimd_doc = client.cimd_document
+ if cimd_doc.token_endpoint_auth_method != "private_key_jwt":
+ raise ValueError("CIMD document must specify private_key_jwt auth method")
+
+ return await self._assertion_validator.validate_assertion(
+ assertion, client.client_id, token_endpoint, cimd_doc
+ )
diff --git a/src/fastmcp/server/auth/oauth_proxy/consent.py b/src/fastmcp/server/auth/oauth_proxy/consent.py
index 6f47a5da77..87b63d88f2 100644
--- a/src/fastmcp/server/auth/oauth_proxy/consent.py
+++ b/src/fastmcp/server/auth/oauth_proxy/consent.py
@@ -21,6 +21,7 @@
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse
+from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient
from fastmcp.server.auth.oauth_proxy.ui import create_consent_html
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.ui import create_secure_html_response
@@ -245,10 +246,17 @@ async def _show_consent_page(
txn["csrf_token"] = csrf_token
txn["csrf_expires_at"] = csrf_expires_at
- # Load client to get client_name if available
+ # Load client to get client_name and CIMD info if available
client = await self.get_client(txn["client_id"])
client_name = getattr(client, "client_name", None) if client else None
+ # Detect CIMD clients for verified domain badge
+ is_cimd_client = False
+ cimd_domain: str | None = None
+ if isinstance(client, ProxyDCRClient) and client.cimd_document is not None:
+ is_cimd_client = True
+ cimd_domain = urlparse(txn["client_id"]).hostname
+
# Extract server metadata from app state
fastmcp = getattr(request.app.state, "fastmcp_server", None)
@@ -273,6 +281,8 @@ async def _show_consent_page(
server_icon_url=server_icon_url,
server_website_url=server_website_url,
csp_policy=self._consent_csp_policy,
+ is_cimd_client=is_cimd_client,
+ cimd_domain=cimd_domain,
)
response = create_secure_html_response(html)
# Store CSRF in cookie with short lifetime
diff --git a/src/fastmcp/server/auth/oauth_proxy/models.py b/src/fastmcp/server/auth/oauth_proxy/models.py
index 53c939f2a1..575c846baa 100644
--- a/src/fastmcp/server/auth/oauth_proxy/models.py
+++ b/src/fastmcp/server/auth/oauth_proxy/models.py
@@ -11,7 +11,11 @@
from mcp.shared.auth import InvalidRedirectUriError, OAuthClientInformationFull
from pydantic import AnyUrl, BaseModel, Field
-from fastmcp.server.auth.redirect_validation import validate_redirect_uri
+from fastmcp.server.auth.cimd import CIMDDocument
+from fastmcp.server.auth.redirect_validation import (
+ matches_allowed_pattern,
+ validate_redirect_uri,
+)
# -------------------------------------------------------------------------
# Constants
@@ -156,28 +160,77 @@ class ProxyDCRClient(OAuthClientInformationFull):
allowed_redirect_uri_patterns: list[str] | None = Field(default=None)
client_name: str | None = Field(default=None)
+ cimd_document: CIMDDocument | None = Field(default=None)
+ cimd_fetched_at: float | None = Field(default=None)
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
- """Validate redirect URI against allowed patterns.
+ """Validate redirect URI against proxy patterns and optionally CIMD redirect_uris.
- Since we're acting as a proxy and clients register dynamically,
- we validate their redirect URIs against configurable patterns.
- This is essential for cached token scenarios where the client may
- reconnect with a different port.
+ For CIMD clients: validates against BOTH the CIMD document's redirect_uris
+ AND the proxy's allowed patterns (if configured). Both must pass.
+
+ For DCR clients: validates against proxy patterns first, falling back to
+ base validation (registered redirect_uris) if patterns don't match.
"""
+ if redirect_uri is None and self.cimd_document is not None:
+ cimd_redirect_uris = self.cimd_document.redirect_uris
+ if len(cimd_redirect_uris) == 1:
+ candidate = cimd_redirect_uris[0]
+ if "*" in candidate:
+ raise InvalidRedirectUriError(
+ "redirect_uri must be specified when CIMD redirect_uris uses wildcards."
+ )
+ try:
+ return AnyUrl(candidate)
+ except Exception as e:
+ raise InvalidRedirectUriError(
+ f"Invalid CIMD redirect_uri: {e}"
+ ) from e
+
+ raise InvalidRedirectUriError(
+ "redirect_uri must be specified when CIMD lists multiple redirect_uris."
+ )
+
if redirect_uri is not None:
- # Validate against allowed patterns
- if validate_redirect_uri(
+ cimd_redirect_uris = (
+ self.cimd_document.redirect_uris if self.cimd_document else None
+ )
+
+ if cimd_redirect_uris:
+ uri_str = str(redirect_uri)
+ cimd_match = any(
+ matches_allowed_pattern(uri_str, pattern)
+ for pattern in cimd_redirect_uris
+ )
+ if not cimd_match:
+ raise InvalidRedirectUriError(
+ f"Redirect URI '{redirect_uri}' does not match CIMD redirect_uris."
+ )
+
+ if self.allowed_redirect_uri_patterns:
+ if not validate_redirect_uri(
+ redirect_uri=redirect_uri,
+ allowed_patterns=self.allowed_redirect_uri_patterns,
+ ):
+ raise InvalidRedirectUriError(
+ f"Redirect URI '{redirect_uri}' does not match allowed patterns."
+ )
+
+ return redirect_uri
+
+ pattern_matches = validate_redirect_uri(
redirect_uri=redirect_uri,
allowed_patterns=self.allowed_redirect_uri_patterns,
- ):
+ )
+
+ if pattern_matches:
return redirect_uri
- # If patterns are explicitly configured then reject non-matching URIs
+ # Patterns configured but didn't match
if self.allowed_redirect_uri_patterns:
raise InvalidRedirectUriError(
f"Redirect URI '{redirect_uri}' does not match allowed patterns."
)
- # If no redirect_uri provided, use default behavior
+ # No redirect_uri provided or no patterns configured — use base validation
return super().validate_redirect_uri(redirect_uri)
diff --git a/src/fastmcp/server/auth/oauth_proxy/proxy.py b/src/fastmcp/server/auth/oauth_proxy/proxy.py
index e1a24720f5..27773ef25c 100644
--- a/src/fastmcp/server/auth/oauth_proxy/proxy.py
+++ b/src/fastmcp/server/auth/oauth_proxy/proxy.py
@@ -32,6 +32,7 @@
from key_value.aio.adapters.pydantic import PydanticAdapter
from key_value.aio.protocols import AsyncKeyValue
from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
+from mcp.server.auth.handlers.metadata import MetadataHandler
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
@@ -40,6 +41,7 @@
RefreshToken,
TokenError,
)
+from mcp.server.auth.routes import build_metadata, cors_middleware
from mcp.server.auth.settings import (
ClientRegistrationOptions,
RevocationOptions,
@@ -52,7 +54,13 @@
from typing_extensions import override
from fastmcp import settings
-from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
+from fastmcp.server.auth.auth import (
+ OAuthProvider,
+ PrivateKeyJWTClientAuthenticator,
+ TokenHandler,
+ TokenVerifier,
+)
+from fastmcp.server.auth.cimd import CIMDClientManager
from fastmcp.server.auth.handlers.authorize import AuthorizationHandler
from fastmcp.server.auth.jwt_issuer import (
JWTIssuer,
@@ -248,6 +256,8 @@ def __init__(
consent_csp_policy: str | None = None,
# Token expiry fallback
fallback_access_token_expiry_seconds: int | None = None,
+ # CIMD (Client ID Metadata Document) support
+ enable_cimd: bool = True,
):
"""Initialize the OAuth proxy provider.
@@ -302,6 +312,9 @@ def __init__(
defaults: 1 hour if a refresh token is available (since we can refresh),
or 1 year if no refresh token (for API-key-style tokens like GitHub OAuth Apps).
Set explicitly to override these defaults.
+ enable_cimd: Enable CIMD (Client ID Metadata Document) support for URL-based
+ client IDs. When True, clients can authenticate using HTTPS URLs as client
+ IDs, with metadata fetched from the URL. Supports private_key_jwt auth.
"""
# Always enable DCR since we implement it locally for MCP clients
@@ -484,6 +497,15 @@ def __init__(
# Use the provided token validator
self._token_validator: TokenVerifier = token_verifier
+ # CIMD (Client ID Metadata Document) support
+ self._cimd_manager: CIMDClientManager | None = None
+ if enable_cimd:
+ self._cimd_manager = CIMDClientManager(
+ enable_cimd=True,
+ default_scope=self._default_scope_str,
+ allowed_redirect_uri_patterns=self._allowed_client_redirect_uris,
+ )
+
logger.debug(
"Initialized OAuth proxy provider with upstream server %s",
self._upstream_authorization_endpoint,
@@ -559,15 +581,43 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
provided to the DCR client during registration, not the upstream client ID.
For unregistered clients, returns None (which will raise an error in the SDK).
+ CIMD clients (URL-based client IDs) are looked up and cached automatically.
"""
# Load from storage
- if not (client := await self._client_store.get(key=client_id)):
- return None
+ client = await self._client_store.get(key=client_id)
- if client.allowed_redirect_uri_patterns is None:
- client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris
+ if client is not None:
+ if client.allowed_redirect_uri_patterns is None:
+ client.allowed_redirect_uri_patterns = (
+ self._allowed_client_redirect_uris
+ )
+
+ # Refresh CIMD clients using HTTP cache-aware fetcher.
+ if self._cimd_manager is not None and client.cimd_document is not None:
+ try:
+ refreshed = await self._cimd_manager.get_client(client_id)
+ if refreshed is not None:
+ await self._client_store.put(key=client_id, value=refreshed)
+ return refreshed
+ except Exception as e:
+ logger.debug(
+ "CIMD refresh failed for %s, using cached client: %s",
+ client_id,
+ e,
+ )
- return client
+ return client
+
+ # Client not in storage — try CIMD lookup for URL-based client IDs
+ if self._cimd_manager is not None and self._cimd_manager.is_cimd_client_id(
+ client_id
+ ):
+ cimd_client = await self._cimd_manager.get_client(client_id)
+ if cimd_client is not None:
+ await self._client_store.put(key=client_id, value=cimd_client)
+ return cimd_client
+
+ return None
@override
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
@@ -1437,6 +1487,61 @@ def get_routes(
methods=["GET", "POST"],
)
)
+ elif (
+ self._cimd_manager is not None
+ and isinstance(route, Route)
+ and route.path == "/token"
+ and route.methods is not None
+ and "POST" in route.methods
+ ):
+ # Replace the token endpoint authenticator with one that supports
+ # private_key_jwt for CIMD clients
+ token_endpoint_url = f"{self.base_url}/token"
+ cimd_authenticator = PrivateKeyJWTClientAuthenticator(
+ provider=self,
+ cimd_manager=self._cimd_manager,
+ token_endpoint_url=token_endpoint_url,
+ )
+ token_handler = TokenHandler(
+ provider=self, client_authenticator=cimd_authenticator
+ )
+ custom_routes.append(
+ Route(
+ path="/token",
+ endpoint=cors_middleware(
+ token_handler.handle, ["POST", "OPTIONS"]
+ ),
+ methods=["POST", "OPTIONS"],
+ )
+ )
+ elif (
+ self._cimd_manager is not None
+ and isinstance(route, Route)
+ and route.path.startswith("/.well-known/oauth-authorization-server")
+ ):
+ client_registration_options = (
+ self.client_registration_options or ClientRegistrationOptions()
+ )
+ revocation_options = self.revocation_options or RevocationOptions()
+ metadata = build_metadata(
+ self.base_url, # ty: ignore[invalid-argument-type]
+ self.service_documentation_url,
+ client_registration_options,
+ revocation_options,
+ )
+ metadata.client_id_metadata_document_supported = True
+ handler = MetadataHandler(metadata)
+ methods = route.methods or ["GET", "OPTIONS"]
+
+ custom_routes.append(
+ Route(
+ path=route.path,
+ endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]),
+ methods=methods,
+ name=route.name,
+ include_in_schema=route.include_in_schema,
+ )
+ )
else:
# Keep all other standard OAuth routes unchanged
custom_routes.append(route)
diff --git a/src/fastmcp/server/auth/oauth_proxy/ui.py b/src/fastmcp/server/auth/oauth_proxy/ui.py
index 3bae1a11c1..4cbb3ec2ce 100644
--- a/src/fastmcp/server/auth/oauth_proxy/ui.py
+++ b/src/fastmcp/server/auth/oauth_proxy/ui.py
@@ -32,6 +32,8 @@ def create_consent_html(
server_website_url: str | None = None,
client_website_url: str | None = None,
csp_policy: str | None = None,
+ is_cimd_client: bool = False,
+ cimd_domain: str | None = None,
) -> str:
"""Create a styled HTML consent page for OAuth authorization requests.
@@ -60,6 +62,17 @@ def create_consent_html(
"""
+ # Build CIMD verified domain badge if applicable
+ cimd_badge = ""
+ if is_cimd_client and cimd_domain:
+ cimd_domain_escaped = html_module.escape(cimd_domain)
+ cimd_badge = f"""
+
+ ✓
+ Verified domain: {cimd_domain_escaped}
+
+ """
+
# Build redirect URI section (yellow box, centered)
redirect_uri_escaped = html_module.escape(redirect_uri)
redirect_section = f"""
@@ -144,6 +157,7 @@ def create_consent_html(
{create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
Application Access Request
{intro_box}
+ {cimd_badge}
{redirect_section}
{advanced_details}
{form}
@@ -152,6 +166,23 @@ def create_consent_html(
"""
# Additional styles needed for this page
+ cimd_badge_styles = """
+ .cimd-badge {
+ background: #ecfdf5;
+ border: 1px solid #6ee7b7;
+ border-radius: 8px;
+ padding: 8px 16px;
+ margin-bottom: 16px;
+ font-size: 14px;
+ color: #065f46;
+ text-align: center;
+ }
+ .cimd-check {
+ color: #059669;
+ font-weight: bold;
+ margin-right: 4px;
+ }
+ """
additional_styles = (
INFO_BOX_STYLES
+ REDIRECT_SECTION_STYLES
@@ -159,6 +190,7 @@ def create_consent_html(
+ DETAIL_BOX_STYLES
+ BUTTON_STYLES
+ TOOLTIP_STYLES
+ + cimd_badge_styles
)
# Determine CSP policy to use
diff --git a/src/fastmcp/server/auth/oidc_proxy.py b/src/fastmcp/server/auth/oidc_proxy.py
index 1bcdef4e43..d89ac07561 100644
--- a/src/fastmcp/server/auth/oidc_proxy.py
+++ b/src/fastmcp/server/auth/oidc_proxy.py
@@ -228,6 +228,8 @@ def __init__(
extra_token_params: dict[str, str] | None = None,
# Token expiry fallback
fallback_access_token_expiry_seconds: int | None = None,
+ # CIMD configuration
+ enable_cimd: bool = True,
) -> None:
"""Initialize the OIDC proxy provider.
@@ -278,6 +280,9 @@ def __init__(
doesn't return `expires_in` in the token response. If not set, uses smart
defaults: 1 hour if a refresh token is available (since we can refresh),
or 1 year if no refresh token (for API-key-style tokens like GitHub OAuth Apps).
+ enable_cimd: Whether to enable CIMD (Client ID Metadata Document) client support.
+ When True, clients can use their metadata document URL as client_id instead of
+ Dynamic Client Registration. Default is True.
"""
if not config_url:
raise ValueError("Missing required config URL")
@@ -351,6 +356,7 @@ def __init__(
"require_authorization_consent": require_authorization_consent,
"consent_csp_policy": consent_csp_policy,
"fallback_access_token_expiry_seconds": fallback_access_token_expiry_seconds,
+ "enable_cimd": enable_cimd,
}
if redirect_path:
diff --git a/src/fastmcp/server/auth/providers/jwt.py b/src/fastmcp/server/auth/providers/jwt.py
index fd01c2f2cb..828b9238f7 100644
--- a/src/fastmcp/server/auth/providers/jwt.py
+++ b/src/fastmcp/server/auth/providers/jwt.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import json
import time
from dataclasses import dataclass
from typing import Any, cast
@@ -15,6 +16,7 @@
from typing_extensions import TypedDict
from fastmcp.server.auth import AccessToken, TokenVerifier
+from fastmcp.server.auth.ssrf import SSRFError, SSRFFetchError, ssrf_safe_fetch
from fastmcp.utilities.auth import decode_jwt_header, parse_scopes
from fastmcp.utilities.logging import get_logger
@@ -165,6 +167,7 @@ def __init__(
algorithm: str | None = None,
required_scopes: list[str] | None = None,
base_url: AnyHttpUrl | str | None = None,
+ ssrf_safe: bool = False,
):
"""
Initialize a JWTVerifier configured to validate JWTs using either a static key or a JWKS endpoint.
@@ -177,6 +180,10 @@ def __init__(
algorithm: JWT signing algorithm to accept (default: "RS256"). Supported: HS256/384/512, RS256/384/512, ES256/384/512, PS256/384/512.
required_scopes: Scopes that must be present in validated tokens.
base_url: Base URL passed to the parent TokenVerifier.
+ ssrf_safe: If True, JWKS fetches use SSRF protection (HTTPS-only,
+ public IPs, DNS pinning). Enable when the JWKS URI comes from
+ untrusted input (e.g. CIMD documents). Defaults to False so
+ operator-configured JWKS URIs (including localhost) work normally.
Raises:
ValueError: If neither or both of `public_key` and `jwks_uri` are provided, or if `algorithm` is unsupported.
@@ -220,6 +227,7 @@ def __init__(
self.audience = audience
self.public_key = public_key
self.jwks_uri = jwks_uri
+ self.ssrf_safe = ssrf_safe
self.jwt = JsonWebToken([self.algorithm])
self.logger = get_logger(__name__)
@@ -239,11 +247,11 @@ async def _get_verification_key(self, token: str) -> str:
kid = header.get("kid")
return await self._get_jwks_key(kid)
- except Exception as e:
+ except (ValueError, KeyError, IndexError, json.JSONDecodeError) as e:
raise ValueError(f"Failed to extract key ID from token: {e}") from e
async def _get_jwks_key(self, kid: str | None) -> str:
- """Fetch key from JWKS with simple caching."""
+ """Fetch key from JWKS with simple caching and SSRF protection."""
if not self.jwks_uri:
raise ValueError("JWKS URI not configured")
@@ -257,12 +265,9 @@ async def _get_jwks_key(self, kid: str | None) -> str:
# If no kid but only one key cached, use it
return next(iter(self._jwks_cache.values()))
- # Fetch JWKS
+ # Fetch JWKS — with SSRF protection when enabled (untrusted URIs)
try:
- async with httpx.AsyncClient() as client:
- response = await client.get(self.jwks_uri)
- response.raise_for_status()
- jwks_data = response.json()
+ jwks_data = await self._fetch_jwks()
# Cache all keys
self._jwks_cache = {}
@@ -298,11 +303,35 @@ async def _get_jwks_key(self, kid: str | None) -> str:
else:
raise ValueError("No keys found in JWKS")
- except httpx.HTTPError as e:
+ except (SSRFError, SSRFFetchError) as e:
+ self.logger.debug("JWKS fetch blocked by SSRF protection: %s", e)
raise ValueError(f"Failed to fetch JWKS: {e}") from e
- except Exception as e:
- self.logger.debug(f"JWKS fetch failed: {e}")
+ except httpx.HTTPError as e:
raise ValueError(f"Failed to fetch JWKS: {e}") from e
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JWKS JSON: {e}") from e
+ except (JoseError, TypeError, KeyError) as e:
+ self.logger.debug("JWKS key processing failed: %s", e)
+ raise ValueError(f"Failed to process JWKS: {e}") from e
+
+ async def _fetch_jwks(self) -> dict[str, Any]:
+ """Fetch JWKS data, using SSRF-safe or standard fetch based on config."""
+ if not self.jwks_uri:
+ raise ValueError("JWKS URI not configured")
+
+ if self.ssrf_safe:
+ content = await ssrf_safe_fetch(
+ self.jwks_uri,
+ max_size=65536,
+ timeout=10.0,
+ overall_timeout=30.0,
+ )
+ return json.loads(content)
+ else:
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
+ response = await client.get(self.jwks_uri)
+ response.raise_for_status()
+ return response.json()
def _extract_scopes(self, claims: dict[str, Any]) -> list[str]:
"""
@@ -435,7 +464,7 @@ async def load_access_token(self, token: str) -> AccessToken | None:
except JoseError:
self.logger.debug("Token validation failed: JWT signature/format invalid")
return None
- except Exception as e:
+ except (ValueError, TypeError, KeyError, AttributeError) as e:
self.logger.debug("Token validation failed: %s", str(e))
return None
diff --git a/src/fastmcp/server/auth/redirect_validation.py b/src/fastmcp/server/auth/redirect_validation.py
index f49958ad56..4d011416fb 100644
--- a/src/fastmcp/server/auth/redirect_validation.py
+++ b/src/fastmcp/server/auth/redirect_validation.py
@@ -1,19 +1,138 @@
-"""Utilities for validating client redirect URIs in OAuth flows."""
+"""Utilities for validating client redirect URIs in OAuth flows.
+
+This module provides secure redirect URI validation with wildcard support,
+protecting against userinfo-based bypass attacks like http://localhost@evil.com.
+"""
import fnmatch
+from urllib.parse import urlparse
from pydantic import AnyUrl
+def _parse_host_port(netloc: str) -> tuple[str | None, str | None]:
+ """Parse host and port from netloc, handling wildcards.
+
+ Args:
+ netloc: The netloc component (e.g., "localhost:8080" or "localhost:*")
+
+ Returns:
+ Tuple of (host, port_str) where port_str may be "*" or a number string
+ """
+ # Handle userinfo (remove it for parsing, but we check separately)
+ if "@" in netloc:
+ netloc = netloc.split("@")[-1]
+
+ # Handle IPv6 addresses [::1]:port
+ if netloc.startswith("["):
+ bracket_end = netloc.find("]")
+ if bracket_end == -1:
+ return netloc, None
+ host = netloc[1:bracket_end]
+ rest = netloc[bracket_end + 1 :]
+ if rest.startswith(":"):
+ return host, rest[1:]
+ return host, None
+
+ # Handle regular host:port
+ if ":" in netloc:
+ host, port = netloc.rsplit(":", 1)
+ return host, port
+
+ return netloc, None
+
+
+def _match_host(uri_host: str | None, pattern_host: str | None) -> bool:
+ """Match host component, supporting *.example.com wildcard patterns.
+
+ Args:
+ uri_host: The host from the URI being validated
+ pattern_host: The host pattern (may start with *.)
+
+ Returns:
+ True if the host matches
+ """
+ if not uri_host or not pattern_host:
+ return uri_host == pattern_host
+
+ # Normalize to lowercase for comparison
+ uri_host = uri_host.lower()
+ pattern_host = pattern_host.lower()
+
+ # Handle *.example.com wildcard subdomain patterns
+ if pattern_host.startswith("*."):
+ suffix = pattern_host[1:] # .example.com
+ # Only match actual subdomains (foo.example.com), NOT the base domain
+ return uri_host.endswith(suffix) and uri_host != pattern_host[2:]
+
+ return uri_host == pattern_host
+
+
+def _match_port(
+ uri_port: str | None,
+ pattern_port: str | None,
+ uri_scheme: str,
+) -> bool:
+ """Match port component, supporting * wildcard for any port.
+
+ Args:
+ uri_port: The port from the URI (None if default, string otherwise)
+ pattern_port: The port from the pattern (None if default, "*" for wildcard)
+ uri_scheme: The URI scheme (http/https) for default port handling
+
+ Returns:
+ True if the port matches
+ """
+ # Wildcard matches any port
+ if pattern_port == "*":
+ return True
+
+ # Normalize None to default ports
+ default_port = "443" if uri_scheme == "https" else "80"
+ uri_effective = uri_port if uri_port else default_port
+ pattern_effective = pattern_port if pattern_port else default_port
+
+ return uri_effective == pattern_effective
+
+
+def _match_path(uri_path: str, pattern_path: str) -> bool:
+ """Match path component using fnmatch for wildcard support.
+
+ Args:
+ uri_path: The path from the URI
+ pattern_path: The path pattern (may contain * wildcards)
+
+ Returns:
+ True if the path matches
+ """
+ # Normalize empty paths to /
+ uri_path = uri_path or "/"
+ pattern_path = pattern_path or "/"
+
+ # Empty or root pattern path matches any path
+ # This makes http://localhost:* match http://localhost:3000/callback
+ if pattern_path == "/":
+ return True
+
+ # Use fnmatch for path wildcards (e.g., /auth/*)
+ return fnmatch.fnmatch(uri_path, pattern_path)
+
+
def matches_allowed_pattern(uri: str, pattern: str) -> bool:
- """Check if a URI matches an allowed pattern with wildcard support.
+ """Securely check if a URI matches an allowed pattern with wildcard support.
- Patterns support * wildcard matching:
+ This function parses both the URI and pattern as URLs, comparing each
+ component separately to prevent bypass attacks like userinfo injection.
+
+ Patterns support wildcards:
- http://localhost:* matches any localhost port
- http://127.0.0.1:* matches any 127.0.0.1 port
- https://*.example.com/* matches any subdomain of example.com
- https://app.example.com/auth/* matches any path under /auth/
+ Security: Rejects URIs with userinfo (user:pass@host) which could bypass
+ naive string matching (e.g., http://localhost@evil.com).
+
Args:
uri: The redirect URI to validate
pattern: The allowed pattern (may contain wildcards)
@@ -21,8 +140,36 @@ def matches_allowed_pattern(uri: str, pattern: str) -> bool:
Returns:
True if the URI matches the pattern
"""
- # Use fnmatch for wildcard matching
- return fnmatch.fnmatch(uri, pattern)
+ try:
+ uri_parsed = urlparse(uri)
+ pattern_parsed = urlparse(pattern)
+ except ValueError:
+ return False
+
+ # SECURITY: Reject URIs with userinfo (user:pass@host)
+ # This prevents bypass attacks like http://localhost@evil.com/callback
+ # which would match http://localhost:* with naive fnmatch
+ if uri_parsed.username is not None or uri_parsed.password is not None:
+ return False
+
+ # Scheme must match exactly
+ if uri_parsed.scheme.lower() != pattern_parsed.scheme.lower():
+ return False
+
+ # Parse host and port manually to handle wildcards
+ uri_host, uri_port = _parse_host_port(uri_parsed.netloc)
+ pattern_host, pattern_port = _parse_host_port(pattern_parsed.netloc)
+
+ # Host must match (with subdomain wildcard support)
+ if not _match_host(uri_host, pattern_host):
+ return False
+
+ # Port must match (with * wildcard support)
+ if not _match_port(uri_port, pattern_port, uri_parsed.scheme.lower()):
+ return False
+
+ # Path must match (with fnmatch wildcards)
+ return _match_path(uri_parsed.path, pattern_parsed.path)
def validate_redirect_uri(
diff --git a/src/fastmcp/server/auth/ssrf.py b/src/fastmcp/server/auth/ssrf.py
new file mode 100644
index 0000000000..8009269c6d
--- /dev/null
+++ b/src/fastmcp/server/auth/ssrf.py
@@ -0,0 +1,307 @@
+"""SSRF-safe HTTP utilities for FastMCP.
+
+This module provides SSRF-protected HTTP fetching with:
+- DNS resolution and IP validation before requests
+- DNS pinning to prevent rebinding TOCTOU attacks
+- Support for both CIMD and JWKS fetches
+"""
+
+from __future__ import annotations
+
+import asyncio
+import ipaddress
+import socket
+import time
+from dataclasses import dataclass
+from urllib.parse import urlparse
+
+import httpx
+
+from fastmcp.utilities.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def format_ip_for_url(ip_str: str) -> str:
+ """Format IP address for use in URL (bracket IPv6 addresses).
+
+ IPv6 addresses must be bracketed in URLs to distinguish the address from
+ the port separator. For example: https://[2001:db8::1]:443/path
+
+ Args:
+ ip_str: IP address string
+
+ Returns:
+ IP string suitable for URL (IPv6 addresses are bracketed)
+ """
+ try:
+ ip = ipaddress.ip_address(ip_str)
+ if isinstance(ip, ipaddress.IPv6Address):
+ return f"[{ip_str}]"
+ return ip_str
+ except ValueError:
+ return ip_str
+
+
+class SSRFError(Exception):
+ """Raised when an SSRF protection check fails."""
+
+
+class SSRFFetchError(Exception):
+ """Raised when SSRF-safe fetch fails."""
+
+
+def is_ip_allowed(ip_str: str) -> bool:
+ """Check if an IP address is allowed (must be globally routable unicast).
+
+ Uses ip.is_global which catches:
+ - Private (10.x, 172.16-31.x, 192.168.x)
+ - Loopback (127.x, ::1)
+ - Link-local (169.254.x, fe80::) - includes AWS metadata!
+ - Reserved, unspecified
+ - RFC6598 Carrier-Grade NAT (100.64.0.0/10) - can point to internal networks
+
+ Additionally blocks multicast addresses (not caught by is_global).
+
+ Args:
+ ip_str: IP address string to check
+
+ Returns:
+ True if the IP is allowed (public unicast internet), False if blocked
+ """
+ try:
+ ip = ipaddress.ip_address(ip_str)
+ except ValueError:
+ return False
+
+ if not ip.is_global:
+ return False
+
+ # Block multicast (not caught by is_global for some ranges)
+ if ip.is_multicast:
+ return False
+
+ # IPv6-specific checks for embedded IPv4 addresses
+ if isinstance(ip, ipaddress.IPv6Address):
+ if ip.ipv4_mapped:
+ return is_ip_allowed(str(ip.ipv4_mapped))
+ if ip.sixtofour:
+ return is_ip_allowed(str(ip.sixtofour))
+ if ip.teredo:
+ server, client = ip.teredo
+ return is_ip_allowed(str(server)) and is_ip_allowed(str(client))
+
+ return True
+
+
+async def resolve_hostname(hostname: str, port: int = 443) -> list[str]:
+ """Resolve hostname to IP addresses using DNS.
+
+ Args:
+ hostname: Hostname to resolve
+ port: Port number (used for getaddrinfo)
+
+ Returns:
+ List of resolved IP addresses
+
+ Raises:
+ SSRFError: If resolution fails
+ """
+ loop = asyncio.get_running_loop()
+ try:
+ infos = await loop.run_in_executor(
+ None,
+ lambda: socket.getaddrinfo(
+ hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
+ ),
+ )
+ ips = list({info[4][0] for info in infos})
+ if not ips:
+ raise SSRFError(f"DNS resolution returned no addresses for {hostname}")
+ return ips
+ except socket.gaierror as e:
+ raise SSRFError(f"DNS resolution failed for {hostname}: {e}") from e
+
+
+@dataclass
+class ValidatedURL:
+ """A URL that has been validated for SSRF with resolved IPs."""
+
+ original_url: str
+ hostname: str
+ port: int
+ path: str
+ resolved_ips: list[str]
+
+
+async def validate_url(url: str, require_path: bool = False) -> ValidatedURL:
+ """Validate URL for SSRF and resolve to IPs.
+
+ Args:
+ url: URL to validate
+ require_path: If True, require non-root path (for CIMD)
+
+ Returns:
+ ValidatedURL with resolved IPs
+
+ Raises:
+ SSRFError: If URL is invalid or resolves to blocked IPs
+ """
+ try:
+ parsed = urlparse(url)
+ except (ValueError, AttributeError) as e:
+ raise SSRFError(f"Invalid URL: {e}") from e
+
+ if parsed.scheme != "https":
+ raise SSRFError(f"URL must use HTTPS, got: {parsed.scheme}")
+
+ if not parsed.netloc:
+ raise SSRFError("URL must have a host")
+
+ if require_path and parsed.path in ("", "/"):
+ raise SSRFError("URL must have a non-root path")
+
+ hostname = parsed.hostname or parsed.netloc
+ port = parsed.port or 443
+
+ # Resolve and validate IPs
+ resolved_ips = await resolve_hostname(hostname, port)
+
+ blocked = [ip for ip in resolved_ips if not is_ip_allowed(ip)]
+ if blocked:
+ raise SSRFError(
+ f"URL resolves to blocked IP address(es): {blocked}. "
+ f"Private, loopback, link-local, and reserved IPs are not allowed."
+ )
+
+ return ValidatedURL(
+ original_url=url,
+ hostname=hostname,
+ port=port,
+ path=parsed.path + ("?" + parsed.query if parsed.query else ""),
+ resolved_ips=resolved_ips,
+ )
+
+
+async def ssrf_safe_fetch(
+ url: str,
+ *,
+ require_path: bool = False,
+ max_size: int = 5120,
+ timeout: float = 10.0,
+ overall_timeout: float = 30.0,
+) -> bytes:
+ """Fetch URL with comprehensive SSRF protection and DNS pinning.
+
+ Security measures:
+ 1. HTTPS only
+ 2. DNS resolution with IP validation
+ 3. Connects to validated IP directly (DNS pinning prevents rebinding)
+ 4. Response size limit
+ 5. Redirects disabled
+ 6. Overall timeout
+
+ Args:
+ url: URL to fetch
+ require_path: If True, require non-root path
+ max_size: Maximum response size in bytes (default 5KB)
+ timeout: Per-operation timeout in seconds
+ overall_timeout: Overall timeout for entire operation
+
+ Returns:
+ Response body as bytes
+
+ Raises:
+ SSRFError: If SSRF validation fails
+ SSRFFetchError: If fetch fails
+ """
+ start_time = time.monotonic()
+
+ # Validate URL and resolve DNS
+ validated = await validate_url(url, require_path=require_path)
+
+ last_error: Exception | None = None
+
+ for pinned_ip in validated.resolved_ips:
+ elapsed = time.monotonic() - start_time
+ if elapsed > overall_timeout:
+ raise SSRFFetchError(f"Overall timeout exceeded: {url}")
+ remaining = max(1.0, overall_timeout - elapsed)
+
+ pinned_url = (
+ f"https://{format_ip_for_url(pinned_ip)}:{validated.port}{validated.path}"
+ )
+
+ logger.debug(
+ "SSRF-safe fetch: %s -> %s (pinned to %s)",
+ url,
+ pinned_url,
+ pinned_ip,
+ )
+
+ try:
+ # Use httpx with streaming to enforce size limit during download
+ async with (
+ httpx.AsyncClient(
+ timeout=httpx.Timeout(
+ connect=min(timeout, remaining),
+ read=min(timeout, remaining),
+ write=min(timeout, remaining),
+ pool=min(timeout, remaining),
+ ),
+ follow_redirects=False,
+ verify=True,
+ ) as client,
+ client.stream(
+ "GET",
+ pinned_url,
+ headers={"Host": validated.hostname},
+ extensions={"sni_hostname": validated.hostname},
+ ) as response,
+ ):
+ if time.monotonic() - start_time > overall_timeout:
+ raise SSRFFetchError(f"Overall timeout exceeded: {url}")
+
+ if response.status_code != 200:
+ raise SSRFFetchError(f"HTTP {response.status_code} fetching {url}")
+
+ # Check Content-Length header first if available
+ content_length = response.headers.get("content-length")
+ if content_length:
+ try:
+ size = int(content_length)
+ if size > max_size:
+ raise SSRFFetchError(
+ f"Response too large: {size} bytes (max {max_size})"
+ )
+ except ValueError:
+ pass
+
+ # Stream the response and enforce size limit during download
+ chunks = []
+ total = 0
+ async for chunk in response.aiter_bytes():
+ if time.monotonic() - start_time > overall_timeout:
+ raise SSRFFetchError(f"Overall timeout exceeded: {url}")
+ total += len(chunk)
+ if total > max_size:
+ raise SSRFFetchError(
+ f"Response too large: exceeded {max_size} bytes"
+ )
+ chunks.append(chunk)
+
+ return b"".join(chunks)
+
+ except httpx.TimeoutException as e:
+ last_error = e
+ continue
+ except httpx.RequestError as e:
+ last_error = e
+ continue
+
+ if last_error is not None:
+ if isinstance(last_error, httpx.TimeoutException):
+ raise SSRFFetchError(f"Timeout fetching {url}") from last_error
+ raise SSRFFetchError(f"Error fetching {url}: {last_error}") from last_error
+
+ raise SSRFFetchError(f"Error fetching {url}: no resolved IPs succeeded")
diff --git a/tests/cli/test_cimd_cli.py b/tests/cli/test_cimd_cli.py
new file mode 100644
index 0000000000..301c440ede
--- /dev/null
+++ b/tests/cli/test_cimd_cli.py
@@ -0,0 +1,208 @@
+"""Tests for the CIMD CLI commands (create and validate)."""
+
+from __future__ import annotations
+
+import json
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from pydantic import AnyHttpUrl
+
+from fastmcp.cli.cimd import create_command, validate_command
+from fastmcp.server.auth.cimd import CIMDDocument, CIMDFetchError, CIMDValidationError
+
+
+class TestCIMDCreateCommand:
+ """Tests for `fastmcp auth cimd create`."""
+
+ def test_minimal_output(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert doc["client_name"] == "Test App"
+ assert doc["redirect_uris"] == ["http://localhost:*/callback"]
+ assert doc["token_endpoint_auth_method"] == "none"
+ assert doc["grant_types"] == ["authorization_code"]
+ assert doc["response_types"] == ["code"]
+ # Placeholder client_id
+ assert "YOUR-DOMAIN" in doc["client_id"]
+
+ def test_with_client_id(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ client_id="https://myapp.example.com/client.json",
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert doc["client_id"] == "https://myapp.example.com/client.json"
+
+ def test_with_output_file(self, tmp_path):
+ output_file = tmp_path / "client.json"
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ client_id="https://example.com/client.json",
+ output=str(output_file),
+ )
+ doc = json.loads(output_file.read_text())
+ assert doc["client_id"] == "https://example.com/client.json"
+ assert doc["client_name"] == "Test App"
+
+ def test_relative_path_resolved(self, tmp_path, monkeypatch):
+ """Relative paths should be resolved against cwd."""
+ monkeypatch.chdir(tmp_path)
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ output="./subdir/client.json",
+ )
+ resolved = tmp_path / "subdir" / "client.json"
+ assert resolved.exists()
+ doc = json.loads(resolved.read_text())
+ assert doc["client_name"] == "Test App"
+
+ def test_with_scope(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ scope="read write",
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert doc["scope"] == "read write"
+
+ def test_with_client_uri(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ client_uri="https://example.com",
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert doc["client_uri"] == "https://example.com"
+
+ def test_with_logo_uri(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ logo_uri="https://example.com/logo.png",
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert doc["logo_uri"] == "https://example.com/logo.png"
+
+ def test_multiple_redirect_uris(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=[
+ "http://localhost:*/callback",
+ "https://myapp.example.com/callback",
+ ],
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert len(doc["redirect_uris"]) == 2
+
+ def test_no_pretty(self, capsys: pytest.CaptureFixture[str]):
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ pretty=False,
+ )
+ output = capsys.readouterr().out.strip()
+ # Compact JSON has no newlines within the object
+ assert "\n" not in output
+ doc = json.loads(output)
+ assert doc["client_name"] == "Test App"
+
+ def test_placeholder_warning_on_stderr(self, capsys: pytest.CaptureFixture[str]):
+ """When outputting to stdout with no --client-id, warning goes to stderr."""
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ )
+ captured = capsys.readouterr()
+ # stdout has valid JSON
+ json.loads(captured.out)
+ # stderr has the warning (Rich Console writes to stderr)
+ assert "placeholder" in captured.err
+
+ def test_no_warning_with_client_id(self, capsys: pytest.CaptureFixture[str]):
+ """No placeholder warning when --client-id is provided."""
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ client_id="https://example.com/client.json",
+ )
+ captured = capsys.readouterr()
+ assert "placeholder" not in captured.err
+
+ def test_optional_fields_omitted_when_none(
+ self, capsys: pytest.CaptureFixture[str]
+ ):
+ """Optional fields like scope, client_uri, logo_uri are omitted if not given."""
+ create_command(
+ name="Test App",
+ redirect_uri=["http://localhost:*/callback"],
+ )
+ doc = json.loads(capsys.readouterr().out)
+ assert "scope" not in doc
+ assert "client_uri" not in doc
+ assert "logo_uri" not in doc
+
+
+class TestCIMDValidateCommand:
+ """Tests for `fastmcp auth cimd validate`."""
+
+ def test_invalid_url_format(self, capsys: pytest.CaptureFixture[str]):
+ with pytest.raises(SystemExit, match="1"):
+ validate_command("http://insecure.com/client.json")
+ captured = capsys.readouterr()
+ assert "Invalid CIMD URL" in captured.out
+
+ def test_root_path_rejected(self, capsys: pytest.CaptureFixture[str]):
+ with pytest.raises(SystemExit, match="1"):
+ validate_command("https://example.com/")
+ captured = capsys.readouterr()
+ assert "Invalid CIMD URL" in captured.out
+
+ def test_success(self, capsys: pytest.CaptureFixture[str]):
+ mock_doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://myapp.example.com/client.json"),
+ client_name="Test App",
+ redirect_uris=["http://localhost:*/callback"],
+ token_endpoint_auth_method="none",
+ grant_types=["authorization_code"],
+ response_types=["code"],
+ )
+ with patch.object(CIMDDocument, "__init__", return_value=None):
+ pass
+ mock_fetch = AsyncMock(return_value=mock_doc)
+ with patch(
+ "fastmcp.cli.cimd.CIMDFetcher.fetch",
+ mock_fetch,
+ ):
+ validate_command("https://myapp.example.com/client.json")
+ captured = capsys.readouterr()
+ assert "Valid CIMD document" in captured.out
+ assert "Test App" in captured.out
+
+ def test_fetch_error(self, capsys: pytest.CaptureFixture[str]):
+ mock_fetch = AsyncMock(side_effect=CIMDFetchError("Connection refused"))
+ with patch(
+ "fastmcp.cli.cimd.CIMDFetcher.fetch",
+ mock_fetch,
+ ):
+ with pytest.raises(SystemExit, match="1"):
+ validate_command("https://myapp.example.com/client.json")
+ captured = capsys.readouterr()
+ assert "Failed to fetch" in captured.out
+
+ def test_validation_error(self, capsys: pytest.CaptureFixture[str]):
+ mock_fetch = AsyncMock(side_effect=CIMDValidationError("client_id mismatch"))
+ with patch(
+ "fastmcp.cli.cimd.CIMDFetcher.fetch",
+ mock_fetch,
+ ):
+ with pytest.raises(SystemExit, match="1"):
+ validate_command("https://myapp.example.com/client.json")
+ captured = capsys.readouterr()
+ assert "Validation error" in captured.out
diff --git a/tests/client/auth/test_oauth_cimd.py b/tests/client/auth/test_oauth_cimd.py
new file mode 100644
index 0000000000..04818af605
--- /dev/null
+++ b/tests/client/auth/test_oauth_cimd.py
@@ -0,0 +1,164 @@
+"""Tests for CIMD (Client ID Metadata Document) support in the OAuth client."""
+
+from __future__ import annotations
+
+import warnings
+
+import httpx
+import pytest
+
+from fastmcp.client.auth import OAuth
+from fastmcp.client.transports import StreamableHttpTransport
+from fastmcp.client.transports.sse import SSETransport
+
+VALID_CIMD_URL = "https://myapp.example.com/oauth/client.json"
+MCP_SERVER_URL = "https://mcp-server.example.com/mcp"
+
+
+class TestOAuthClientMetadataURL:
+ """Tests for the client_metadata_url parameter on OAuth."""
+
+ def test_stored_on_instance(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ assert oauth._client_metadata_url == VALID_CIMD_URL
+
+ def test_none_by_default(self):
+ oauth = OAuth()
+ assert oauth._client_metadata_url is None
+
+ def test_passed_to_parent_on_bind(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ oauth._bind(MCP_SERVER_URL)
+ assert oauth.context.client_metadata_url == VALID_CIMD_URL
+
+ def test_none_metadata_url_on_parent(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth = OAuth(mcp_url=MCP_SERVER_URL)
+ assert oauth.context.client_metadata_url is None
+
+ def test_unbound_when_no_mcp_url(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ assert oauth._bound is False
+
+ def test_bound_when_mcp_url_provided(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth = OAuth(
+ mcp_url=MCP_SERVER_URL,
+ client_metadata_url=VALID_CIMD_URL,
+ )
+ assert oauth._bound is True
+
+ def test_invalid_cimd_url_rejected(self):
+ """CIMD URLs must be HTTPS with a non-root path."""
+ with pytest.raises(ValueError, match="valid HTTPS URL"):
+ OAuth(
+ mcp_url=MCP_SERVER_URL,
+ client_metadata_url="http://insecure.com/client.json",
+ )
+
+ def test_root_path_cimd_url_rejected(self):
+ with pytest.raises(ValueError, match="valid HTTPS URL"):
+ OAuth(
+ mcp_url=MCP_SERVER_URL,
+ client_metadata_url="https://example.com/",
+ )
+
+
+class TestOAuthBind:
+ """Tests for the _bind() deferred initialization."""
+
+ def test_bind_sets_bound_true(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ assert oauth._bound is False
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth._bind(MCP_SERVER_URL)
+ assert oauth._bound is True
+
+ def test_bind_idempotent(self):
+ """Second call to _bind is a no-op."""
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth._bind(MCP_SERVER_URL)
+ oauth._bind("https://other-server.example.com/mcp")
+ # First binding wins
+ assert oauth.mcp_url == MCP_SERVER_URL
+
+ def test_bind_sets_mcp_url(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth._bind(MCP_SERVER_URL + "/")
+ # Trailing slash stripped
+ assert oauth.mcp_url == MCP_SERVER_URL
+
+ def test_bind_creates_token_storage(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ assert not hasattr(oauth, "token_storage_adapter")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth._bind(MCP_SERVER_URL)
+ assert hasattr(oauth, "token_storage_adapter")
+
+ async def test_unbound_raises_runtime_error(self):
+ """async_auth_flow should fail clearly when OAuth is not bound."""
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ request = httpx.Request("GET", MCP_SERVER_URL)
+ with pytest.raises(RuntimeError, match="no server URL"):
+ async for _ in oauth.async_auth_flow(request):
+ pass
+
+ def test_scopes_forwarded_as_list(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth = OAuth(
+ client_metadata_url=VALID_CIMD_URL,
+ scopes=["read", "write"],
+ )
+ oauth._bind(MCP_SERVER_URL)
+ assert oauth.context.client_metadata.scope == "read write"
+
+ def test_scopes_forwarded_as_string(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ oauth = OAuth(
+ client_metadata_url=VALID_CIMD_URL,
+ scopes="read write",
+ )
+ oauth._bind(MCP_SERVER_URL)
+ assert oauth.context.client_metadata.scope == "read write"
+
+
+class TestOAuthBindFromTransport:
+ """Tests that transports call _bind() on OAuth instances."""
+
+ def test_http_transport_binds_oauth(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ assert oauth._bound is False
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ StreamableHttpTransport(MCP_SERVER_URL, auth=oauth)
+ assert oauth._bound is True
+ assert oauth.mcp_url == MCP_SERVER_URL
+
+ def test_sse_transport_binds_oauth(self):
+ oauth = OAuth(client_metadata_url=VALID_CIMD_URL)
+ assert oauth._bound is False
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ SSETransport(MCP_SERVER_URL, auth=oauth)
+ assert oauth._bound is True
+ assert oauth.mcp_url == MCP_SERVER_URL
+
+ def test_http_transport_oauth_string_still_works(self):
+ """auth="oauth" should still create a new OAuth instance."""
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ transport = StreamableHttpTransport(MCP_SERVER_URL, auth="oauth")
+ assert isinstance(transport.auth, OAuth)
+ assert transport.auth._bound is True
diff --git a/tests/server/auth/oauth_proxy/test_oauth_proxy.py b/tests/server/auth/oauth_proxy/test_oauth_proxy.py
index 27c31e1988..b605e50fd2 100644
--- a/tests/server/auth/oauth_proxy/test_oauth_proxy.py
+++ b/tests/server/auth/oauth_proxy/test_oauth_proxy.py
@@ -1,6 +1,8 @@
"""Tests for OAuth proxy initialization and configuration."""
+import httpx
from key_value.aio.stores.memory import MemoryStore
+from starlette.applications import Starlette
from fastmcp.server.auth.oauth_proxy import OAuthProxy
@@ -72,3 +74,29 @@ def test_redirect_path_normalization(self, jwt_verifier):
client_storage=MemoryStore(),
)
assert proxy._redirect_path == "/auth/callback"
+
+ async def test_metadata_advertises_cimd_support(self, jwt_verifier):
+ """OAuth metadata should advertise CIMD support when enabled."""
+ proxy = OAuthProxy(
+ upstream_authorization_endpoint="https://auth.example.com/authorize",
+ upstream_token_endpoint="https://auth.example.com/token",
+ upstream_client_id="client-123",
+ upstream_client_secret="secret-456",
+ token_verifier=jwt_verifier,
+ base_url="https://api.example.com",
+ jwt_signing_key="test-secret",
+ client_storage=MemoryStore(),
+ enable_cimd=True,
+ )
+
+ app = Starlette(routes=proxy.get_routes())
+ transport = httpx.ASGITransport(app=app)
+
+ async with httpx.AsyncClient(
+ transport=transport, base_url="https://api.example.com"
+ ) as client:
+ response = await client.get("/.well-known/oauth-authorization-server")
+
+ assert response.status_code == 200
+ metadata = response.json()
+ assert metadata.get("client_id_metadata_document_supported") is True
diff --git a/tests/server/auth/test_cimd.py b/tests/server/auth/test_cimd.py
new file mode 100644
index 0000000000..d3c3e316e2
--- /dev/null
+++ b/tests/server/auth/test_cimd.py
@@ -0,0 +1,971 @@
+"""Unit tests for CIMD (Client ID Metadata Document) functionality."""
+
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from pydantic import AnyHttpUrl, ValidationError
+
+from fastmcp.server.auth.cimd import (
+ CIMDAssertionValidator,
+ CIMDClientManager,
+ CIMDDocument,
+ CIMDFetcher,
+ CIMDFetchError,
+ CIMDValidationError,
+)
+from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient
+
+# Standard public IP used for DNS mocking in tests
+TEST_PUBLIC_IP = "93.184.216.34"
+
+
+class TestCIMDDocument:
+ """Tests for CIMDDocument model validation."""
+
+ def test_valid_minimal_document(self):
+ """Test that minimal valid document passes validation."""
+ doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ )
+ assert str(doc.client_id) == "https://example.com/client.json"
+ assert doc.token_endpoint_auth_method == "none"
+ assert doc.grant_types == ["authorization_code"]
+ assert doc.response_types == ["code"]
+
+ def test_valid_full_document(self):
+ """Test that full document passes validation."""
+ doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ client_name="My App",
+ client_uri=AnyHttpUrl("https://example.com"),
+ logo_uri=AnyHttpUrl("https://example.com/logo.png"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="none",
+ grant_types=["authorization_code", "refresh_token"],
+ response_types=["code"],
+ scope="read write",
+ )
+ assert doc.client_name == "My App"
+ assert doc.scope == "read write"
+
+ def test_private_key_jwt_auth_method_allowed(self):
+ """Test that private_key_jwt is allowed for CIMD."""
+ doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="private_key_jwt",
+ jwks_uri=AnyHttpUrl("https://example.com/.well-known/jwks.json"),
+ )
+ assert doc.token_endpoint_auth_method == "private_key_jwt"
+
+ def test_client_secret_basic_rejected(self):
+ """Test that client_secret_basic is rejected for CIMD."""
+ with pytest.raises(ValidationError) as exc_info:
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="client_secret_basic", # type: ignore[arg-type] - testing invalid value
+ )
+ # Literal type rejects invalid values before custom validator
+ assert "token_endpoint_auth_method" in str(exc_info.value)
+
+ def test_client_secret_post_rejected(self):
+ """Test that client_secret_post is rejected for CIMD."""
+ with pytest.raises(ValidationError) as exc_info:
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="client_secret_post", # type: ignore[arg-type] - testing invalid value
+ )
+ assert "token_endpoint_auth_method" in str(exc_info.value)
+
+ def test_client_secret_jwt_rejected(self):
+ """Test that client_secret_jwt is rejected for CIMD."""
+ with pytest.raises(ValidationError) as exc_info:
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="client_secret_jwt", # type: ignore[arg-type] - testing invalid value
+ )
+ assert "token_endpoint_auth_method" in str(exc_info.value)
+
+ def test_missing_redirect_uris_rejected(self):
+ """Test that redirect_uris is required for CIMD."""
+ with pytest.raises(ValidationError) as exc_info:
+ CIMDDocument(client_id=AnyHttpUrl("https://example.com/client.json"))
+ assert "redirect_uris" in str(exc_info.value)
+
+ def test_empty_redirect_uris_rejected(self):
+ """Test that empty redirect_uris is rejected."""
+ with pytest.raises(ValidationError) as exc_info:
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=[],
+ )
+ assert "redirect_uris" in str(exc_info.value)
+
+ def test_redirect_uri_without_scheme_rejected(self):
+ """Test that redirect_uris without a scheme are rejected."""
+ with pytest.raises(ValidationError, match="must have a scheme"):
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["/just/a/path"],
+ )
+
+ def test_redirect_uri_without_host_rejected(self):
+ """Test that redirect_uris without a host are rejected."""
+ with pytest.raises(ValidationError, match="must have a host"):
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://"],
+ )
+
+ def test_redirect_uri_whitespace_only_rejected(self):
+ """Test that whitespace-only redirect_uris are rejected."""
+ with pytest.raises(ValidationError, match="non-empty"):
+ CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=[" "],
+ )
+
+
+class TestCIMDFetcher:
+ """Tests for CIMDFetcher."""
+
+ @pytest.fixture
+ def fetcher(self):
+ """Create a CIMDFetcher for testing."""
+ return CIMDFetcher()
+
+ def test_is_cimd_client_id_valid_urls(self, fetcher: CIMDFetcher):
+ """Test is_cimd_client_id accepts valid CIMD URLs."""
+ assert fetcher.is_cimd_client_id("https://example.com/client.json")
+ assert fetcher.is_cimd_client_id("https://example.com/path/to/client")
+ assert fetcher.is_cimd_client_id("https://sub.example.com/cimd.json")
+
+ def test_is_cimd_client_id_rejects_http(self, fetcher: CIMDFetcher):
+ """Test is_cimd_client_id rejects HTTP URLs."""
+ assert not fetcher.is_cimd_client_id("http://example.com/client.json")
+
+ def test_is_cimd_client_id_rejects_root_path(self, fetcher: CIMDFetcher):
+ """Test is_cimd_client_id rejects URLs with no path."""
+ assert not fetcher.is_cimd_client_id("https://example.com/")
+ assert not fetcher.is_cimd_client_id("https://example.com")
+
+ def test_is_cimd_client_id_rejects_non_url(self, fetcher: CIMDFetcher):
+ """Test is_cimd_client_id rejects non-URL strings."""
+ assert not fetcher.is_cimd_client_id("client-123")
+ assert not fetcher.is_cimd_client_id("my-client")
+ assert not fetcher.is_cimd_client_id("")
+ assert not fetcher.is_cimd_client_id("not a url")
+
+ def test_validate_redirect_uri_exact_match(self, fetcher: CIMDFetcher):
+ """Test redirect_uri validation with exact match."""
+ doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ )
+ assert fetcher.validate_redirect_uri(doc, "http://localhost:3000/callback")
+ assert not fetcher.validate_redirect_uri(doc, "http://localhost:4000/callback")
+
+ def test_validate_redirect_uri_wildcard_match(self, fetcher: CIMDFetcher):
+ """Test redirect_uri validation with wildcard port."""
+ doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:*/callback"],
+ )
+ assert fetcher.validate_redirect_uri(doc, "http://localhost:3000/callback")
+ assert fetcher.validate_redirect_uri(doc, "http://localhost:8080/callback")
+ assert not fetcher.validate_redirect_uri(doc, "http://localhost:3000/other")
+
+
+class TestCIMDFetcherHTTP:
+ """Tests for CIMDFetcher HTTP fetching (using httpx mock).
+
+ Note: With SSRF protection and DNS pinning, HTTP requests go to the resolved IP
+ instead of the hostname. These tests mock DNS resolution to return a public IP
+ and configure httpx_mock to expect the IP-based URL.
+ """
+
+ @pytest.fixture
+ def fetcher(self):
+ """Create a CIMDFetcher for testing."""
+ return CIMDFetcher()
+
+ @pytest.fixture
+ def mock_dns(self):
+ """Mock DNS resolution to return test public IP."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ yield
+
+ async def test_fetch_success(self, fetcher: CIMDFetcher, httpx_mock, mock_dns):
+ """Test successful CIMD document fetch."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+
+ # With DNS pinning, request goes to IP. Match any URL.
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={
+ "content-type": "application/json",
+ "content-length": "200",
+ },
+ )
+
+ doc = await fetcher.fetch(url)
+ assert str(doc.client_id) == url
+ assert doc.client_name == "Test App"
+
+ async def test_fetch_ttl_cache(self, fetcher: CIMDFetcher, httpx_mock, mock_dns):
+ """Test that fetched documents are cached and served from cache within TTL."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ first = await fetcher.fetch(url)
+ second = await fetcher.fetch(url)
+
+ assert first.client_id == second.client_id
+ assert len(httpx_mock.get_requests()) == 1
+
+ async def test_fetch_client_id_mismatch(
+ self, fetcher: CIMDFetcher, httpx_mock, mock_dns
+ ):
+ """Test that client_id mismatch is rejected."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": "https://other.com/client.json", # Different URL
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "100"},
+ )
+
+ with pytest.raises(CIMDValidationError) as exc_info:
+ await fetcher.fetch(url)
+ assert "mismatch" in str(exc_info.value).lower()
+
+ async def test_fetch_http_error(self, fetcher: CIMDFetcher, httpx_mock, mock_dns):
+ """Test handling of HTTP errors."""
+ url = "https://example.com/client.json"
+ httpx_mock.add_response(status_code=404)
+
+ with pytest.raises(CIMDFetchError) as exc_info:
+ await fetcher.fetch(url)
+ assert "404" in str(exc_info.value)
+
+ async def test_fetch_invalid_json(self, fetcher: CIMDFetcher, httpx_mock, mock_dns):
+ """Test handling of invalid JSON response."""
+ url = "https://example.com/client.json"
+ httpx_mock.add_response(
+ content=b"not json",
+ headers={"content-length": "10"},
+ )
+
+ with pytest.raises(CIMDValidationError) as exc_info:
+ await fetcher.fetch(url)
+ assert "JSON" in str(exc_info.value)
+
+ async def test_fetch_invalid_document(
+ self, fetcher: CIMDFetcher, httpx_mock, mock_dns
+ ):
+ """Test handling of invalid CIMD document."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "client_secret_basic", # Not allowed
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "100"},
+ )
+
+ with pytest.raises(CIMDValidationError) as exc_info:
+ await fetcher.fetch(url)
+ assert "Invalid CIMD document" in str(exc_info.value)
+
+
+class TestCIMDAssertionValidator:
+ """Tests for CIMDAssertionValidator (private_key_jwt support)."""
+
+ @pytest.fixture
+ def validator(self):
+ """Create a CIMDAssertionValidator for testing."""
+ return CIMDAssertionValidator()
+
+ @pytest.fixture
+ def key_pair(self):
+ """Generate RSA key pair for testing."""
+ from fastmcp.server.auth.providers.jwt import RSAKeyPair
+
+ return RSAKeyPair.generate()
+
+ @pytest.fixture
+ def jwks(self, key_pair):
+ """Create JWKS from key pair."""
+ import base64
+
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives import serialization
+
+ # Load public key
+ public_key = serialization.load_pem_public_key(
+ key_pair.public_key.encode(), backend=default_backend()
+ )
+
+ # Get RSA public numbers
+ from cryptography.hazmat.primitives.asymmetric import rsa
+
+ if isinstance(public_key, rsa.RSAPublicKey):
+ numbers = public_key.public_numbers()
+
+ # Convert to JWK format
+ return {
+ "keys": [
+ {
+ "kty": "RSA",
+ "kid": "test-key-1",
+ "use": "sig",
+ "alg": "RS256",
+ "n": base64.urlsafe_b64encode(
+ numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big")
+ )
+ .rstrip(b"=")
+ .decode(),
+ "e": base64.urlsafe_b64encode(
+ numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big")
+ )
+ .rstrip(b"=")
+ .decode(),
+ }
+ ]
+ }
+
+ @pytest.fixture
+ def cimd_doc_with_jwks_uri(self):
+ """Create CIMD document with jwks_uri."""
+ return CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="private_key_jwt",
+ jwks_uri=AnyHttpUrl("https://example.com/.well-known/jwks.json"),
+ )
+
+ @pytest.fixture
+ def cimd_doc_with_inline_jwks(self, jwks):
+ """Create CIMD document with inline JWKS."""
+ return CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="private_key_jwt",
+ jwks=jwks,
+ )
+
+ async def test_valid_assertion_with_jwks_uri(
+ self, validator, key_pair, cimd_doc_with_jwks_uri, httpx_mock
+ ):
+ """Test that valid JWT assertion passes validation (jwks_uri)."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Mock JWKS endpoint
+ import base64
+
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives import serialization
+
+ public_key = serialization.load_pem_public_key(
+ key_pair.public_key.encode(), backend=default_backend()
+ )
+ from cryptography.hazmat.primitives.asymmetric import rsa
+
+ assert isinstance(public_key, rsa.RSAPublicKey)
+ numbers = public_key.public_numbers()
+
+ jwks = {
+ "keys": [
+ {
+ "kty": "RSA",
+ "kid": "test-key-1",
+ "use": "sig",
+ "alg": "RS256",
+ "n": base64.urlsafe_b64encode(
+ numbers.n.to_bytes((numbers.n.bit_length() + 7) // 8, "big")
+ )
+ .rstrip(b"=")
+ .decode(),
+ "e": base64.urlsafe_b64encode(
+ numbers.e.to_bytes((numbers.e.bit_length() + 7) // 8, "big")
+ )
+ .rstrip(b"=")
+ .decode(),
+ }
+ ]
+ }
+
+ # Mock DNS resolution for SSRF-safe fetch
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ httpx_mock.add_response(json=jwks)
+
+ # Create valid assertion (use short lifetime for security compliance)
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer=client_id,
+ audience=token_endpoint,
+ additional_claims={"jti": "unique-jti-123"},
+ expires_in_seconds=60, # 1 minute (max allowed is 300s)
+ kid="test-key-1",
+ )
+
+ # Should validate successfully
+ assert await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_jwks_uri
+ )
+
+ async def test_valid_assertion_with_inline_jwks(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that valid JWT assertion passes validation (inline JWKS)."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create valid assertion (use short lifetime for security compliance)
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer=client_id,
+ audience=token_endpoint,
+ additional_claims={"jti": "unique-jti-456"},
+ expires_in_seconds=60, # 1 minute (max allowed is 300s)
+ kid="test-key-1",
+ )
+
+ # Should validate successfully
+ assert await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+
+ async def test_rejects_wrong_issuer(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that wrong issuer is rejected."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create assertion with wrong issuer
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer="https://attacker.com", # Wrong!
+ audience=token_endpoint,
+ additional_claims={"jti": "unique-jti-789"},
+ expires_in_seconds=60,
+ kid="test-key-1",
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+ assert "Invalid JWT assertion" in str(exc_info.value)
+
+ async def test_rejects_wrong_audience(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that wrong audience is rejected."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create assertion with wrong audience
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer=client_id,
+ audience="https://wrong-endpoint.com/token", # Wrong!
+ additional_claims={"jti": "unique-jti-abc"},
+ expires_in_seconds=60,
+ kid="test-key-1",
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+ assert "Invalid JWT assertion" in str(exc_info.value)
+
+ async def test_rejects_wrong_subject(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that wrong subject claim is rejected."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create assertion with wrong subject
+ assertion = key_pair.create_token(
+ subject="https://different-client.com", # Wrong!
+ issuer=client_id,
+ audience=token_endpoint,
+ additional_claims={"jti": "unique-jti-def"},
+ expires_in_seconds=60,
+ kid="test-key-1",
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+ assert "sub claim must be" in str(exc_info.value)
+
+ async def test_rejects_missing_jti(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that missing jti claim is rejected."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create assertion without jti
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer=client_id,
+ audience=token_endpoint,
+ # No jti!
+ expires_in_seconds=60,
+ kid="test-key-1",
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+ assert "jti claim" in str(exc_info.value)
+
+ async def test_rejects_replayed_jti(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that replayed JTI is detected and rejected."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create assertion
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer=client_id,
+ audience=token_endpoint,
+ additional_claims={"jti": "replayed-jti"},
+ expires_in_seconds=60,
+ kid="test-key-1",
+ )
+
+ # First use should succeed
+ assert await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+
+ # Second use with same jti should fail (replay attack)
+ with pytest.raises(ValueError) as exc_info:
+ await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+ assert "replay" in str(exc_info.value).lower()
+
+ async def test_rejects_expired_token(
+ self, validator, key_pair, cimd_doc_with_inline_jwks
+ ):
+ """Test that expired tokens are rejected."""
+ client_id = "https://example.com/client.json"
+ token_endpoint = "https://oauth.example.com/token"
+
+ # Create expired assertion (expired 1 hour ago)
+ assertion = key_pair.create_token(
+ subject=client_id,
+ issuer=client_id,
+ audience=token_endpoint,
+ additional_claims={"jti": "expired-jti"},
+ expires_in_seconds=-3600, # Negative = expired
+ kid="test-key-1",
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ await validator.validate_assertion(
+ assertion, client_id, token_endpoint, cimd_doc_with_inline_jwks
+ )
+ assert "Invalid JWT assertion" in str(exc_info.value)
+
+
+class TestCIMDClientManager:
+ """Tests for CIMDClientManager."""
+
+ @pytest.fixture
+ def manager(self):
+ """Create a CIMDClientManager for testing."""
+ return CIMDClientManager(enable_cimd=True)
+
+ @pytest.fixture
+ def disabled_manager(self):
+ """Create a disabled CIMDClientManager for testing."""
+ return CIMDClientManager(enable_cimd=False)
+
+ @pytest.fixture
+ def mock_dns(self):
+ """Mock DNS resolution to return test public IP."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ yield
+
+ def test_is_cimd_client_id_enabled(self, manager):
+ """Test CIMD URL detection when enabled."""
+ assert manager.is_cimd_client_id("https://example.com/client.json")
+ assert not manager.is_cimd_client_id("regular-client-id")
+
+ def test_is_cimd_client_id_disabled(self, disabled_manager):
+ """Test CIMD URL detection when disabled."""
+ assert not disabled_manager.is_cimd_client_id("https://example.com/client.json")
+ assert not disabled_manager.is_cimd_client_id("regular-client-id")
+
+ async def test_get_client_success(self, manager, httpx_mock, mock_dns):
+ """Test successful CIMD client creation."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ client = await manager.get_client(url)
+ assert client is not None
+ assert client.client_id == url
+ assert client.client_name == "Test App"
+ # Verify it uses proxy's patterns (None by default), not document's redirect_uris
+ assert client.allowed_redirect_uri_patterns is None
+
+ async def test_get_client_disabled(self, disabled_manager):
+ """Test that get_client returns None when disabled."""
+ client = await disabled_manager.get_client("https://example.com/client.json")
+ assert client is None
+
+ async def test_get_client_fetch_failure(self, manager, httpx_mock, mock_dns):
+ """Test that get_client returns None on fetch failure."""
+ url = "https://example.com/client.json"
+ httpx_mock.add_response(status_code=404)
+
+ client = await manager.get_client(url)
+ assert client is None
+
+ # Trust policy and consent bypass tests removed - functionality removed from CIMD
+
+
+class TestCIMDClientManagerGetClientOptions:
+ """Tests for CIMDClientManager.get_client with default_scope and allowed patterns."""
+
+ @pytest.fixture
+ def mock_dns(self):
+ """Mock DNS resolution to return test public IP."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ yield
+
+ async def test_default_scope_applied_when_doc_has_no_scope(
+ self, httpx_mock, mock_dns
+ ):
+ """When the CIMD document omits scope, the manager's default_scope is used."""
+
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ # No scope field
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ manager = CIMDClientManager(
+ enable_cimd=True,
+ default_scope="read write admin",
+ )
+ client = await manager.get_client(url)
+ assert client is not None
+ assert client.scope == "read write admin"
+
+ async def test_doc_scope_takes_precedence_over_default(self, httpx_mock, mock_dns):
+ """When the CIMD document specifies scope, it wins over the default."""
+
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ "scope": "custom-scope",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ manager = CIMDClientManager(
+ enable_cimd=True,
+ default_scope="default-scope",
+ )
+ client = await manager.get_client(url)
+ assert client is not None
+ assert client.scope == "custom-scope"
+
+ async def test_allowed_redirect_uri_patterns_stored_on_client(
+ self, httpx_mock, mock_dns
+ ):
+ """Proxy's allowed_redirect_uri_patterns are forwarded to the created client."""
+
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ "redirect_uris": ["http://localhost:*/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ patterns = ["http://localhost:*", "https://app.example.com/*"]
+ manager = CIMDClientManager(
+ enable_cimd=True,
+ allowed_redirect_uri_patterns=patterns,
+ )
+ client = await manager.get_client(url)
+ assert client is not None
+ assert client.allowed_redirect_uri_patterns == patterns
+
+ async def test_cimd_document_attached_to_client(self, httpx_mock, mock_dns):
+ """The fetched CIMDDocument is attached to the created client."""
+
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Attached Doc App",
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ manager = CIMDClientManager(enable_cimd=True)
+ client = await manager.get_client(url)
+ assert client is not None
+ assert client.cimd_document is not None
+ assert client.cimd_document.client_name == "Attached Doc App"
+ assert str(client.cimd_document.client_id) == url
+
+
+class TestCIMDClientManagerValidatePrivateKeyJwt:
+ """Tests for CIMDClientManager.validate_private_key_jwt wrapper."""
+
+ @pytest.fixture
+ def manager(self):
+ return CIMDClientManager(enable_cimd=True)
+
+ async def test_missing_cimd_document_raises(self, manager):
+ """validate_private_key_jwt raises ValueError if client has no cimd_document."""
+
+ client = ProxyDCRClient(
+ client_id="https://example.com/client.json",
+ client_secret=None,
+ redirect_uris=None,
+ cimd_document=None,
+ )
+ with pytest.raises(ValueError, match="must have CIMD document"):
+ await manager.validate_private_key_jwt(
+ "fake.jwt.token",
+ client,
+ "https://oauth.example.com/token",
+ )
+
+ async def test_wrong_auth_method_raises(self, manager):
+ """validate_private_key_jwt raises ValueError if auth method is not private_key_jwt."""
+
+ cimd_doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="none", # Not private_key_jwt
+ )
+ client = ProxyDCRClient(
+ client_id="https://example.com/client.json",
+ client_secret=None,
+ redirect_uris=None,
+ cimd_document=cimd_doc,
+ )
+ with pytest.raises(ValueError, match="private_key_jwt"):
+ await manager.validate_private_key_jwt(
+ "fake.jwt.token",
+ client,
+ "https://oauth.example.com/token",
+ )
+
+ async def test_success_delegates_to_assertion_validator(self, manager):
+ """On success, validate_private_key_jwt delegates to the assertion validator."""
+
+ cimd_doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ token_endpoint_auth_method="private_key_jwt",
+ jwks_uri=AnyHttpUrl("https://example.com/.well-known/jwks.json"),
+ )
+ client = ProxyDCRClient(
+ client_id="https://example.com/client.json",
+ client_secret=None,
+ redirect_uris=None,
+ cimd_document=cimd_doc,
+ )
+
+ manager._assertion_validator.validate_assertion = AsyncMock(return_value=True)
+
+ result = await manager.validate_private_key_jwt(
+ "test.jwt.assertion",
+ client,
+ "https://oauth.example.com/token",
+ )
+ assert result is True
+ manager._assertion_validator.validate_assertion.assert_awaited_once_with(
+ "test.jwt.assertion",
+ "https://example.com/client.json",
+ "https://oauth.example.com/token",
+ cimd_doc,
+ )
+
+
+class TestCIMDRedirectUriEnforcement:
+ """Tests for CIMD redirect_uri validation security.
+
+ Verifies that CIMD clients enforce BOTH:
+ 1. CIMD document's redirect_uris
+ 2. Proxy's allowed_redirect_uri_patterns
+ """
+
+ @pytest.fixture
+ def mock_dns(self):
+ """Mock DNS resolution to return test public IP."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ yield
+
+ async def test_cimd_redirect_uris_enforced(self, httpx_mock, mock_dns):
+ """Test that CIMD document redirect_uris are enforced.
+
+ Even if proxy patterns allow http://localhost:*, a CIMD client
+ should only accept URIs declared in its document.
+ """
+ from mcp.shared.auth import InvalidRedirectUriError
+ from pydantic import AnyUrl
+
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ # CIMD only declares port 3000
+ "redirect_uris": ["http://localhost:3000/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ # Proxy allows any localhost port
+ manager = CIMDClientManager(
+ enable_cimd=True,
+ allowed_redirect_uri_patterns=["http://localhost:*"],
+ )
+ client = await manager.get_client(url)
+ assert client is not None
+
+ # Declared URI should work
+ validated = client.validate_redirect_uri(
+ AnyUrl("http://localhost:3000/callback")
+ )
+ assert str(validated) == "http://localhost:3000/callback"
+
+ # Different port should fail (not in CIMD redirect_uris)
+ with pytest.raises(InvalidRedirectUriError):
+ client.validate_redirect_uri(AnyUrl("http://localhost:4000/callback"))
+
+ async def test_proxy_patterns_also_checked(self, httpx_mock, mock_dns):
+ """Test that proxy patterns are checked even for CIMD clients.
+
+ A CIMD client should not be able to use a redirect_uri that's
+ in its document but not allowed by proxy patterns.
+ """
+ from mcp.shared.auth import InvalidRedirectUriError
+ from pydantic import AnyUrl
+
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Test App",
+ # CIMD declares both localhost and external URI
+ "redirect_uris": [
+ "http://localhost:3000/callback",
+ "https://evil.com/callback",
+ ],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ # Proxy only allows localhost
+ manager = CIMDClientManager(
+ enable_cimd=True,
+ allowed_redirect_uri_patterns=["http://localhost:*"],
+ )
+ client = await manager.get_client(url)
+ assert client is not None
+
+ # Localhost should work (in CIMD and matches pattern)
+ validated = client.validate_redirect_uri(
+ AnyUrl("http://localhost:3000/callback")
+ )
+ assert str(validated) == "http://localhost:3000/callback"
+
+ # Evil.com should fail (in CIMD but doesn't match proxy patterns)
+ with pytest.raises(InvalidRedirectUriError):
+ client.validate_redirect_uri(AnyUrl("https://evil.com/callback"))
diff --git a/tests/server/auth/test_jwt_provider.py b/tests/server/auth/test_jwt_provider.py
index 14a81299e8..bced42a1f5 100644
--- a/tests/server/auth/test_jwt_provider.py
+++ b/tests/server/auth/test_jwt_provider.py
@@ -1,5 +1,6 @@
from collections.abc import AsyncGenerator
from typing import Any
+from unittest.mock import patch
import httpx
import pytest
@@ -10,6 +11,9 @@
from fastmcp.server.auth.providers.jwt import JWKData, JWKSData, JWTVerifier, RSAKeyPair
from fastmcp.utilities.tests import run_server_async
+# Standard public IP used for DNS mocking in tests
+TEST_PUBLIC_IP = "93.184.216.34"
+
class SymmetricKeyHelper:
"""Helper class for generating symmetric key JWT tokens for testing."""
@@ -378,7 +382,11 @@ async def test_symmetric_token_algorithm_mismatch(
class TestBearerTokenJWKS:
- """Tests for JWKS URI functionality."""
+ """Tests for JWKS URI functionality.
+
+ Note: With SSRF protection, JWKS fetches validate DNS and connect to the
+ resolved IP. Tests mock DNS resolution to return a public IP.
+ """
@pytest.fixture
def jwks_provider(self, rsa_key_pair: RSAKeyPair) -> JWTVerifier:
@@ -402,18 +410,25 @@ def mock_jwks_data(self, rsa_key_pair: RSAKeyPair) -> JWKSData:
return {"keys": [jwk_data]}
+ @pytest.fixture
+ def mock_dns(self):
+ """Mock DNS resolution to return test public IP."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ yield
+
async def test_jwks_token_validation(
self,
rsa_key_pair: RSAKeyPair,
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
"""Test token validation using JWKS URI."""
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
username = "test-user"
issuer = "https://test.example.com"
@@ -440,11 +455,9 @@ async def test_jwks_token_validation_with_invalid_key(
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
token = RSAKeyPair.generate().create_token(
subject="test-user",
issuer="https://test.example.com",
@@ -460,12 +473,10 @@ async def test_jwks_token_validation_with_kid(
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
mock_jwks_data["keys"][0]["kid"] = "test-key-1"
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
token = rsa_key_pair.create_token(
subject="test-user",
issuer="https://test.example.com",
@@ -483,12 +494,10 @@ async def test_jwks_token_validation_with_kid_and_no_kid_in_token(
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
mock_jwks_data["keys"][0]["kid"] = "test-key-1"
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
token = rsa_key_pair.create_token(
subject="test-user",
issuer="https://test.example.com",
@@ -505,12 +514,10 @@ async def test_jwks_token_validation_with_no_kid_and_kid_in_jwks(
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
mock_jwks_data["keys"][0]["kid"] = "test-key-1"
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
token = rsa_key_pair.create_token(
subject="test-user",
issuer="https://test.example.com",
@@ -527,12 +534,10 @@ async def test_jwks_token_validation_with_kid_mismatch(
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
mock_jwks_data["keys"][0]["kid"] = "test-key-1"
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
token = rsa_key_pair.create_token(
subject="test-user",
issuer="https://test.example.com",
@@ -549,6 +554,7 @@ async def test_jwks_token_validation_with_multiple_keys_and_no_kid_in_token(
jwks_provider: JWTVerifier,
mock_jwks_data: JWKSData,
httpx_mock: HTTPXMock,
+ mock_dns,
):
mock_jwks_data["keys"] = [ # type: ignore[typeddict-item]
{
@@ -561,10 +567,7 @@ async def test_jwks_token_validation_with_multiple_keys_and_no_kid_in_token(
},
]
- httpx_mock.add_response(
- url="https://test.example.com/.well-known/jwks.json",
- json=mock_jwks_data,
- )
+ httpx_mock.add_response(json=mock_jwks_data)
token = rsa_key_pair.create_token(
subject="test-user",
issuer="https://test.example.com",
diff --git a/tests/server/auth/test_oauth_proxy_redirect_validation.py b/tests/server/auth/test_oauth_proxy_redirect_validation.py
index 20d4afdd75..391977b886 100644
--- a/tests/server/auth/test_oauth_proxy_redirect_validation.py
+++ b/tests/server/auth/test_oauth_proxy_redirect_validation.py
@@ -1,14 +1,20 @@
"""Tests for OAuth proxy redirect URI validation."""
+from unittest.mock import patch
+
import pytest
from key_value.aio.stores.memory import MemoryStore
from mcp.shared.auth import InvalidRedirectUriError
-from pydantic import AnyUrl
+from pydantic import AnyHttpUrl, AnyUrl
from fastmcp.server.auth.auth import TokenVerifier
+from fastmcp.server.auth.cimd import CIMDDocument
from fastmcp.server.auth.oauth_proxy import OAuthProxy
from fastmcp.server.auth.oauth_proxy.models import ProxyDCRClient
+# Standard public IP used for DNS mocking in tests
+TEST_PUBLIC_IP = "93.184.216.34"
+
class MockTokenVerifier(TokenVerifier):
"""Mock token verifier for testing."""
@@ -133,6 +139,38 @@ def test_none_redirect_uri(self):
result = client.validate_redirect_uri(None)
assert result == AnyUrl("http://localhost:3000")
+ def test_cimd_none_redirect_uri_single_exact(self):
+ """CIMD clients may omit redirect_uri only when a single exact URI exists."""
+ cimd_doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:3000/callback"],
+ )
+ client = ProxyDCRClient(
+ client_id="https://example.com/client.json",
+ client_secret=None,
+ redirect_uris=None,
+ cimd_document=cimd_doc,
+ )
+
+ result = client.validate_redirect_uri(None)
+ assert result == AnyUrl("http://localhost:3000/callback")
+
+ def test_cimd_none_redirect_uri_wildcard_rejected(self):
+ """CIMD clients must specify redirect_uri when only wildcard patterns exist."""
+ cimd_doc = CIMDDocument(
+ client_id=AnyHttpUrl("https://example.com/client.json"),
+ redirect_uris=["http://localhost:*/callback"],
+ )
+ client = ProxyDCRClient(
+ client_id="https://example.com/client.json",
+ client_secret=None,
+ redirect_uris=None,
+ cimd_document=cimd_doc,
+ )
+
+ with pytest.raises(InvalidRedirectUriError):
+ client.validate_redirect_uri(None)
+
class TestOAuthProxyRedirectValidation:
"""Test OAuth proxy with redirect URI validation."""
@@ -240,3 +278,90 @@ async def test_proxy_unregistered_client_returns_none(self):
# Get an unregistered client
client = await proxy.get_client("unknown-client")
assert client is None
+
+
+class TestOAuthProxyCIMDClient:
+ """Test that CIMD clients obtained via proxy carry their document and apply dual validation."""
+
+ @pytest.fixture
+ def mock_dns(self):
+ """Mock DNS resolution to return test public IP."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[TEST_PUBLIC_IP],
+ ):
+ yield
+
+ async def test_proxy_get_client_returns_cimd_client(self, httpx_mock, mock_dns):
+ """CIMD client obtained via proxy's get_client has cimd_document attached."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "CIMD App",
+ "redirect_uris": ["http://localhost:*/callback"],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ proxy = OAuthProxy(
+ upstream_authorization_endpoint="https://auth.example.com/authorize",
+ upstream_token_endpoint="https://auth.example.com/token",
+ upstream_client_id="test-client",
+ upstream_client_secret="test-secret",
+ token_verifier=MockTokenVerifier(),
+ base_url="http://localhost:8000",
+ jwt_signing_key="test-secret",
+ client_storage=MemoryStore(),
+ )
+
+ client = await proxy.get_client(url)
+ assert isinstance(client, ProxyDCRClient)
+ assert client.cimd_document is not None
+ assert client.cimd_document.client_name == "CIMD App"
+ assert client.client_id == url
+
+ async def test_proxy_cimd_dual_redirect_validation(self, httpx_mock, mock_dns):
+ """CIMD client from proxy enforces both CIMD redirect_uris and proxy patterns."""
+ url = "https://example.com/client.json"
+ doc_data = {
+ "client_id": url,
+ "client_name": "Dual Validation App",
+ "redirect_uris": [
+ "http://localhost:3000/callback",
+ "https://evil.com/callback",
+ ],
+ "token_endpoint_auth_method": "none",
+ }
+ httpx_mock.add_response(
+ json=doc_data,
+ headers={"content-length": "200"},
+ )
+
+ proxy = OAuthProxy(
+ upstream_authorization_endpoint="https://auth.example.com/authorize",
+ upstream_token_endpoint="https://auth.example.com/token",
+ upstream_client_id="test-client",
+ upstream_client_secret="test-secret",
+ token_verifier=MockTokenVerifier(),
+ base_url="http://localhost:8000",
+ allowed_client_redirect_uris=["http://localhost:*"],
+ jwt_signing_key="test-secret",
+ client_storage=MemoryStore(),
+ )
+
+ client = await proxy.get_client(url)
+ assert client is not None
+
+ # In CIMD AND matches proxy pattern → accepted
+ assert client.validate_redirect_uri(AnyUrl("http://localhost:3000/callback"))
+
+ # In CIMD but NOT in proxy pattern → rejected
+ with pytest.raises(InvalidRedirectUriError):
+ client.validate_redirect_uri(AnyUrl("https://evil.com/callback"))
+
+ # NOT in CIMD but matches proxy pattern → rejected
+ with pytest.raises(InvalidRedirectUriError):
+ client.validate_redirect_uri(AnyUrl("http://localhost:9999/other"))
diff --git a/tests/server/auth/test_oauth_proxy_storage.py b/tests/server/auth/test_oauth_proxy_storage.py
index cc1808c3f6..7c898823e9 100644
--- a/tests/server/auth/test_oauth_proxy_storage.py
+++ b/tests/server/auth/test_oauth_proxy_storage.py
@@ -112,7 +112,7 @@ async def test_nonexistent_client_returns_none(
async def test_proxy_dcr_client_redirect_validation(
self, jwt_verifier: TokenVerifier, temp_storage: AsyncKeyValue
):
- """Test that ProxyDCRClient is created with redirect URI patterns."""
+ """Test that OAuthProxyClient is created with redirect URI patterns."""
proxy = OAuthProxy(
upstream_authorization_endpoint="https://github.com/login/oauth/authorize",
upstream_token_endpoint="https://github.com/login/oauth/access_token",
@@ -132,11 +132,11 @@ async def test_proxy_dcr_client_redirect_validation(
)
await proxy.register_client(client_info)
- # Get client back - should be ProxyDCRClient
+ # Get client back - should be OAuthProxyClient
client = await proxy.get_client("test-proxy-client")
assert client is not None
- # ProxyDCRClient should validate dynamic localhost ports
+ # OAuthProxyClient should validate dynamic localhost ports
validated = client.validate_redirect_uri(
AnyUrl("http://localhost:12345/callback")
)
@@ -205,5 +205,7 @@ async def test_storage_data_structure(self, jwt_verifier, temp_storage):
"client_id_issued_at": None,
"client_secret_expires_at": None,
"allowed_redirect_uri_patterns": None,
+ "cimd_document": None,
+ "cimd_fetched_at": None,
}
)
diff --git a/tests/server/auth/test_oidc_proxy.py b/tests/server/auth/test_oidc_proxy.py
index b8e373e401..e3dd95e1c1 100644
--- a/tests/server/auth/test_oidc_proxy.py
+++ b/tests/server/auth/test_oidc_proxy.py
@@ -15,10 +15,10 @@
TEST_AUTHORIZATION_ENDPOINT = "https://example.com/authorize"
TEST_TOKEN_ENDPOINT = "https://example.com/oauth/token"
-TEST_CONFIG_URL = "https://example.com/.well-known/openid-configuration"
+TEST_CONFIG_URL = AnyHttpUrl("https://example.com/.well-known/openid-configuration")
TEST_CLIENT_ID = "test-client-id"
TEST_CLIENT_SECRET = "test-client-secret"
-TEST_BASE_URL = "https://example.com:8000/"
+TEST_BASE_URL = AnyHttpUrl("https://example.com:8000/")
# =============================================================================
@@ -366,7 +366,7 @@ def validate_get_oidc_configuration(oidc_configuration, strict, timeout_seconds)
mock_get.return_value = mock_response
config = OIDCConfiguration.get_oidc_configuration(
- config_url=AnyHttpUrl(TEST_CONFIG_URL),
+ config_url=TEST_CONFIG_URL,
strict=strict,
timeout_seconds=timeout_seconds,
)
@@ -376,7 +376,7 @@ def validate_get_oidc_configuration(oidc_configuration, strict, timeout_seconds)
mock_get.assert_called_once()
call_args = mock_get.call_args
- assert call_args[0][0] == TEST_CONFIG_URL
+ assert str(call_args[0][0]) == str(TEST_CONFIG_URL)
return call_args
@@ -415,7 +415,7 @@ def test_get_oidc_configuration_not_strict(
mock_get.return_value = mock_response
OIDCConfiguration.get_oidc_configuration(
- config_url=AnyHttpUrl(TEST_CONFIG_URL),
+ config_url=TEST_CONFIG_URL,
strict=False,
timeout_seconds=10,
)
@@ -423,7 +423,7 @@ def test_get_oidc_configuration_not_strict(
mock_get.assert_called_once()
call_args = mock_get.call_args
- assert call_args[0][0] == TEST_CONFIG_URL
+ assert str(call_args[0][0]) == str(TEST_CONFIG_URL)
def validate_proxy(mock_get, proxy, oidc_config):
@@ -431,13 +431,13 @@ def validate_proxy(mock_get, proxy, oidc_config):
mock_get.assert_called_once()
call_args = mock_get.call_args
- assert str(call_args[0][0]) == TEST_CONFIG_URL
+ assert str(call_args[0][0]) == str(TEST_CONFIG_URL)
assert proxy._upstream_authorization_endpoint == TEST_AUTHORIZATION_ENDPOINT
assert proxy._upstream_token_endpoint == TEST_TOKEN_ENDPOINT
assert proxy._upstream_client_id == TEST_CLIENT_ID
assert proxy._upstream_client_secret.get_secret_value() == TEST_CLIENT_SECRET
- assert str(proxy.base_url) == TEST_BASE_URL
+ assert str(proxy.base_url) == str(TEST_BASE_URL)
assert proxy.oidc_config == oidc_config
diff --git a/tests/server/auth/test_redirect_validation.py b/tests/server/auth/test_redirect_validation.py
index 87071a91fc..10945d2fb4 100644
--- a/tests/server/auth/test_redirect_validation.py
+++ b/tests/server/auth/test_redirect_validation.py
@@ -109,6 +109,65 @@ def test_anyurl_conversion(self):
assert not validate_redirect_uri(uri, patterns)
+class TestSecurityBypass:
+ """Test protection against redirect URI security bypass attacks."""
+
+ def test_userinfo_bypass_blocked(self):
+ """Test that userinfo-style bypasses are blocked.
+
+ Attack: http://localhost@evil.com/callback would match http://localhost:*
+ with naive string matching, but actually points to evil.com.
+ """
+ pattern = "http://localhost:*"
+
+ # These should be blocked - the "host" is actually in the userinfo
+ assert not matches_allowed_pattern(
+ "http://localhost@evil.com/callback", pattern
+ )
+ assert not matches_allowed_pattern(
+ "http://localhost:3000@malicious.io/callback", pattern
+ )
+ assert not matches_allowed_pattern(
+ "http://user:pass@localhost:3000/callback", pattern
+ )
+
+ def test_userinfo_bypass_with_subdomain_pattern(self):
+ """Test userinfo bypass with subdomain wildcard patterns."""
+ pattern = "https://*.example.com/callback"
+
+ # Blocked: userinfo tricks
+ assert not matches_allowed_pattern(
+ "https://app.example.com@attacker.com/callback", pattern
+ )
+ assert not matches_allowed_pattern(
+ "https://user:pass@app.example.com/callback", pattern
+ )
+
+ def test_legitimate_uris_still_work(self):
+ """Test that legitimate URIs work after security hardening."""
+ pattern = "http://localhost:*"
+ assert matches_allowed_pattern("http://localhost:3000/callback", pattern)
+ assert matches_allowed_pattern("http://localhost:8080/auth", pattern)
+
+ pattern = "https://*.example.com/callback"
+ assert matches_allowed_pattern("https://app.example.com/callback", pattern)
+
+ def test_scheme_mismatch_blocked(self):
+ """Test that scheme mismatches are blocked."""
+ assert not matches_allowed_pattern(
+ "http://localhost:3000/callback", "https://localhost:*"
+ )
+ assert not matches_allowed_pattern(
+ "https://localhost:3000/callback", "http://localhost:*"
+ )
+
+ def test_host_mismatch_blocked(self):
+ """Test that host mismatches are blocked even with wildcards."""
+ pattern = "http://localhost:*"
+ assert not matches_allowed_pattern("http://127.0.0.1:3000/callback", pattern)
+ assert not matches_allowed_pattern("http://example.com:3000/callback", pattern)
+
+
class TestDefaultPatterns:
"""Test the default localhost patterns constant."""
diff --git a/tests/server/auth/test_ssrf_protection.py b/tests/server/auth/test_ssrf_protection.py
new file mode 100644
index 0000000000..79cf926a09
--- /dev/null
+++ b/tests/server/auth/test_ssrf_protection.py
@@ -0,0 +1,447 @@
+"""Tests for SSRF-safe HTTP utilities.
+
+This module tests the ssrf.py module which provides SSRF-protected HTTP fetching.
+"""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import httpx
+import pytest
+
+from fastmcp.server.auth.ssrf import (
+ SSRFError,
+ SSRFFetchError,
+ is_ip_allowed,
+ ssrf_safe_fetch,
+ validate_url,
+)
+
+
+class TestIsIPAllowed:
+ """Tests for is_ip_allowed function."""
+
+ def test_public_ipv4_allowed(self):
+ """Public IPv4 addresses should be allowed."""
+ assert is_ip_allowed("8.8.8.8") is True
+ assert is_ip_allowed("1.1.1.1") is True
+ assert is_ip_allowed("93.184.216.34") is True
+
+ def test_private_ipv4_blocked(self):
+ """Private IPv4 addresses should be blocked."""
+ assert is_ip_allowed("192.168.1.1") is False
+ assert is_ip_allowed("10.0.0.1") is False
+ assert is_ip_allowed("172.16.0.1") is False
+
+ def test_loopback_blocked(self):
+ """Loopback addresses should be blocked."""
+ assert is_ip_allowed("127.0.0.1") is False
+ assert is_ip_allowed("::1") is False
+
+ def test_link_local_blocked(self):
+ """Link-local addresses (AWS metadata) should be blocked."""
+ assert is_ip_allowed("169.254.169.254") is False
+
+ def test_rfc6598_cgnat_blocked(self):
+ """RFC6598 Carrier-Grade NAT addresses should be blocked."""
+ assert is_ip_allowed("100.64.0.1") is False
+ assert is_ip_allowed("100.100.100.100") is False
+
+ def test_ipv4_mapped_ipv6_blocked_if_private(self):
+ """IPv4-mapped IPv6 addresses should check the embedded IPv4."""
+ assert is_ip_allowed("::ffff:127.0.0.1") is False
+ assert is_ip_allowed("::ffff:192.168.1.1") is False
+
+
+class TestValidateURL:
+ """Tests for validate_url function."""
+
+ async def test_http_rejected(self):
+ """HTTP URLs should be rejected (HTTPS required)."""
+ with pytest.raises(SSRFError, match="must use HTTPS"):
+ await validate_url("http://example.com/path")
+
+ async def test_missing_host_rejected(self):
+ """URLs without host should be rejected."""
+ with pytest.raises(SSRFError, match="must have a host"):
+ await validate_url("https:///path")
+
+ async def test_root_path_rejected_when_required(self):
+ """Root paths should be rejected when require_path=True."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["93.184.216.34"],
+ ):
+ with pytest.raises(SSRFError, match="non-root path"):
+ await validate_url("https://example.com/", require_path=True)
+
+ async def test_private_ip_rejected(self):
+ """URLs resolving to private IPs should be rejected."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["192.168.1.1"],
+ ):
+ with pytest.raises(SSRFError, match="blocked IP"):
+ await validate_url("https://example.com/path")
+
+
+class TestSSRFSafeFetch:
+ """Tests for ssrf_safe_fetch function."""
+
+ async def test_private_ip_blocked(self):
+ """Fetch to private IP should be blocked."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["192.168.1.1"],
+ ):
+ with pytest.raises(SSRFError, match="blocked IP"):
+ await ssrf_safe_fetch("https://internal.example.com/api")
+
+ async def test_cgnat_blocked(self):
+ """Fetch to RFC6598 CGNAT IP should be blocked."""
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["100.64.0.1"],
+ ):
+ with pytest.raises(SSRFError, match="blocked IP"):
+ await ssrf_safe_fetch("https://cgnat.example.com/api")
+
+ async def test_connects_to_pinned_ip(self):
+ """Verify connection uses pinned IP, not re-resolved DNS."""
+ resolved_ip = "93.184.216.34"
+
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[resolved_ip],
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {"content-length": "15"}
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+
+ async def aiter_bytes():
+ yield b'{"data": "test"}'
+
+ mock_stream.aiter_bytes = aiter_bytes
+
+ mock_client = AsyncMock()
+ mock_client.stream = MagicMock(return_value=mock_stream)
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_class.return_value = mock_client
+
+ await ssrf_safe_fetch("https://example.com/api")
+
+ # Verify URL contains pinned IP
+ call_args = mock_client.stream.call_args
+ url_called = call_args[0][1]
+ assert resolved_ip in url_called
+
+ async def test_fallback_to_second_ip(self):
+ """If the first IP fails, the next resolved IP should be tried."""
+ resolved_ips = ["2001:4860:4860::8888", "93.184.216.34"]
+
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=resolved_ips,
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ request = httpx.Request("GET", "https://example.com/api")
+
+ first_client = AsyncMock()
+ first_client.stream = MagicMock(
+ side_effect=httpx.RequestError("boom", request=request)
+ )
+ first_client.__aenter__.return_value = first_client
+ first_client.__aexit__ = AsyncMock(return_value=None)
+
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {"content-length": "2"}
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+
+ async def aiter_bytes():
+ yield b"ok"
+
+ mock_stream.aiter_bytes = aiter_bytes
+
+ second_client = AsyncMock()
+ second_client.stream = MagicMock(return_value=mock_stream)
+ second_client.__aenter__.return_value = second_client
+ second_client.__aexit__ = AsyncMock(return_value=None)
+
+ mock_client_class.side_effect = [first_client, second_client]
+
+ content = await ssrf_safe_fetch("https://example.com/api")
+ assert content == b"ok"
+
+ call_args = second_client.stream.call_args
+ url_called = call_args[0][1]
+ assert resolved_ips[1] in url_called
+
+ async def test_host_header_set(self):
+ """Verify Host header is set to original hostname."""
+ resolved_ip = "93.184.216.34"
+ original_host = "example.com"
+
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[resolved_ip],
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {"content-length": "15"}
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+
+ async def aiter_bytes():
+ yield b'{"data": "test"}'
+
+ mock_stream.aiter_bytes = aiter_bytes
+
+ mock_client = AsyncMock()
+ mock_client.stream = MagicMock(return_value=mock_stream)
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_class.return_value = mock_client
+
+ await ssrf_safe_fetch(f"https://{original_host}/api")
+
+ # Verify Host header
+ call_kwargs = mock_client.stream.call_args[1]
+ assert call_kwargs["headers"]["Host"] == original_host
+
+ async def test_response_size_limit(self):
+ """Verify response size limit is enforced via streaming."""
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["93.184.216.34"],
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ # Response larger than default 5KB (no Content-Length, so streaming enforces)
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {} # No Content-Length to force streaming check
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+
+ async def aiter_bytes():
+ # Yield 10KB total
+ for _ in range(10):
+ yield b"x" * 1024
+
+ mock_stream.aiter_bytes = aiter_bytes
+
+ mock_client = AsyncMock()
+ mock_client.stream = MagicMock(return_value=mock_stream)
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_class.return_value = mock_client
+
+ with pytest.raises(SSRFFetchError, match="too large"):
+ await ssrf_safe_fetch("https://example.com/api")
+
+
+class TestJWKSSSRFProtection:
+ """Tests for SSRF protection in JWTVerifier JWKS fetching."""
+
+ async def test_jwks_private_ip_blocked(self):
+ """JWKS fetch to private IP should be blocked."""
+ from fastmcp.server.auth.providers.jwt import JWTVerifier
+
+ verifier = JWTVerifier(
+ jwks_uri="https://internal.example.com/.well-known/jwks.json",
+ issuer="https://issuer.example.com",
+ ssrf_safe=True,
+ )
+
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["192.168.1.1"],
+ ):
+ with pytest.raises(ValueError, match="Failed to fetch JWKS"):
+ # Create a dummy token to trigger JWKS fetch
+ await verifier._get_jwks_key("test-kid")
+
+ async def test_jwks_cgnat_blocked(self):
+ """JWKS fetch to RFC6598 CGNAT IP should be blocked."""
+ from fastmcp.server.auth.providers.jwt import JWTVerifier
+
+ verifier = JWTVerifier(
+ jwks_uri="https://cgnat.example.com/.well-known/jwks.json",
+ issuer="https://issuer.example.com",
+ ssrf_safe=True,
+ )
+
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["100.64.0.1"],
+ ):
+ with pytest.raises(ValueError, match="Failed to fetch JWKS"):
+ await verifier._get_jwks_key("test-kid")
+
+ async def test_jwks_loopback_blocked(self):
+ """JWKS fetch to loopback should be blocked."""
+ from fastmcp.server.auth.providers.jwt import JWTVerifier
+
+ verifier = JWTVerifier(
+ jwks_uri="https://localhost/.well-known/jwks.json",
+ issuer="https://issuer.example.com",
+ ssrf_safe=True,
+ )
+
+ with patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["127.0.0.1"],
+ ):
+ with pytest.raises(ValueError, match="Failed to fetch JWKS"):
+ await verifier._get_jwks_key("test-kid")
+
+
+class TestIPv6URLFormatting:
+ """Tests for proper IPv6 address bracketing in URLs."""
+
+ def test_format_ip_for_url_ipv4(self):
+ """IPv4 addresses should not be bracketed."""
+ from fastmcp.server.auth.ssrf import format_ip_for_url
+
+ assert format_ip_for_url("8.8.8.8") == "8.8.8.8"
+ assert format_ip_for_url("192.168.1.1") == "192.168.1.1"
+
+ def test_format_ip_for_url_ipv6(self):
+ """IPv6 addresses should be bracketed for URL use."""
+ from fastmcp.server.auth.ssrf import format_ip_for_url
+
+ assert format_ip_for_url("2001:db8::1") == "[2001:db8::1]"
+ assert format_ip_for_url("::1") == "[::1]"
+ assert format_ip_for_url("fe80::1") == "[fe80::1]"
+
+ def test_format_ip_for_url_invalid(self):
+ """Invalid IP strings should be returned unchanged."""
+ from fastmcp.server.auth.ssrf import format_ip_for_url
+
+ assert format_ip_for_url("not-an-ip") == "not-an-ip"
+ assert format_ip_for_url("") == ""
+
+ async def test_ipv6_pinned_url_is_valid(self):
+ """Verify IPv6 addresses are properly bracketed in pinned URLs."""
+ resolved_ipv6 = "2001:4860:4860::8888"
+
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=[resolved_ipv6],
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {"content-length": "10"}
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+
+ async def aiter_bytes():
+ yield b'{"key": 1}'
+
+ mock_stream.aiter_bytes = aiter_bytes
+
+ mock_client = AsyncMock()
+ mock_client.stream = MagicMock(return_value=mock_stream)
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_class.return_value = mock_client
+
+ await ssrf_safe_fetch("https://example.com/api")
+
+ # Verify the URL contains bracketed IPv6 address
+ call_args = mock_client.stream.call_args
+ url_called = call_args[0][1]
+
+ # IPv6 should be bracketed: https://[2001:4860:4860::8888]:443/path
+ assert f"[{resolved_ipv6}]" in url_called, (
+ f"Expected bracketed IPv6 [{resolved_ipv6}] in URL, got {url_called}"
+ )
+
+
+class TestStreamingResponseSizeLimit:
+ """Tests for streaming-based response size enforcement."""
+
+ async def test_size_limit_enforced_during_streaming(self):
+ """Verify that size limit is enforced as chunks are received, not after."""
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["93.184.216.34"],
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ chunks_yielded = []
+
+ async def aiter_bytes():
+ # Yield chunks that exceed the limit
+ for i in range(10):
+ chunk = b"x" * 1024 # 1KB per chunk
+ chunks_yielded.append(chunk)
+ yield chunk
+
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {} # No content-length to force streaming check
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+ mock_stream.aiter_bytes = aiter_bytes
+
+ mock_client = AsyncMock()
+ mock_client.stream = MagicMock(return_value=mock_stream)
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_class.return_value = mock_client
+
+ with pytest.raises(SSRFFetchError, match="too large"):
+ await ssrf_safe_fetch("https://example.com/api", max_size=5120)
+
+ # Verify we stopped after exceeding the limit (should be ~6 chunks for 5KB limit)
+ # This confirms we're enforcing during streaming, not after downloading all
+ assert len(chunks_yielded) <= 7, (
+ f"Downloaded {len(chunks_yielded)} chunks (expected <=7 for streaming enforcement)"
+ )
+
+ async def test_content_length_header_checked_first(self):
+ """Verify Content-Length header is checked before streaming."""
+ with (
+ patch(
+ "fastmcp.server.auth.ssrf.resolve_hostname",
+ return_value=["93.184.216.34"],
+ ),
+ patch("httpx.AsyncClient") as mock_client_class,
+ ):
+ mock_stream = MagicMock()
+ mock_stream.status_code = 200
+ mock_stream.headers = {"content-length": "10240"} # 10KB
+ mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
+ mock_stream.__aexit__ = AsyncMock(return_value=None)
+
+ # aiter_bytes should never be called if Content-Length is checked
+ mock_stream.aiter_bytes = MagicMock(
+ side_effect=AssertionError("Should not stream")
+ )
+
+ mock_client = AsyncMock()
+ mock_client.stream = MagicMock(return_value=mock_stream)
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_class.return_value = mock_client
+
+ with pytest.raises(SSRFFetchError, match="too large"):
+ await ssrf_safe_fetch("https://example.com/api", max_size=5120)
diff --git a/tests/utilities/openapi/test_models.py b/tests/utilities/openapi/test_models.py
index cc4baadb37..4361635c24 100644
--- a/tests/utilities/openapi/test_models.py
+++ b/tests/utilities/openapi/test_models.py
@@ -4,8 +4,10 @@
from inline_snapshot import snapshot
from fastmcp.utilities.openapi.models import (
+ HttpMethod,
HTTPRoute,
ParameterInfo,
+ ParameterLocation,
RequestBodyInfo,
ResponseInfo,
)
@@ -51,7 +53,7 @@ def test_parameter_with_all_fields(self):
assert param.style == "deepObject"
@pytest.mark.parametrize("location", ["path", "query", "header", "cookie"])
- def test_valid_parameter_locations(self, location):
+ def test_valid_parameter_locations(self, location: ParameterLocation):
"""Test that all valid parameter locations are accepted."""
param = ParameterInfo(
name="test",
@@ -286,7 +288,7 @@ def test_route_pre_calculated_fields(self):
@pytest.mark.parametrize(
"method", ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
)
- def test_valid_http_methods(self, method):
+ def test_valid_http_methods(self, method: HttpMethod):
"""Test that all valid HTTP methods are accepted."""
route = HTTPRoute(
path="/test",