From b1aa57412259297da5e13fa75e98814f0c7d80e9 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 18 Mar 2026 10:56:12 -0400 Subject: [PATCH 1/3] feat: add TokenCache utility and caching to GitHubTokenVerifier Extract the caching machinery from IntrospectionTokenVerifier into a shared TokenCache class in fastmcp.utilities.token_cache, then wire it into both IntrospectionTokenVerifier and GitHubTokenVerifier. --- src/fastmcp/server/auth/providers/github.py | 32 ++- .../server/auth/providers/introspection.py | 108 +--------- src/fastmcp/utilities/token_cache.py | 166 ++++++++++++++ tests/server/auth/providers/test_github.py | 142 +++++++++++- .../auth/providers/test_introspection.py | 33 ++- .../server/auth/providers/test_propelauth.py | 4 +- tests/utilities/test_token_cache.py | 203 ++++++++++++++++++ 7 files changed, 563 insertions(+), 125 deletions(-) create mode 100644 src/fastmcp/utilities/token_cache.py create mode 100644 tests/utilities/test_token_cache.py diff --git a/src/fastmcp/server/auth/providers/github.py b/src/fastmcp/server/auth/providers/github.py index 303d97ddd6..4127d41af8 100644 --- a/src/fastmcp/server/auth/providers/github.py +++ b/src/fastmcp/server/auth/providers/github.py @@ -33,6 +33,7 @@ from fastmcp.server.auth.oauth_proxy import OAuthProxy from fastmcp.utilities.auth import parse_scopes from fastmcp.utilities.logging import get_logger +from fastmcp.utilities.token_cache import TokenCache logger = get_logger(__name__) @@ -42,6 +43,10 @@ class GitHubTokenVerifier(TokenVerifier): GitHub OAuth tokens are opaque (not JWTs), so we verify them by calling GitHub's API to check if they're valid and get user info. + + Caching is disabled by default. Set ``cache_ttl_seconds`` to a positive + integer to cache successful verification results and avoid repeated + GitHub API calls for the same token. """ def __init__( @@ -49,6 +54,8 @@ def __init__( *, required_scopes: list[str] | None = None, timeout_seconds: int = 10, + cache_ttl_seconds: int | None = None, + max_cache_size: int | None = None, http_client: httpx.AsyncClient | None = None, ): """Initialize the GitHub token verifier. @@ -56,6 +63,10 @@ def __init__( Args: required_scopes: Required OAuth scopes (e.g., ['user:email']) timeout_seconds: HTTP request timeout + cache_ttl_seconds: How long to cache verification results in seconds. + Caching is disabled by default (None). Set to a positive integer + to enable (e.g., 300 for 5 minutes). + max_cache_size: Maximum number of tokens to cache. Default: 10 000. http_client: Optional httpx.AsyncClient for connection pooling. When provided, the client is reused across calls and the caller is responsible for its lifecycle. When None (default), a fresh client is created per call. @@ -63,9 +74,18 @@ def __init__( super().__init__(required_scopes=required_scopes) self.timeout_seconds = timeout_seconds self._http_client = http_client + self._cache = TokenCache( + ttl_seconds=cache_ttl_seconds, + max_size=max_cache_size, + ) async def verify_token(self, token: str) -> AccessToken | None: """Verify GitHub OAuth token by calling GitHub API.""" + is_cached, cached_result = self._cache.get(token) + if is_cached: + logger.debug("GitHub token cache hit") + return cached_result + try: async with ( contextlib.nullcontext(self._http_client) @@ -128,7 +148,7 @@ async def verify_token(self, token: str) -> AccessToken | None: return None # Create AccessToken with GitHub user info - return AccessToken( + result = AccessToken( token=token, client_id=str(user_data.get("id", "unknown")), # Use GitHub user ID scopes=token_scopes, @@ -142,6 +162,8 @@ async def verify_token(self, token: str) -> AccessToken | None: "github_user_data": user_data, }, ) + self._cache.set(token, result) + return result except httpx.RequestError as e: logger.debug("Failed to verify GitHub token: %s", e) @@ -189,6 +211,8 @@ def __init__( redirect_path: str | None = None, required_scopes: list[str] | None = None, timeout_seconds: int = 10, + cache_ttl_seconds: int | None = None, + max_cache_size: int | None = None, allowed_client_redirect_uris: list[str] | None = None, client_storage: AsyncKeyValue | None = None, jwt_signing_key: str | bytes | None = None, @@ -207,6 +231,10 @@ def __init__( redirect_path: Redirect path configured in GitHub OAuth app (defaults to "/auth/callback") required_scopes: Required GitHub scopes (defaults to ["user"]) timeout_seconds: HTTP request timeout for GitHub API calls (defaults to 10) + cache_ttl_seconds: How long to cache token verification results in seconds. + Caching is disabled by default (None). Set to a positive integer to + enable (e.g., 300 for 5 minutes). + max_cache_size: Maximum number of tokens to cache. Default: 10 000. allowed_client_redirect_uris: List of allowed redirect URI patterns for MCP clients. If None (default), all URIs are allowed. If empty list, no URIs are allowed. client_storage: Storage backend for OAuth state (client registrations, encrypted tokens). @@ -234,6 +262,8 @@ def __init__( token_verifier = GitHubTokenVerifier( required_scopes=required_scopes_final, timeout_seconds=timeout_seconds, + cache_ttl_seconds=cache_ttl_seconds, + max_cache_size=max_cache_size, http_client=http_client, ) diff --git a/src/fastmcp/server/auth/providers/introspection.py b/src/fastmcp/server/auth/providers/introspection.py index 09e20abfde..bb49c539f4 100644 --- a/src/fastmcp/server/auth/providers/introspection.py +++ b/src/fastmcp/server/auth/providers/introspection.py @@ -25,9 +25,7 @@ import base64 import contextlib -import hashlib import time -from dataclasses import dataclass from typing import Any, Literal, get_args import httpx @@ -36,18 +34,11 @@ from fastmcp.server.auth import AccessToken, TokenVerifier from fastmcp.utilities.auth import parse_scopes from fastmcp.utilities.logging import get_logger +from fastmcp.utilities.token_cache import TokenCache logger = get_logger(__name__) -@dataclass -class _IntrospectionCacheEntry: - """Cached introspection result with expiration.""" - - result: AccessToken - expires_at: float - - ClientAuthMethod = Literal["client_secret_basic", "client_secret_post"] @@ -154,96 +145,9 @@ def __init__( self._http_client = http_client self.logger = get_logger(__name__) - # Cache configuration (None or 0 = disabled) - self._cache_ttl = cache_ttl_seconds or 0 - self._max_cache_size = ( - max_cache_size - if max_cache_size is not None - else self.DEFAULT_MAX_CACHE_SIZE - ) - self._cache: dict[str, _IntrospectionCacheEntry] = {} - self._last_cleanup = time.monotonic() - self._cleanup_interval = 60 # Cleanup every 60 seconds - - def _hash_token(self, token: str) -> str: - """Hash token for use as cache key. - - Using SHA-256 for memory efficiency (fixed 64-char hex digest - regardless of token length). - """ - return hashlib.sha256(token.encode("utf-8")).hexdigest() - - def _cleanup_expired_cache(self) -> None: - """Remove expired entries from cache.""" - now = time.time() - expired = [key for key, entry in self._cache.items() if entry.expires_at < now] - for key in expired: - del self._cache[key] - if expired: - self.logger.debug("Cleaned up %d expired cache entries", len(expired)) - - def _maybe_cleanup(self) -> None: - """Periodically cleanup expired entries to prevent unbounded growth.""" - now = time.monotonic() - if now - self._last_cleanup > self._cleanup_interval: - self._cleanup_expired_cache() - self._last_cleanup = now - - def _get_cached(self, token: str) -> tuple[bool, AccessToken | None]: - """Get cached introspection result. - - Returns: - Tuple of (is_cached, result): - - (True, AccessToken) if cached valid token - - (False, None) if not in cache or expired - """ - if self._cache_ttl <= 0 or self._max_cache_size <= 0: - return (False, None) # Caching disabled - - cache_key = self._hash_token(token) - entry = self._cache.get(cache_key) - - if entry is None: - return (False, None) # Not in cache - - if entry.expires_at < time.time(): - del self._cache[cache_key] - return (False, None) # Expired - - # Return a copy to prevent mutations from affecting cached value - return (True, entry.result.model_copy(deep=True)) - - def _set_cached(self, token: str, result: AccessToken) -> None: - """Cache a valid introspection result with TTL. - - Only successful validations are cached. Failures (inactive, expired, - missing scopes, errors) are never cached to avoid sticky false negatives. - """ - if self._cache_ttl <= 0 or self._max_cache_size <= 0: - return # Caching disabled - - # Periodic cleanup - self._maybe_cleanup() - - # Check cache size limit - if len(self._cache) >= self._max_cache_size: - self._cleanup_expired_cache() - # If still at limit after cleanup, evict oldest entry - if len(self._cache) >= self._max_cache_size: - oldest_key = next(iter(self._cache)) - del self._cache[oldest_key] - - cache_key = self._hash_token(token) - - # Use token's expiration if available and sooner than TTL - expires_at = time.time() + self._cache_ttl - if result.expires_at: - expires_at = min(expires_at, float(result.expires_at)) - - # Store a deep copy to prevent mutations from affecting cached value - self._cache[cache_key] = _IntrospectionCacheEntry( - result=result.model_copy(deep=True), - expires_at=expires_at, + self._cache = TokenCache( + ttl_seconds=cache_ttl_seconds, + max_size=max_cache_size, ) def _create_basic_auth_header(self) -> str: @@ -293,7 +197,7 @@ async def verify_token(self, token: str) -> AccessToken | None: AccessToken object if valid and active, None if invalid, inactive, or expired """ # Check cache first - is_cached, cached_result = self._get_cached(token) + is_cached, cached_result = self._cache.get(token) if is_cached: self.logger.debug("Token introspection cache hit") return cached_result @@ -388,7 +292,7 @@ async def verify_token(self, token: str) -> AccessToken | None: expires_at=int(exp) if exp else None, claims=introspection_data, # Store full response for extensibility ) - self._set_cached(token, result) + self._cache.set(token, result) return result except httpx.TimeoutException: diff --git a/src/fastmcp/utilities/token_cache.py b/src/fastmcp/utilities/token_cache.py new file mode 100644 index 0000000000..10c4f64a18 --- /dev/null +++ b/src/fastmcp/utilities/token_cache.py @@ -0,0 +1,166 @@ +"""In-memory cache for token verification results. + +Provides a generic TTL-based cache for ``AccessToken`` objects, designed to +reduce repeated network calls during opaque-token verification. Only +*successful* verifications should be cached; errors and failures must be +retried on every request. + +Example: + ```python + from fastmcp.utilities.token_cache import TokenCache + + cache = TokenCache(ttl_seconds=300, max_size=10000) + + # On cache miss, call the upstream verifier and store the result. + hit, token = cache.get(raw_token) + if not hit: + token = await _call_upstream(raw_token) + if token is not None: + cache.set(raw_token, token) + ``` +""" + +from __future__ import annotations + +import hashlib +import time +from dataclasses import dataclass + +from fastmcp.server.auth.auth import AccessToken +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + +DEFAULT_MAX_CACHE_SIZE = 10_000 +_CLEANUP_INTERVAL = 60 # seconds between periodic sweeps + + +@dataclass +class _CacheEntry: + """A cached token result with its absolute expiration timestamp.""" + + result: AccessToken + expires_at: float + + +class TokenCache: + """TTL-based in-memory cache for ``AccessToken`` objects. + + Features: + - SHA-256 hashed cache keys (fixed size, regardless of token length). + - Per-entry TTL that respects both the configured ``ttl_seconds`` and the + token's own ``expires_at`` claim (whichever is sooner). + - Bounded size with FIFO eviction when the cache is full. + - Periodic cleanup of expired entries to prevent unbounded growth. + - Defensive deep copies on both store and retrieve to prevent + callers from mutating cached values. + + Caching is disabled when ``ttl_seconds`` is ``None``, ``0``, or + negative, or when ``max_size`` is ``0`` or negative. + """ + + def __init__( + self, + *, + ttl_seconds: int | None = None, + max_size: int | None = None, + ) -> None: + """Initialise the cache. + + Args: + ttl_seconds: How long cached entries remain valid, in seconds. + ``None`` or ``0`` disables caching entirely. + max_size: Upper bound on the number of entries. When the limit is + reached, expired entries are swept first; if still full the + oldest entry is evicted. Defaults to 10 000. + """ + self._ttl = ttl_seconds or 0 + self._max_size = max_size if max_size is not None else DEFAULT_MAX_CACHE_SIZE + self._entries: dict[str, _CacheEntry] = {} + self._last_cleanup = time.monotonic() + + @property + def enabled(self) -> bool: + """Return whether caching is active.""" + return self._ttl > 0 and self._max_size > 0 + + # -- public API ---------------------------------------------------------- + + def get(self, token: str) -> tuple[bool, AccessToken | None]: + """Look up a cached verification result. + + Returns: + ``(True, AccessToken)`` on a cache hit, ``(False, None)`` on a miss + or when caching is disabled. The returned ``AccessToken`` is a deep + copy that is safe to mutate. + """ + if not self.enabled: + return (False, None) + + cache_key = self._hash_token(token) + entry = self._entries.get(cache_key) + + if entry is None: + return (False, None) + + if entry.expires_at < time.time(): + del self._entries[cache_key] + return (False, None) + + return (True, entry.result.model_copy(deep=True)) + + def set(self, token: str, result: AccessToken) -> None: + """Store a *successful* verification result. + + Only successful verifications should be cached. Failures (inactive + tokens, missing scopes, HTTP errors, timeouts) must **not** be cached + so that transient problems do not produce sticky false negatives. + """ + if not self.enabled: + return + + self._maybe_cleanup() + self._enforce_size_limit() + + cache_key = self._hash_token(token) + + expires_at = time.time() + self._ttl + if result.expires_at: + expires_at = min(expires_at, float(result.expires_at)) + + self._entries[cache_key] = _CacheEntry( + result=result.model_copy(deep=True), + expires_at=expires_at, + ) + + # -- internals ----------------------------------------------------------- + + @staticmethod + def _hash_token(token: str) -> str: + """Return the SHA-256 hex digest of *token*.""" + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + def _cleanup_expired(self) -> None: + """Remove all entries whose TTL has elapsed.""" + now = time.time() + expired = [k for k, v in self._entries.items() if v.expires_at < now] + for key in expired: + del self._entries[key] + if expired: + logger.debug("Cleaned up %d expired cache entries", len(expired)) + + def _maybe_cleanup(self) -> None: + """Run ``_cleanup_expired`` at most once per cleanup interval.""" + now = time.monotonic() + if now - self._last_cleanup > _CLEANUP_INTERVAL: + self._cleanup_expired() + self._last_cleanup = now + + def _enforce_size_limit(self) -> None: + """Ensure there is room for at least one new entry.""" + if len(self._entries) < self._max_size: + return + self._cleanup_expired() + if len(self._entries) >= self._max_size: + oldest_key = next(iter(self._entries)) + del self._entries[oldest_key] diff --git a/tests/server/auth/providers/test_github.py b/tests/server/auth/providers/test_github.py index 05a4a72e24..848bb9a486 100644 --- a/tests/server/auth/providers/test_github.py +++ b/tests/server/auth/providers/test_github.py @@ -1,6 +1,6 @@ """Unit tests for GitHub OAuth provider.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from key_value.aio.stores.memory import MemoryStore @@ -100,8 +100,6 @@ async def test_verify_token_github_api_failure(self): async def test_verify_token_success(self): """Test successful token verification.""" - from unittest.mock import AsyncMock - verifier = GitHubTokenVerifier(required_scopes=["user"]) # Mock the httpx.AsyncClient directly @@ -139,3 +137,141 @@ async def test_verify_token_success(self): assert result.scopes == ["user", "repo"] assert result.claims["login"] == "testuser" assert result.claims["name"] == "Test User" + + +def _mock_github_success(mock_client: AsyncMock) -> None: + """Configure *mock_client* to return a successful GitHub user + scopes response.""" + user_response = MagicMock() + user_response.status_code = 200 + user_response.json.return_value = { + "id": 12345, + "login": "testuser", + "name": "Test User", + "email": "test@example.com", + "avatar_url": "https://github.com/testuser.png", + } + + scopes_response = MagicMock() + scopes_response.headers = {"x-oauth-scopes": "user,repo"} + + mock_client.get.side_effect = [user_response, scopes_response] + + +def _mock_github_failure(mock_client: AsyncMock) -> None: + """Configure *mock_client* to return a 401 GitHub response.""" + fail_response = MagicMock() + fail_response.status_code = 401 + fail_response.text = "Bad credentials" + mock_client.get.return_value = fail_response + + +class TestGitHubTokenVerifierCaching: + """Test caching behaviour on GitHubTokenVerifier.""" + + def test_cache_disabled_by_default(self): + verifier = GitHubTokenVerifier() + assert not verifier._cache.enabled + + def test_cache_enabled_with_ttl(self): + verifier = GitHubTokenVerifier(cache_ttl_seconds=300) + assert verifier._cache.enabled + + async def test_cache_hit_avoids_second_api_call(self): + verifier = GitHubTokenVerifier( + required_scopes=["user"], + cache_ttl_seconds=300, + ) + + mock_client = AsyncMock() + + with patch( + "fastmcp.server.auth.providers.github.httpx.AsyncClient" + ) as mock_cls: + mock_cls.return_value.__aenter__.return_value = mock_client + + _mock_github_success(mock_client) + result1 = await verifier.verify_token("tok-1") + assert result1 is not None + assert mock_client.get.call_count == 2 # /user + /user/repos + + result2 = await verifier.verify_token("tok-1") + assert result2 is not None + assert result2.client_id == result1.client_id + assert mock_client.get.call_count == 2 # no additional calls + + async def test_cache_disabled_makes_every_call(self): + verifier = GitHubTokenVerifier( + required_scopes=["user"], + cache_ttl_seconds=0, + ) + + mock_client = AsyncMock() + + with patch( + "fastmcp.server.auth.providers.github.httpx.AsyncClient" + ) as mock_cls: + mock_cls.return_value.__aenter__.return_value = mock_client + + _mock_github_success(mock_client) + await verifier.verify_token("tok-1") + assert mock_client.get.call_count == 2 + + _mock_github_success(mock_client) + await verifier.verify_token("tok-1") + assert mock_client.get.call_count == 4 + + async def test_failures_are_not_cached(self): + verifier = GitHubTokenVerifier(cache_ttl_seconds=300) + + mock_client = AsyncMock() + + with patch( + "fastmcp.server.auth.providers.github.httpx.AsyncClient" + ) as mock_cls: + mock_cls.return_value.__aenter__.return_value = mock_client + + _mock_github_failure(mock_client) + result1 = await verifier.verify_token("bad-tok") + assert result1 is None + + _mock_github_success(mock_client) + result2 = await verifier.verify_token("bad-tok") + assert result2 is not None + + async def test_cached_result_is_defensive_copy(self): + verifier = GitHubTokenVerifier( + required_scopes=["user"], + cache_ttl_seconds=300, + ) + + mock_client = AsyncMock() + + with patch( + "fastmcp.server.auth.providers.github.httpx.AsyncClient" + ) as mock_cls: + mock_cls.return_value.__aenter__.return_value = mock_client + + _mock_github_success(mock_client) + result1 = await verifier.verify_token("tok-1") + assert result1 is not None + result1.claims["login"] = "MUTATED" + + result2 = await verifier.verify_token("tok-1") + assert result2 is not None + assert result2.claims["login"] == "testuser" + + def test_provider_passes_cache_params(self, memory_storage: MemoryStore): + provider = GitHubProvider( + client_id="cid", + client_secret="csec", + base_url="https://example.com", + cache_ttl_seconds=120, + max_cache_size=500, + jwt_signing_key="test-secret", + client_storage=memory_storage, + ) + verifier = provider._token_validator + assert isinstance(verifier, GitHubTokenVerifier) + assert verifier._cache.enabled + assert verifier._cache._ttl == 120 + assert verifier._cache._max_size == 500 diff --git a/tests/server/auth/providers/test_introspection.py b/tests/server/auth/providers/test_introspection.py index 793570987d..27c032365a 100644 --- a/tests/server/auth/providers/test_introspection.py +++ b/tests/server/auth/providers/test_introspection.py @@ -554,8 +554,7 @@ def test_default_cache_settings(self): client_id="test-client", client_secret="test-secret", ) - assert verifier._cache_ttl == 0 # Disabled by default - assert verifier._max_cache_size == 10000 + assert not verifier._cache.enabled def test_custom_cache_settings(self): """Test that cache settings can be customized.""" @@ -566,8 +565,8 @@ def test_custom_cache_settings(self): cache_ttl_seconds=60, max_cache_size=500, ) - assert verifier._cache_ttl == 60 - assert verifier._max_cache_size == 500 + assert verifier._cache._ttl == 60 + assert verifier._cache._max_size == 500 def test_cache_disabled_with_zero_ttl(self): """Test that cache is disabled when TTL is 0 or None.""" @@ -578,7 +577,7 @@ def test_cache_disabled_with_zero_ttl(self): client_secret="test-secret", cache_ttl_seconds=0, ) - assert verifier._cache_ttl == 0 + assert not verifier._cache.enabled # Explicit None (same as default) verifier2 = IntrospectionTokenVerifier( @@ -587,7 +586,7 @@ def test_cache_disabled_with_zero_ttl(self): client_secret="test-secret", cache_ttl_seconds=None, ) - assert verifier2._cache_ttl == 0 + assert not verifier2._cache.enabled async def test_cache_disabled_with_zero_or_negative_max_size( self, httpx_mock: HTTPXMock @@ -854,9 +853,9 @@ async def test_timeout_errors_are_not_cached( def test_token_hashing(self, verifier_with_cache: IntrospectionTokenVerifier): """Test that tokens are hashed consistently.""" - hash1 = verifier_with_cache._hash_token("test-token") - hash2 = verifier_with_cache._hash_token("test-token") - hash3 = verifier_with_cache._hash_token("different-token") + hash1 = verifier_with_cache._cache._hash_token("test-token") + hash2 = verifier_with_cache._cache._hash_token("test-token") + hash3 = verifier_with_cache._cache._hash_token("different-token") # Same token produces same hash assert hash1 == hash2 @@ -885,8 +884,8 @@ async def test_cache_respects_token_expiration( await verifier_with_cache.verify_token("test-token") # Check that cache entry uses the shorter expiration - cache_key = verifier_with_cache._hash_token("test-token") - entry = verifier_with_cache._cache[cache_key] + cache_key = verifier_with_cache._cache._hash_token("test-token") + entry = verifier_with_cache._cache._entries[cache_key] # Cache expiration should be at or before token expiration assert entry.expires_at <= short_exp @@ -917,8 +916,8 @@ async def test_expired_cache_entry_triggers_new_introspection( assert len(httpx_mock.get_requests()) == 1 # Expire the cache entry manually - cache_key = verifier._hash_token("test-token") - verifier._cache[cache_key].expires_at = time.time() - 1 + cache_key = verifier._cache._hash_token("test-token") + verifier._cache._entries[cache_key].expires_at = time.time() - 1 # Second call — cache miss, new introspection await verifier.verify_token("test-token") @@ -944,15 +943,15 @@ async def test_cache_eviction_at_max_size(self, httpx_mock: HTTPXMock): # Fill cache to capacity await verifier.verify_token("token-0") await verifier.verify_token("token-1") - assert len(verifier._cache) == 2 + assert len(verifier._cache._entries) == 2 # Third token should evict the oldest entry await verifier.verify_token("token-2") - assert len(verifier._cache) == 2 + assert len(verifier._cache._entries) == 2 # token-0 should have been evicted (FIFO) - hash_0 = verifier._hash_token("token-0") - assert hash_0 not in verifier._cache + hash_0 = verifier._cache._hash_token("token-0") + assert hash_0 not in verifier._cache._entries class TestIntrospectionTokenVerifierIntegration: diff --git a/tests/server/auth/providers/test_propelauth.py b/tests/server/auth/providers/test_propelauth.py index 9d70496fc9..f06efe6857 100644 --- a/tests/server/auth/providers/test_propelauth.py +++ b/tests/server/auth/providers/test_propelauth.py @@ -133,8 +133,8 @@ def test_token_introspection_overrides_cache(self): ) assert isinstance(provider.token_verifier, IntrospectionTokenVerifier) - assert provider.token_verifier._cache_ttl == 300 - assert provider.token_verifier._max_cache_size == 500 + assert provider.token_verifier._cache._ttl == 300 + assert provider.token_verifier._cache._max_size == 500 def test_token_introspection_overrides_http_client(self): """Test that http_client override is passed to the verifier.""" diff --git a/tests/utilities/test_token_cache.py b/tests/utilities/test_token_cache.py new file mode 100644 index 0000000000..369338f77c --- /dev/null +++ b/tests/utilities/test_token_cache.py @@ -0,0 +1,203 @@ +"""Tests for the shared TokenCache utility.""" + +import time + +import pytest + +from fastmcp.server.auth.auth import AccessToken +from fastmcp.utilities.token_cache import TokenCache + + +def _make_token( + *, + token: str = "tok", + client_id: str = "client-1", + scopes: list[str] | None = None, + expires_at: int | None = None, +) -> AccessToken: + return AccessToken( + token=token, + client_id=client_id, + scopes=scopes or [], + expires_at=expires_at, + ) + + +class TestTokenCacheDisabled: + """Verify behaviour when caching is turned off.""" + + @pytest.mark.parametrize( + "ttl, max_size", + [ + (None, None), + (0, 100), + (-1, 100), + (300, 0), + (300, -1), + ], + ) + def test_disabled_configurations(self, ttl: int | None, max_size: int | None): + cache = TokenCache(ttl_seconds=ttl, max_size=max_size) + assert not cache.enabled + + def test_get_returns_miss_when_disabled(self): + cache = TokenCache(ttl_seconds=0) + cache.set("tok", _make_token()) + hit, result = cache.get("tok") + assert not hit + assert result is None + + def test_set_is_noop_when_disabled(self): + cache = TokenCache(ttl_seconds=0) + cache.set("tok", _make_token()) + assert len(cache._entries) == 0 + + +class TestTokenCacheEnabled: + """Core get/set behaviour with caching on.""" + + @pytest.fixture + def cache(self) -> TokenCache: + return TokenCache(ttl_seconds=300, max_size=100) + + def test_enabled(self, cache: TokenCache): + assert cache.enabled + + def test_set_and_get(self, cache: TokenCache): + access = _make_token(client_id="user-1") + cache.set("tok-1", access) + + hit, result = cache.get("tok-1") + assert hit + assert result is not None + assert result.client_id == "user-1" + + def test_miss_for_unknown_token(self, cache: TokenCache): + hit, result = cache.get("unknown") + assert not hit + assert result is None + + def test_different_tokens_cached_separately(self, cache: TokenCache): + cache.set("tok-a", _make_token(client_id="a")) + cache.set("tok-b", _make_token(client_id="b")) + + _, a = cache.get("tok-a") + _, b = cache.get("tok-b") + assert a is not None and a.client_id == "a" + assert b is not None and b.client_id == "b" + + +class TestTokenCacheDefensiveCopy: + """Mutating a returned token must not affect the cached value.""" + + def test_get_returns_deep_copy(self): + cache = TokenCache(ttl_seconds=300, max_size=100) + access = _make_token(client_id="orig") + access.claims = {"key": "original"} + cache.set("tok", access) + + _, first = cache.get("tok") + assert first is not None + first.claims["key"] = "mutated" + first.scopes.append("admin") + + _, second = cache.get("tok") + assert second is not None + assert second.claims["key"] == "original" + assert "admin" not in second.scopes + + def test_mutating_source_does_not_affect_cache(self): + cache = TokenCache(ttl_seconds=300, max_size=100) + access = _make_token(client_id="orig") + access.claims = {"key": "original"} + cache.set("tok", access) + + access.claims["key"] = "mutated" + + _, cached = cache.get("tok") + assert cached is not None + assert cached.claims["key"] == "original" + + +class TestTokenCacheTTL: + """Expiration and TTL behaviour.""" + + def test_expired_entry_is_evicted_on_get(self): + cache = TokenCache(ttl_seconds=300, max_size=100) + cache.set("tok", _make_token()) + + key = cache._hash_token("tok") + cache._entries[key].expires_at = time.time() - 1 + + hit, result = cache.get("tok") + assert not hit + assert result is None + assert key not in cache._entries + + def test_token_expires_at_caps_ttl(self): + cache = TokenCache(ttl_seconds=300, max_size=100) + short_exp = int(time.time()) + 30 + cache.set("tok", _make_token(expires_at=short_exp)) + + key = cache._hash_token("tok") + assert cache._entries[key].expires_at <= short_exp + + def test_ttl_used_when_no_token_expiry(self): + cache = TokenCache(ttl_seconds=60, max_size=100) + before = time.time() + cache.set("tok", _make_token(expires_at=None)) + after = time.time() + + key = cache._hash_token("tok") + entry = cache._entries[key] + assert before + 60 <= entry.expires_at <= after + 60 + + +class TestTokenCacheSizeLimit: + """Eviction and size-limit behaviour.""" + + def test_evicts_oldest_when_full(self): + cache = TokenCache(ttl_seconds=300, max_size=2) + cache.set("tok-0", _make_token(client_id="0")) + cache.set("tok-1", _make_token(client_id="1")) + cache.set("tok-2", _make_token(client_id="2")) + + assert len(cache._entries) == 2 + hit_0, _ = cache.get("tok-0") + assert not hit_0 + + hit_1, _ = cache.get("tok-1") + hit_2, _ = cache.get("tok-2") + assert hit_1 + assert hit_2 + + def test_cleanup_expired_before_eviction(self): + cache = TokenCache(ttl_seconds=300, max_size=2) + cache.set("tok-0", _make_token(client_id="0")) + cache.set("tok-1", _make_token(client_id="1")) + + key_0 = cache._hash_token("tok-0") + cache._entries[key_0].expires_at = time.time() - 1 + + cache.set("tok-2", _make_token(client_id="2")) + + assert len(cache._entries) == 2 + hit_1, _ = cache.get("tok-1") + hit_2, _ = cache.get("tok-2") + assert hit_1 + assert hit_2 + + +class TestTokenCacheHashing: + """SHA-256 key hashing.""" + + def test_consistent_hashing(self): + assert TokenCache._hash_token("abc") == TokenCache._hash_token("abc") + + def test_different_tokens_different_hashes(self): + assert TokenCache._hash_token("abc") != TokenCache._hash_token("xyz") + + def test_hash_is_64_hex_chars(self): + h = TokenCache._hash_token("anything") + assert len(h) == 64 + int(h, 16) # must be valid hex From fc9fa3a8ba127248ad6c1f32747470a016551065 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:15:11 -0400 Subject: [PATCH 2/3] Remove dead constant, validate negative cache params --- .../server/auth/providers/introspection.py | 3 -- src/fastmcp/utilities/token_cache.py | 6 +++ .../auth/providers/test_introspection.py | 47 +++++++++---------- tests/utilities/test_token_cache.py | 10 +++- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/fastmcp/server/auth/providers/introspection.py b/src/fastmcp/server/auth/providers/introspection.py index bb49c539f4..65d9d86ced 100644 --- a/src/fastmcp/server/auth/providers/introspection.py +++ b/src/fastmcp/server/auth/providers/introspection.py @@ -77,9 +77,6 @@ class IntrospectionTokenVerifier(TokenVerifier): ``` """ - # Default cache settings - DEFAULT_MAX_CACHE_SIZE = 10000 - def __init__( self, *, diff --git a/src/fastmcp/utilities/token_cache.py b/src/fastmcp/utilities/token_cache.py index 10c4f64a18..b2a9bcda9b 100644 --- a/src/fastmcp/utilities/token_cache.py +++ b/src/fastmcp/utilities/token_cache.py @@ -74,6 +74,12 @@ def __init__( reached, expired entries are swept first; if still full the oldest entry is evicted. Defaults to 10 000. """ + if ttl_seconds is not None and ttl_seconds < 0: + raise ValueError( + f"cache_ttl_seconds must be non-negative, got {ttl_seconds}" + ) + if max_size is not None and max_size < 0: + raise ValueError(f"max_cache_size must be non-negative, got {max_size}") self._ttl = ttl_seconds or 0 self._max_size = max_size if max_size is not None else DEFAULT_MAX_CACHE_SIZE self._entries: dict[str, _CacheEntry] = {} diff --git a/tests/server/auth/providers/test_introspection.py b/tests/server/auth/providers/test_introspection.py index 27c032365a..46dcd2f0d6 100644 --- a/tests/server/auth/providers/test_introspection.py +++ b/tests/server/auth/providers/test_introspection.py @@ -588,23 +588,18 @@ def test_cache_disabled_with_zero_ttl(self): ) assert not verifier2._cache.enabled - async def test_cache_disabled_with_zero_or_negative_max_size( - self, httpx_mock: HTTPXMock - ): - """Test that cache is disabled when max_cache_size is 0 or negative.""" - # Add two responses for the two verifiers - for _ in range(2): - httpx_mock.add_response( - url="https://auth.example.com/oauth/introspect", - method="POST", - json={ - "active": True, - "client_id": "user-123", - "scope": "read", - }, - ) + async def test_cache_disabled_with_zero_max_size(self, httpx_mock: HTTPXMock): + """Test that cache is disabled when max_cache_size is 0.""" + httpx_mock.add_response( + url="https://auth.example.com/oauth/introspect", + method="POST", + json={ + "active": True, + "client_id": "user-123", + "scope": "read", + }, + ) - # Zero max_cache_size should disable caching (not raise StopIteration) verifier = IntrospectionTokenVerifier( introspection_url="https://auth.example.com/oauth/introspect", client_id="test-client", @@ -616,16 +611,16 @@ async def test_cache_disabled_with_zero_or_negative_max_size( assert result is not None assert result.client_id == "user-123" - # Negative max_cache_size should also disable caching - verifier2 = IntrospectionTokenVerifier( - introspection_url="https://auth.example.com/oauth/introspect", - client_id="test-client", - client_secret="test-secret", - cache_ttl_seconds=300, - max_cache_size=-1, - ) - result2 = await verifier2.verify_token("test-token") - assert result2 is not None + def test_negative_max_cache_size_raises(self): + """Negative max_cache_size is a caller bug and should raise.""" + with pytest.raises(ValueError, match="max_cache_size must be non-negative"): + IntrospectionTokenVerifier( + introspection_url="https://auth.example.com/oauth/introspect", + client_id="test-client", + client_secret="test-secret", + cache_ttl_seconds=300, + max_cache_size=-1, + ) async def test_cache_hit_returns_cached_result( self, verifier_with_cache: IntrospectionTokenVerifier, httpx_mock: HTTPXMock diff --git a/tests/utilities/test_token_cache.py b/tests/utilities/test_token_cache.py index 369338f77c..73da40a2dc 100644 --- a/tests/utilities/test_token_cache.py +++ b/tests/utilities/test_token_cache.py @@ -31,15 +31,21 @@ class TestTokenCacheDisabled: [ (None, None), (0, 100), - (-1, 100), (300, 0), - (300, -1), ], ) def test_disabled_configurations(self, ttl: int | None, max_size: int | None): cache = TokenCache(ttl_seconds=ttl, max_size=max_size) assert not cache.enabled + def test_negative_ttl_raises(self): + with pytest.raises(ValueError, match="cache_ttl_seconds must be non-negative"): + TokenCache(ttl_seconds=-1) + + def test_negative_max_size_raises(self): + with pytest.raises(ValueError, match="max_cache_size must be non-negative"): + TokenCache(max_size=-1) + def test_get_returns_miss_when_disabled(self): cache = TokenCache(ttl_seconds=0) cache.set("tok", _make_token()) From a756e0fa488b2a41006351756edcd6eae7be5a1f Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 18 Mar 2026 12:51:24 -0400 Subject: [PATCH 3/3] Fix overwrite eviction bug, skip cache on scope lookup failure --- src/fastmcp/server/auth/providers/github.py | 4 ++- src/fastmcp/utilities/token_cache.py | 11 +++---- tests/server/auth/providers/test_github.py | 32 +++++++++++++++++++++ tests/utilities/test_token_cache.py | 17 +++++++++++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/fastmcp/server/auth/providers/github.py b/src/fastmcp/server/auth/providers/github.py index 4127d41af8..42b65a73ff 100644 --- a/src/fastmcp/server/auth/providers/github.py +++ b/src/fastmcp/server/auth/providers/github.py @@ -124,6 +124,7 @@ async def verify_token(self, token: str) -> AccessToken | None: ) # Extract scopes from X-OAuth-Scopes header if available + scopes_verified = scopes_response.status_code == 200 oauth_scopes_header = scopes_response.headers.get("x-oauth-scopes", "") token_scopes = [ scope.strip() @@ -162,7 +163,8 @@ async def verify_token(self, token: str) -> AccessToken | None: "github_user_data": user_data, }, ) - self._cache.set(token, result) + if scopes_verified: + self._cache.set(token, result) return result except httpx.RequestError as e: diff --git a/src/fastmcp/utilities/token_cache.py b/src/fastmcp/utilities/token_cache.py index b2a9bcda9b..9446090d39 100644 --- a/src/fastmcp/utilities/token_cache.py +++ b/src/fastmcp/utilities/token_cache.py @@ -55,8 +55,8 @@ class TokenCache: - Defensive deep copies on both store and retrieve to prevent callers from mutating cached values. - Caching is disabled when ``ttl_seconds`` is ``None``, ``0``, or - negative, or when ``max_size`` is ``0`` or negative. + Caching is disabled when ``ttl_seconds`` is ``None`` or ``0``, or + when ``max_size`` is ``0``. Negative values raise ``ValueError``. """ def __init__( @@ -125,11 +125,12 @@ def set(self, token: str, result: AccessToken) -> None: if not self.enabled: return - self._maybe_cleanup() - self._enforce_size_limit() - cache_key = self._hash_token(token) + self._maybe_cleanup() + if cache_key not in self._entries: + self._enforce_size_limit() + expires_at = time.time() + self._ttl if result.expires_at: expires_at = min(expires_at, float(result.expires_at)) diff --git a/tests/server/auth/providers/test_github.py b/tests/server/auth/providers/test_github.py index 848bb9a486..11683f8b6f 100644 --- a/tests/server/auth/providers/test_github.py +++ b/tests/server/auth/providers/test_github.py @@ -152,6 +152,7 @@ def _mock_github_success(mock_client: AsyncMock) -> None: } scopes_response = MagicMock() + scopes_response.status_code = 200 scopes_response.headers = {"x-oauth-scopes": "user,repo"} mock_client.get.side_effect = [user_response, scopes_response] @@ -260,6 +261,37 @@ async def test_cached_result_is_defensive_copy(self): assert result2 is not None assert result2.claims["login"] == "testuser" + async def test_scope_failure_skips_cache(self): + """Token verified with fallback scopes (scope API failed) should not be cached.""" + verifier = GitHubTokenVerifier(cache_ttl_seconds=300) + + mock_client = AsyncMock() + + user_response = MagicMock() + user_response.status_code = 200 + user_response.json.return_value = { + "id": 12345, + "login": "testuser", + "name": "Test User", + "email": "test@example.com", + "avatar_url": "https://github.com/testuser.png", + } + + scopes_response = MagicMock() + scopes_response.status_code = 500 + scopes_response.headers = {} + + with patch( + "fastmcp.server.auth.providers.github.httpx.AsyncClient" + ) as mock_cls: + mock_cls.return_value.__aenter__.return_value = mock_client + + mock_client.get.side_effect = [user_response, scopes_response] + result = await verifier.verify_token("tok-1") + assert result is not None + # Should NOT be cached because scope response was not 200 + assert not verifier._cache.enabled or len(verifier._cache._entries) == 0 + def test_provider_passes_cache_params(self, memory_storage: MemoryStore): provider = GitHubProvider( client_id="cid", diff --git a/tests/utilities/test_token_cache.py b/tests/utilities/test_token_cache.py index 73da40a2dc..93bbae40b7 100644 --- a/tests/utilities/test_token_cache.py +++ b/tests/utilities/test_token_cache.py @@ -193,6 +193,23 @@ def test_cleanup_expired_before_eviction(self): assert hit_1 assert hit_2 + def test_overwrite_does_not_evict(self): + """Overwriting an existing key should not evict another entry.""" + cache = TokenCache(ttl_seconds=300, max_size=2) + cache.set("tok-0", _make_token(client_id="0")) + cache.set("tok-1", _make_token(client_id="1")) + + # Overwrite tok-0 — should NOT evict tok-1 + cache.set("tok-0", _make_token(client_id="0-updated")) + + assert len(cache._entries) == 2 + hit_0, result_0 = cache.get("tok-0") + hit_1, _ = cache.get("tok-1") + assert hit_0 + assert hit_1 + assert result_0 is not None + assert result_0.client_id == "0-updated" + class TestTokenCacheHashing: """SHA-256 key hashing."""