diff --git a/docs/servers/auth/oauth-proxy.mdx b/docs/servers/auth/oauth-proxy.mdx index 86a8865f30..fea240fada 100644 --- a/docs/servers/auth/oauth-proxy.mdx +++ b/docs/servers/auth/oauth-proxy.mdx @@ -583,7 +583,7 @@ CIMD provides several security advantages over DCR: - **Replay prevention**: For `private_key_jwt` clients, JTI claims are tracked to prevent assertion replay - **Cache-aware fetching**: CIMD documents are cached according to HTTP cache headers and revalidated when required -To disable CIMD support entirely (for example, to require all clients to register via DCR): +CIMD is enabled by default. To disable it entirely (for example, to require all clients to register via DCR), set `enable_cimd=False` explicitly: ```python auth = OAuthProxy( diff --git a/src/fastmcp/server/auth/cimd.py b/src/fastmcp/server/auth/cimd.py index 49aa6687cb..caef56f966 100644 --- a/src/fastmcp/server/auth/cimd.py +++ b/src/fastmcp/server/auth/cimd.py @@ -19,6 +19,10 @@ import fnmatch import json import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import timezone +from email.utils import parsedate_to_datetime from typing import TYPE_CHECKING, Any, Literal from urllib.parse import urlparse @@ -27,7 +31,7 @@ from fastmcp.server.auth.ssrf import ( SSRFError, SSRFFetchError, - ssrf_safe_fetch, + ssrf_safe_fetch_response, validate_url, ) from fastmcp.utilities.logging import get_logger @@ -155,12 +159,37 @@ class CIMDFetchError(Exception): """Raised when CIMD document fetching fails.""" +@dataclass +class _CIMDCacheEntry: + """Cached CIMD document and associated HTTP cache metadata.""" + + doc: CIMDDocument + etag: str | None + last_modified: str | None + expires_at: float + freshness_lifetime: float + must_revalidate: bool + + +@dataclass +class _CIMDCachePolicy: + """Normalized cache directives parsed from HTTP response headers.""" + + etag: str | None + last_modified: str | None + expires_at: float + freshness_lifetime: float + no_store: bool + must_revalidate: bool + + class CIMDFetcher: """Fetch and validate CIMD documents with SSRF protection. - Delegates HTTP fetching to ssrf_safe_fetch which provides DNS pinning, - IP validation, size limits, and timeout enforcement. Documents are cached - with a simple TTL. + Delegates HTTP fetching to ssrf_safe_fetch_response, which provides DNS + pinning, IP validation, size limits, and timeout enforcement. Documents are + cached using HTTP caching semantics (Cache-Control/ETag/Last-Modified), with + a TTL fallback when response headers do not define caching behavior. """ # Maximum response size (bytes) @@ -178,7 +207,65 @@ def __init__( timeout: HTTP request timeout in seconds (default 10.0) """ self.timeout = timeout - self._cache: dict[str, tuple[CIMDDocument, float]] = {} + self._cache: dict[str, _CIMDCacheEntry] = {} + + def _parse_cache_policy( + self, headers: Mapping[str, str], now: float + ) -> _CIMDCachePolicy: + """Parse HTTP cache headers and derive cache behavior.""" + normalized = {k.lower(): v for k, v in headers.items()} + cache_control = normalized.get("cache-control", "") + directives = { + part.strip().lower() for part in cache_control.split(",") if part.strip() + } + + no_store = "no-store" in directives + must_revalidate = "no-cache" in directives + max_age: int | None = None + + for directive in directives: + if directive.startswith("max-age="): + value = directive.removeprefix("max-age=").strip() + try: + max_age = max(0, int(value)) + except ValueError: + logger.debug( + "Ignoring invalid Cache-Control max-age value: %s", value + ) + break + + expires_at: float | None = None + if max_age is not None: + expires_at = now + max_age + elif "expires" in normalized: + try: + dt = parsedate_to_datetime(normalized["expires"]) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + expires_at = dt.timestamp() + except (TypeError, ValueError): + logger.debug( + "Ignoring invalid Expires header on CIMD response: %s", + normalized["expires"], + ) + + if expires_at is None: + expires_at = now + self.DEFAULT_CACHE_TTL_SECONDS + freshness_lifetime = max(0.0, expires_at - now) + + return _CIMDCachePolicy( + etag=normalized.get("etag"), + last_modified=normalized.get("last-modified"), + expires_at=expires_at, + freshness_lifetime=freshness_lifetime, + no_store=no_store, + must_revalidate=must_revalidate, + ) + + def _has_freshness_headers(self, headers: Mapping[str, str]) -> bool: + """Return True when response includes cache freshness directives.""" + normalized = {k.lower() for k in headers} + return "cache-control" in normalized or "expires" in normalized def is_cimd_client_id(self, client_id: str) -> bool: """Check if a client_id looks like a CIMD URL. @@ -200,7 +287,7 @@ def is_cimd_client_id(self, client_id: str) -> bool: async def fetch(self, client_id_url: str) -> CIMDDocument: """Fetch and validate a CIMD document with SSRF protection. - Uses ssrf_safe_fetch for the HTTP layer, which provides: + Uses ssrf_safe_fetch_response for the HTTP layer, which provides: - HTTPS only, DNS resolution with IP validation - DNS pinning (connects to validated IP directly) - Blocks private/loopback/link-local/multicast IPs @@ -218,26 +305,76 @@ async def fetch(self, client_id_url: str) -> CIMDDocument: CIMDFetchError: If document cannot be fetched """ cached = self._cache.get(client_id_url) + now = time.time() + request_headers: dict[str, str] | None = None + allowed_status_codes = {200} + if cached is not None: - doc, expires_at = cached - if time.time() < expires_at: - return doc + if not cached.must_revalidate and now < cached.expires_at: + return cached.doc + + request_headers = {} + if cached.etag: + request_headers["If-None-Match"] = cached.etag + if cached.last_modified: + request_headers["If-Modified-Since"] = cached.last_modified + if request_headers: + allowed_status_codes = {200, 304} try: - content = await ssrf_safe_fetch( + response = await ssrf_safe_fetch_response( client_id_url, require_path=True, max_size=self.MAX_RESPONSE_SIZE, timeout=self.timeout, overall_timeout=30.0, + request_headers=request_headers, + allowed_status_codes=allowed_status_codes, ) except SSRFError as e: raise CIMDValidationError(str(e)) from e except SSRFFetchError as e: raise CIMDFetchError(str(e)) from e + if response.status_code == 304: + if cached is None: + raise CIMDFetchError( + "CIMD server returned 304 Not Modified without cached document" + ) + + now = time.time() + if self._has_freshness_headers(response.headers): + policy = self._parse_cache_policy(response.headers, now) + else: + # RFC allows 304 to omit unchanged headers. Preserve existing + # cache policy rather than resetting to fallback defaults. + policy = _CIMDCachePolicy( + etag=None, + last_modified=None, + expires_at=now + cached.freshness_lifetime, + freshness_lifetime=cached.freshness_lifetime, + no_store=False, + must_revalidate=cached.must_revalidate, + ) + + if not policy.no_store: + self._cache[client_id_url] = _CIMDCacheEntry( + doc=cached.doc, + etag=policy.etag or cached.etag, + last_modified=policy.last_modified or cached.last_modified, + expires_at=policy.expires_at, + freshness_lifetime=policy.freshness_lifetime, + must_revalidate=policy.must_revalidate, + ) + else: + self._cache.pop(client_id_url, None) + return cached.doc + + now = time.time() + policy = self._parse_cache_policy(response.headers, now) + try: - data = json.loads(content) + data = json.loads(response.content) except json.JSONDecodeError as e: raise CIMDValidationError(f"CIMD document is not valid JSON: {e}") from e @@ -268,7 +405,18 @@ async def fetch(self, client_id_url: str) -> CIMDDocument: doc.client_name, ) - self._cache[client_id_url] = (doc, time.time() + self.DEFAULT_CACHE_TTL_SECONDS) + if not policy.no_store: + self._cache[client_id_url] = _CIMDCacheEntry( + doc=doc, + etag=policy.etag, + last_modified=policy.last_modified, + expires_at=policy.expires_at, + freshness_lifetime=policy.freshness_lifetime, + must_revalidate=policy.must_revalidate, + ) + else: + self._cache.pop(client_id_url, None) + return doc def validate_redirect_uri(self, doc: CIMDDocument, redirect_uri: str) -> bool: diff --git a/src/fastmcp/server/auth/oauth_proxy/models.py b/src/fastmcp/server/auth/oauth_proxy/models.py index 575c846baa..7525b6a0bf 100644 --- a/src/fastmcp/server/auth/oauth_proxy/models.py +++ b/src/fastmcp/server/auth/oauth_proxy/models.py @@ -181,12 +181,27 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: "redirect_uri must be specified when CIMD redirect_uris uses wildcards." ) try: - return AnyUrl(candidate) + resolved = AnyUrl(candidate) except Exception as e: raise InvalidRedirectUriError( f"Invalid CIMD redirect_uri: {e}" ) from e + # Respect proxy-level redirect URI restrictions even when the + # client omits redirect_uri and we fall back to CIMD defaults. + if ( + self.allowed_redirect_uri_patterns is not None + and not validate_redirect_uri( + redirect_uri=resolved, + allowed_patterns=self.allowed_redirect_uri_patterns, + ) + ): + raise InvalidRedirectUriError( + f"Redirect URI '{resolved}' does not match allowed patterns." + ) + + return resolved + raise InvalidRedirectUriError( "redirect_uri must be specified when CIMD lists multiple redirect_uris." ) @@ -207,7 +222,7 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: f"Redirect URI '{redirect_uri}' does not match CIMD redirect_uris." ) - if self.allowed_redirect_uri_patterns: + if self.allowed_redirect_uri_patterns is not None: if not validate_redirect_uri( redirect_uri=redirect_uri, allowed_patterns=self.allowed_redirect_uri_patterns, diff --git a/src/fastmcp/server/auth/ssrf.py b/src/fastmcp/server/auth/ssrf.py index 8009269c6d..39c28e9591 100644 --- a/src/fastmcp/server/auth/ssrf.py +++ b/src/fastmcp/server/auth/ssrf.py @@ -12,6 +12,7 @@ import ipaddress import socket import time +from collections.abc import Mapping from dataclasses import dataclass from urllib.parse import urlparse @@ -134,6 +135,15 @@ class ValidatedURL: resolved_ips: list[str] +@dataclass +class SSRFFetchResponse: + """Response payload from an SSRF-safe fetch.""" + + content: bytes + status_code: int + headers: dict[str, str] + + async def validate_url(url: str, require_path: bool = False) -> ValidatedURL: """Validate URL for SSRF and resolve to IPs. @@ -215,12 +225,39 @@ async def ssrf_safe_fetch( SSRFError: If SSRF validation fails SSRFFetchError: If fetch fails """ + response = await ssrf_safe_fetch_response( + url, + require_path=require_path, + max_size=max_size, + timeout=timeout, + overall_timeout=overall_timeout, + allowed_status_codes={200}, + ) + return response.content + + +async def ssrf_safe_fetch_response( + url: str, + *, + require_path: bool = False, + max_size: int = 5120, + timeout: float = 10.0, + overall_timeout: float = 30.0, + request_headers: Mapping[str, str] | None = None, + allowed_status_codes: set[int] | None = None, +) -> SSRFFetchResponse: + """Fetch URL with SSRF protection and return response metadata. + + This is equivalent to :func:`ssrf_safe_fetch` but returns response headers + and status code, and supports conditional request headers. + """ start_time = time.monotonic() # Validate URL and resolve DNS validated = await validate_url(url, require_path=require_path) last_error: Exception | None = None + expected_statuses = allowed_status_codes or {200} for pinned_ip in validated.resolved_ips: elapsed = time.monotonic() - start_time @@ -239,6 +276,14 @@ async def ssrf_safe_fetch( pinned_ip, ) + headers = {"Host": validated.hostname} + if request_headers: + for key, value in request_headers.items(): + # Host must remain pinned to the validated hostname. + if key.lower() == "host": + continue + headers[key] = value + try: # Use httpx with streaming to enforce size limit during download async with ( @@ -255,14 +300,14 @@ async def ssrf_safe_fetch( client.stream( "GET", pinned_url, - headers={"Host": validated.hostname}, + headers=headers, extensions={"sni_hostname": validated.hostname}, ) as response, ): if time.monotonic() - start_time > overall_timeout: raise SSRFFetchError(f"Overall timeout exceeded: {url}") - if response.status_code != 200: + if response.status_code not in expected_statuses: raise SSRFFetchError(f"HTTP {response.status_code} fetching {url}") # Check Content-Length header first if available @@ -290,7 +335,11 @@ async def ssrf_safe_fetch( ) chunks.append(chunk) - return b"".join(chunks) + return SSRFFetchResponse( + content=b"".join(chunks), + status_code=response.status_code, + headers=dict(response.headers), + ) except httpx.TimeoutException as e: last_error = e diff --git a/tests/server/auth/test_cimd.py b/tests/server/auth/test_cimd.py index d3c3e316e2..111d863c7c 100644 --- a/tests/server/auth/test_cimd.py +++ b/tests/server/auth/test_cimd.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from unittest.mock import AsyncMock, patch import pytest @@ -247,6 +248,243 @@ async def test_fetch_ttl_cache(self, fetcher: CIMDFetcher, httpx_mock, mock_dns) assert first.client_id == second.client_id assert len(httpx_mock.get_requests()) == 1 + async def test_fetch_cache_control_max_age( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Cache-Control max-age should prevent refetch before expiry.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Max-Age App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"cache-control": "max-age=60", "content-length": "200"}, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + + assert first.client_name == second.client_name + assert len(httpx_mock.get_requests()) == 1 + + async def test_fetch_etag_revalidation_304( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Expired cache should revalidate with ETag and accept 304.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "ETag App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={ + "cache-control": "max-age=0", + "etag": '"v1"', + "content-length": "200", + }, + ) + httpx_mock.add_response( + status_code=304, + headers={ + "cache-control": "max-age=120", + "etag": '"v1"', + "content-length": "0", + }, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + requests = httpx_mock.get_requests() + + assert first.client_name == "ETag App" + assert second.client_name == "ETag App" + assert len(requests) == 2 + assert requests[1].headers.get("if-none-match") == '"v1"' + + async def test_fetch_last_modified_revalidation_304( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Expired cache should revalidate with Last-Modified and accept 304.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Last-Modified App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + last_modified = "Wed, 21 Oct 2015 07:28:00 GMT" + httpx_mock.add_response( + json=doc_data, + headers={ + "cache-control": "max-age=0", + "last-modified": last_modified, + "content-length": "200", + }, + ) + httpx_mock.add_response( + status_code=304, + headers={"cache-control": "max-age=120", "content-length": "0"}, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + requests = httpx_mock.get_requests() + + assert first.client_name == "Last-Modified App" + assert second.client_name == "Last-Modified App" + assert len(requests) == 2 + assert requests[1].headers.get("if-modified-since") == last_modified + + async def test_fetch_cache_control_no_store( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Cache-Control no-store should prevent storing CIMD documents.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "No-Store App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={"cache-control": "no-store", "content-length": "200"}, + ) + httpx_mock.add_response( + json=doc_data, + headers={"cache-control": "no-store", "content-length": "200"}, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + + assert first.client_name == second.client_name + assert len(httpx_mock.get_requests()) == 2 + + async def test_fetch_cache_control_no_cache( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """Cache-Control no-cache should force revalidation on each fetch.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "No-Cache App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={ + "cache-control": "no-cache", + "etag": '"v2"', + "content-length": "200", + }, + ) + httpx_mock.add_response( + status_code=304, + headers={ + "cache-control": "no-cache", + "etag": '"v2"', + "content-length": "0", + }, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + requests = httpx_mock.get_requests() + + assert first.client_name == "No-Cache App" + assert second.client_name == "No-Cache App" + assert len(requests) == 2 + assert requests[1].headers.get("if-none-match") == '"v2"' + + async def test_fetch_304_without_cache_headers_preserves_policy( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """304 responses without cache headers should not reset cached policy.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "No-Header-304 App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={ + "cache-control": "no-cache", + "etag": '"v3"', + "content-length": "200", + }, + ) + # Intentionally omit cache-control/expires on 304. + httpx_mock.add_response( + status_code=304, + headers={"content-length": "0"}, + ) + httpx_mock.add_response( + status_code=304, + headers={"content-length": "0"}, + ) + + first = await fetcher.fetch(url) + second = await fetcher.fetch(url) + third = await fetcher.fetch(url) + requests = httpx_mock.get_requests() + + assert first.client_name == "No-Header-304 App" + assert second.client_name == "No-Header-304 App" + assert third.client_name == "No-Header-304 App" + assert len(requests) == 3 + assert requests[1].headers.get("if-none-match") == '"v3"' + assert requests[2].headers.get("if-none-match") == '"v3"' + + async def test_fetch_304_without_cache_headers_refreshes_cached_freshness( + self, fetcher: CIMDFetcher, httpx_mock, mock_dns + ): + """A header-less 304 should renew freshness using cached lifetime.""" + url = "https://example.com/client.json" + doc_data = { + "client_id": url, + "client_name": "Headerless 304 Freshness App", + "redirect_uris": ["http://localhost:3000/callback"], + "token_endpoint_auth_method": "none", + } + httpx_mock.add_response( + json=doc_data, + headers={ + "cache-control": "max-age=60", + "etag": '"v4"', + "content-length": "200", + }, + ) + httpx_mock.add_response( + status_code=304, + headers={"content-length": "0"}, + ) + + first = await fetcher.fetch(url) + + # Simulate cache expiry so the next request triggers revalidation. + cached_entry = fetcher._cache[url] + cached_entry.expires_at = time.time() - 1 + + second = await fetcher.fetch(url) + third = await fetcher.fetch(url) + requests = httpx_mock.get_requests() + + assert first.client_name == "Headerless 304 Freshness App" + assert second.client_name == "Headerless 304 Freshness App" + assert third.client_name == "Headerless 304 Freshness App" + assert len(requests) == 2 + assert requests[1].headers.get("if-none-match") == '"v4"' + async def test_fetch_client_id_mismatch( self, fetcher: CIMDFetcher, httpx_mock, mock_dns ): diff --git a/tests/server/auth/test_oauth_proxy_redirect_validation.py b/tests/server/auth/test_oauth_proxy_redirect_validation.py index 391977b886..47ecfbe8d1 100644 --- a/tests/server/auth/test_oauth_proxy_redirect_validation.py +++ b/tests/server/auth/test_oauth_proxy_redirect_validation.py @@ -155,6 +155,23 @@ def test_cimd_none_redirect_uri_single_exact(self): result = client.validate_redirect_uri(None) assert result == AnyUrl("http://localhost:3000/callback") + def test_cimd_none_redirect_uri_respects_proxy_patterns(self): + """CIMD fallback redirect_uri must still satisfy proxy allowlist patterns.""" + cimd_doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["https://evil.com/callback"], + ) + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=cimd_doc, + allowed_redirect_uri_patterns=["http://localhost:*"], + ) + + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(None) + def test_cimd_none_redirect_uri_wildcard_rejected(self): """CIMD clients must specify redirect_uri when only wildcard patterns exist.""" cimd_doc = CIMDDocument( @@ -171,6 +188,23 @@ def test_cimd_none_redirect_uri_wildcard_rejected(self): with pytest.raises(InvalidRedirectUriError): client.validate_redirect_uri(None) + def test_cimd_empty_proxy_allowlist_rejects_redirect_uri(self): + """An explicit empty proxy allowlist should reject all CIMD redirect URIs.""" + cimd_doc = CIMDDocument( + client_id=AnyHttpUrl("https://example.com/client.json"), + redirect_uris=["http://localhost:3000/callback"], + ) + client = ProxyDCRClient( + client_id="https://example.com/client.json", + client_secret=None, + redirect_uris=None, + cimd_document=cimd_doc, + allowed_redirect_uri_patterns=[], + ) + + with pytest.raises(InvalidRedirectUriError): + client.validate_redirect_uri(AnyUrl("http://localhost:3000/callback")) + class TestOAuthProxyRedirectValidation: """Test OAuth proxy with redirect URI validation."""