Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/servers/auth/oauth-proxy.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ CIMD provides several security advantages over DCR:
- **Replay prevention**: For `private_key_jwt` clients, JTI claims are tracked to prevent assertion replay
- **Cache-aware fetching**: CIMD documents are cached according to HTTP cache headers and revalidated when required

To disable CIMD support entirely (for example, to require all clients to register via DCR):
CIMD is enabled by default. To disable it entirely (for example, to require all clients to register via DCR), set `enable_cimd=False` explicitly:

```python
auth = OAuthProxy(
Expand Down
172 changes: 160 additions & 12 deletions src/fastmcp/server/auth/cimd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import fnmatch
import json
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import timezone
from email.utils import parsedate_to_datetime
from typing import TYPE_CHECKING, Any, Literal
from urllib.parse import urlparse

Expand All @@ -27,7 +31,7 @@
from fastmcp.server.auth.ssrf import (
SSRFError,
SSRFFetchError,
ssrf_safe_fetch,
ssrf_safe_fetch_response,
validate_url,
)
from fastmcp.utilities.logging import get_logger
Expand Down Expand Up @@ -155,12 +159,37 @@ class CIMDFetchError(Exception):
"""Raised when CIMD document fetching fails."""


@dataclass
class _CIMDCacheEntry:
"""Cached CIMD document and associated HTTP cache metadata."""

doc: CIMDDocument
etag: str | None
last_modified: str | None
expires_at: float
freshness_lifetime: float
must_revalidate: bool


@dataclass
class _CIMDCachePolicy:
"""Normalized cache directives parsed from HTTP response headers."""

etag: str | None
last_modified: str | None
expires_at: float
freshness_lifetime: float
no_store: bool
must_revalidate: bool


class CIMDFetcher:
"""Fetch and validate CIMD documents with SSRF protection.

Delegates HTTP fetching to ssrf_safe_fetch which provides DNS pinning,
IP validation, size limits, and timeout enforcement. Documents are cached
with a simple TTL.
Delegates HTTP fetching to ssrf_safe_fetch_response, which provides DNS
pinning, IP validation, size limits, and timeout enforcement. Documents are
cached using HTTP caching semantics (Cache-Control/ETag/Last-Modified), with
a TTL fallback when response headers do not define caching behavior.
"""

# Maximum response size (bytes)
Expand All @@ -178,7 +207,65 @@ def __init__(
timeout: HTTP request timeout in seconds (default 10.0)
"""
self.timeout = timeout
self._cache: dict[str, tuple[CIMDDocument, float]] = {}
self._cache: dict[str, _CIMDCacheEntry] = {}

def _parse_cache_policy(
self, headers: Mapping[str, str], now: float
) -> _CIMDCachePolicy:
"""Parse HTTP cache headers and derive cache behavior."""
normalized = {k.lower(): v for k, v in headers.items()}
cache_control = normalized.get("cache-control", "")
directives = {
part.strip().lower() for part in cache_control.split(",") if part.strip()
}

no_store = "no-store" in directives
must_revalidate = "no-cache" in directives
max_age: int | None = None

for directive in directives:
if directive.startswith("max-age="):
value = directive.removeprefix("max-age=").strip()
try:
max_age = max(0, int(value))
except ValueError:
logger.debug(
"Ignoring invalid Cache-Control max-age value: %s", value
)
break

expires_at: float | None = None
if max_age is not None:
expires_at = now + max_age
elif "expires" in normalized:
try:
dt = parsedate_to_datetime(normalized["expires"])
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
expires_at = dt.timestamp()
except (TypeError, ValueError):
logger.debug(
"Ignoring invalid Expires header on CIMD response: %s",
normalized["expires"],
)

if expires_at is None:
expires_at = now + self.DEFAULT_CACHE_TTL_SECONDS
freshness_lifetime = max(0.0, expires_at - now)

return _CIMDCachePolicy(
etag=normalized.get("etag"),
last_modified=normalized.get("last-modified"),
expires_at=expires_at,
freshness_lifetime=freshness_lifetime,
no_store=no_store,
must_revalidate=must_revalidate,
)

def _has_freshness_headers(self, headers: Mapping[str, str]) -> bool:
"""Return True when response includes cache freshness directives."""
normalized = {k.lower() for k in headers}
return "cache-control" in normalized or "expires" in normalized

def is_cimd_client_id(self, client_id: str) -> bool:
"""Check if a client_id looks like a CIMD URL.
Expand All @@ -200,7 +287,7 @@ def is_cimd_client_id(self, client_id: str) -> bool:
async def fetch(self, client_id_url: str) -> CIMDDocument:
"""Fetch and validate a CIMD document with SSRF protection.

Uses ssrf_safe_fetch for the HTTP layer, which provides:
Uses ssrf_safe_fetch_response for the HTTP layer, which provides:
- HTTPS only, DNS resolution with IP validation
- DNS pinning (connects to validated IP directly)
- Blocks private/loopback/link-local/multicast IPs
Expand All @@ -218,26 +305,76 @@ async def fetch(self, client_id_url: str) -> CIMDDocument:
CIMDFetchError: If document cannot be fetched
"""
cached = self._cache.get(client_id_url)
now = time.time()
request_headers: dict[str, str] | None = None
allowed_status_codes = {200}

if cached is not None:
doc, expires_at = cached
if time.time() < expires_at:
return doc
if not cached.must_revalidate and now < cached.expires_at:
return cached.doc

request_headers = {}
if cached.etag:
request_headers["If-None-Match"] = cached.etag
if cached.last_modified:
request_headers["If-Modified-Since"] = cached.last_modified
if request_headers:
allowed_status_codes = {200, 304}

try:
content = await ssrf_safe_fetch(
response = await ssrf_safe_fetch_response(
client_id_url,
require_path=True,
max_size=self.MAX_RESPONSE_SIZE,
timeout=self.timeout,
overall_timeout=30.0,
request_headers=request_headers,
allowed_status_codes=allowed_status_codes,
)
except SSRFError as e:
raise CIMDValidationError(str(e)) from e
except SSRFFetchError as e:
raise CIMDFetchError(str(e)) from e

if response.status_code == 304:
if cached is None:
raise CIMDFetchError(
"CIMD server returned 304 Not Modified without cached document"
)

now = time.time()
if self._has_freshness_headers(response.headers):
policy = self._parse_cache_policy(response.headers, now)
else:
# RFC allows 304 to omit unchanged headers. Preserve existing
# cache policy rather than resetting to fallback defaults.
policy = _CIMDCachePolicy(
etag=None,
last_modified=None,
expires_at=now + cached.freshness_lifetime,
freshness_lifetime=cached.freshness_lifetime,
no_store=False,
must_revalidate=cached.must_revalidate,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if not policy.no_store:
self._cache[client_id_url] = _CIMDCacheEntry(
doc=cached.doc,
etag=policy.etag or cached.etag,
last_modified=policy.last_modified or cached.last_modified,
expires_at=policy.expires_at,
freshness_lifetime=policy.freshness_lifetime,
must_revalidate=policy.must_revalidate,
)
else:
self._cache.pop(client_id_url, None)
return cached.doc

now = time.time()
policy = self._parse_cache_policy(response.headers, now)

try:
data = json.loads(content)
data = json.loads(response.content)
except json.JSONDecodeError as e:
raise CIMDValidationError(f"CIMD document is not valid JSON: {e}") from e

Expand Down Expand Up @@ -268,7 +405,18 @@ async def fetch(self, client_id_url: str) -> CIMDDocument:
doc.client_name,
)

self._cache[client_id_url] = (doc, time.time() + self.DEFAULT_CACHE_TTL_SECONDS)
if not policy.no_store:
self._cache[client_id_url] = _CIMDCacheEntry(
doc=doc,
etag=policy.etag,
last_modified=policy.last_modified,
expires_at=policy.expires_at,
freshness_lifetime=policy.freshness_lifetime,
must_revalidate=policy.must_revalidate,
)
else:
self._cache.pop(client_id_url, None)

return doc

def validate_redirect_uri(self, doc: CIMDDocument, redirect_uri: str) -> bool:
Expand Down
19 changes: 17 additions & 2 deletions src/fastmcp/server/auth/oauth_proxy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,27 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
"redirect_uri must be specified when CIMD redirect_uris uses wildcards."
)
try:
return AnyUrl(candidate)
resolved = AnyUrl(candidate)
except Exception as e:
raise InvalidRedirectUriError(
f"Invalid CIMD redirect_uri: {e}"
) from e

# Respect proxy-level redirect URI restrictions even when the
# client omits redirect_uri and we fall back to CIMD defaults.
if (
self.allowed_redirect_uri_patterns is not None
and not validate_redirect_uri(
redirect_uri=resolved,
allowed_patterns=self.allowed_redirect_uri_patterns,
)
):
raise InvalidRedirectUriError(
f"Redirect URI '{resolved}' does not match allowed patterns."
)

return resolved

raise InvalidRedirectUriError(
"redirect_uri must be specified when CIMD lists multiple redirect_uris."
)
Expand All @@ -207,7 +222,7 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
f"Redirect URI '{redirect_uri}' does not match CIMD redirect_uris."
)

if self.allowed_redirect_uri_patterns:
if self.allowed_redirect_uri_patterns is not None:
if not validate_redirect_uri(
redirect_uri=redirect_uri,
allowed_patterns=self.allowed_redirect_uri_patterns,
Expand Down
55 changes: 52 additions & 3 deletions src/fastmcp/server/auth/ssrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ipaddress
import socket
import time
from collections.abc import Mapping
from dataclasses import dataclass
from urllib.parse import urlparse

Expand Down Expand Up @@ -134,6 +135,15 @@ class ValidatedURL:
resolved_ips: list[str]


@dataclass
class SSRFFetchResponse:
"""Response payload from an SSRF-safe fetch."""

content: bytes
status_code: int
headers: dict[str, str]


async def validate_url(url: str, require_path: bool = False) -> ValidatedURL:
"""Validate URL for SSRF and resolve to IPs.

Expand Down Expand Up @@ -215,12 +225,39 @@ async def ssrf_safe_fetch(
SSRFError: If SSRF validation fails
SSRFFetchError: If fetch fails
"""
response = await ssrf_safe_fetch_response(
url,
require_path=require_path,
max_size=max_size,
timeout=timeout,
overall_timeout=overall_timeout,
allowed_status_codes={200},
)
return response.content


async def ssrf_safe_fetch_response(
url: str,
*,
require_path: bool = False,
max_size: int = 5120,
timeout: float = 10.0,
overall_timeout: float = 30.0,
request_headers: Mapping[str, str] | None = None,
allowed_status_codes: set[int] | None = None,
) -> SSRFFetchResponse:
"""Fetch URL with SSRF protection and return response metadata.

This is equivalent to :func:`ssrf_safe_fetch` but returns response headers
and status code, and supports conditional request headers.
"""
start_time = time.monotonic()

# Validate URL and resolve DNS
validated = await validate_url(url, require_path=require_path)

last_error: Exception | None = None
expected_statuses = allowed_status_codes or {200}

for pinned_ip in validated.resolved_ips:
elapsed = time.monotonic() - start_time
Expand All @@ -239,6 +276,14 @@ async def ssrf_safe_fetch(
pinned_ip,
)

headers = {"Host": validated.hostname}
if request_headers:
for key, value in request_headers.items():
# Host must remain pinned to the validated hostname.
if key.lower() == "host":
continue
headers[key] = value

try:
# Use httpx with streaming to enforce size limit during download
async with (
Expand All @@ -255,14 +300,14 @@ async def ssrf_safe_fetch(
client.stream(
"GET",
pinned_url,
headers={"Host": validated.hostname},
headers=headers,
extensions={"sni_hostname": validated.hostname},
) as response,
):
if time.monotonic() - start_time > overall_timeout:
raise SSRFFetchError(f"Overall timeout exceeded: {url}")

if response.status_code != 200:
if response.status_code not in expected_statuses:
raise SSRFFetchError(f"HTTP {response.status_code} fetching {url}")

# Check Content-Length header first if available
Expand Down Expand Up @@ -290,7 +335,11 @@ async def ssrf_safe_fetch(
)
chunks.append(chunk)

return b"".join(chunks)
return SSRFFetchResponse(
content=b"".join(chunks),
status_code=response.status_code,
headers=dict(response.headers),
)

except httpx.TimeoutException as e:
last_error = e
Expand Down
Loading