Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/fastmcp/server/auth/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from fastmcp.server.auth.oauth_proxy import OAuthProxy
from fastmcp.utilities.auth import parse_scopes
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.token_cache import TokenCache

logger = get_logger(__name__)

Expand All @@ -42,30 +43,49 @@ class GitHubTokenVerifier(TokenVerifier):

GitHub OAuth tokens are opaque (not JWTs), so we verify them
by calling GitHub's API to check if they're valid and get user info.

Caching is disabled by default. Set ``cache_ttl_seconds`` to a positive
integer to cache successful verification results and avoid repeated
GitHub API calls for the same token.
"""

def __init__(
self,
*,
required_scopes: list[str] | None = None,
timeout_seconds: int = 10,
cache_ttl_seconds: int | None = None,
max_cache_size: int | None = None,
http_client: httpx.AsyncClient | None = None,
):
"""Initialize the GitHub token verifier.

Args:
required_scopes: Required OAuth scopes (e.g., ['user:email'])
timeout_seconds: HTTP request timeout
cache_ttl_seconds: How long to cache verification results in seconds.
Caching is disabled by default (None). Set to a positive integer
to enable (e.g., 300 for 5 minutes).
max_cache_size: Maximum number of tokens to cache. Default: 10 000.
http_client: Optional httpx.AsyncClient for connection pooling. When provided,
the client is reused across calls and the caller is responsible for its
lifecycle. When None (default), a fresh client is created per call.
"""
super().__init__(required_scopes=required_scopes)
self.timeout_seconds = timeout_seconds
self._http_client = http_client
self._cache = TokenCache(
ttl_seconds=cache_ttl_seconds,
max_size=max_cache_size,
)

async def verify_token(self, token: str) -> AccessToken | None:
"""Verify GitHub OAuth token by calling GitHub API."""
is_cached, cached_result = self._cache.get(token)
if is_cached:
logger.debug("GitHub token cache hit")
return cached_result

try:
async with (
contextlib.nullcontext(self._http_client)
Expand Down Expand Up @@ -104,6 +124,7 @@ async def verify_token(self, token: str) -> AccessToken | None:
)

# Extract scopes from X-OAuth-Scopes header if available
scopes_verified = scopes_response.status_code == 200
oauth_scopes_header = scopes_response.headers.get("x-oauth-scopes", "")
token_scopes = [
scope.strip()
Expand All @@ -128,7 +149,7 @@ async def verify_token(self, token: str) -> AccessToken | None:
return None

# Create AccessToken with GitHub user info
return AccessToken(
result = AccessToken(
token=token,
client_id=str(user_data.get("id", "unknown")), # Use GitHub user ID
scopes=token_scopes,
Expand All @@ -142,6 +163,9 @@ async def verify_token(self, token: str) -> AccessToken | None:
"github_user_data": user_data,
},
)
if scopes_verified:
self._cache.set(token, result)
return result

except httpx.RequestError as e:
logger.debug("Failed to verify GitHub token: %s", e)
Expand Down Expand Up @@ -189,6 +213,8 @@ def __init__(
redirect_path: str | None = None,
required_scopes: list[str] | None = None,
timeout_seconds: int = 10,
cache_ttl_seconds: int | None = None,
max_cache_size: int | None = None,
allowed_client_redirect_uris: list[str] | None = None,
client_storage: AsyncKeyValue | None = None,
jwt_signing_key: str | bytes | None = None,
Expand All @@ -207,6 +233,10 @@ def __init__(
redirect_path: Redirect path configured in GitHub OAuth app (defaults to "/auth/callback")
required_scopes: Required GitHub scopes (defaults to ["user"])
timeout_seconds: HTTP request timeout for GitHub API calls (defaults to 10)
cache_ttl_seconds: How long to cache token verification results in seconds.
Caching is disabled by default (None). Set to a positive integer to
enable (e.g., 300 for 5 minutes).
max_cache_size: Maximum number of tokens to cache. Default: 10 000.
allowed_client_redirect_uris: List of allowed redirect URI patterns for MCP clients.
If None (default), all URIs are allowed. If empty list, no URIs are allowed.
client_storage: Storage backend for OAuth state (client registrations, encrypted tokens).
Expand Down Expand Up @@ -234,6 +264,8 @@ def __init__(
token_verifier = GitHubTokenVerifier(
required_scopes=required_scopes_final,
timeout_seconds=timeout_seconds,
cache_ttl_seconds=cache_ttl_seconds,
max_cache_size=max_cache_size,
http_client=http_client,
)

Expand Down
111 changes: 6 additions & 105 deletions src/fastmcp/server/auth/providers/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

import base64
import contextlib
import hashlib
import time
from dataclasses import dataclass
from typing import Any, Literal, get_args

import httpx
Expand All @@ -36,18 +34,11 @@
from fastmcp.server.auth import AccessToken, TokenVerifier
from fastmcp.utilities.auth import parse_scopes
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.token_cache import TokenCache

logger = get_logger(__name__)


@dataclass
class _IntrospectionCacheEntry:
"""Cached introspection result with expiration."""

result: AccessToken
expires_at: float


ClientAuthMethod = Literal["client_secret_basic", "client_secret_post"]


Expand Down Expand Up @@ -86,9 +77,6 @@ class IntrospectionTokenVerifier(TokenVerifier):
```
"""

# Default cache settings
DEFAULT_MAX_CACHE_SIZE = 10000

def __init__(
self,
*,
Expand Down Expand Up @@ -154,96 +142,9 @@ def __init__(
self._http_client = http_client
self.logger = get_logger(__name__)

# Cache configuration (None or 0 = disabled)
self._cache_ttl = cache_ttl_seconds or 0
self._max_cache_size = (
max_cache_size
if max_cache_size is not None
else self.DEFAULT_MAX_CACHE_SIZE
)
self._cache: dict[str, _IntrospectionCacheEntry] = {}
self._last_cleanup = time.monotonic()
self._cleanup_interval = 60 # Cleanup every 60 seconds

def _hash_token(self, token: str) -> str:
"""Hash token for use as cache key.

Using SHA-256 for memory efficiency (fixed 64-char hex digest
regardless of token length).
"""
return hashlib.sha256(token.encode("utf-8")).hexdigest()

def _cleanup_expired_cache(self) -> None:
"""Remove expired entries from cache."""
now = time.time()
expired = [key for key, entry in self._cache.items() if entry.expires_at < now]
for key in expired:
del self._cache[key]
if expired:
self.logger.debug("Cleaned up %d expired cache entries", len(expired))

def _maybe_cleanup(self) -> None:
"""Periodically cleanup expired entries to prevent unbounded growth."""
now = time.monotonic()
if now - self._last_cleanup > self._cleanup_interval:
self._cleanup_expired_cache()
self._last_cleanup = now

def _get_cached(self, token: str) -> tuple[bool, AccessToken | None]:
"""Get cached introspection result.

Returns:
Tuple of (is_cached, result):
- (True, AccessToken) if cached valid token
- (False, None) if not in cache or expired
"""
if self._cache_ttl <= 0 or self._max_cache_size <= 0:
return (False, None) # Caching disabled

cache_key = self._hash_token(token)
entry = self._cache.get(cache_key)

if entry is None:
return (False, None) # Not in cache

if entry.expires_at < time.time():
del self._cache[cache_key]
return (False, None) # Expired

# Return a copy to prevent mutations from affecting cached value
return (True, entry.result.model_copy(deep=True))

def _set_cached(self, token: str, result: AccessToken) -> None:
"""Cache a valid introspection result with TTL.

Only successful validations are cached. Failures (inactive, expired,
missing scopes, errors) are never cached to avoid sticky false negatives.
"""
if self._cache_ttl <= 0 or self._max_cache_size <= 0:
return # Caching disabled

# Periodic cleanup
self._maybe_cleanup()

# Check cache size limit
if len(self._cache) >= self._max_cache_size:
self._cleanup_expired_cache()
# If still at limit after cleanup, evict oldest entry
if len(self._cache) >= self._max_cache_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]

cache_key = self._hash_token(token)

# Use token's expiration if available and sooner than TTL
expires_at = time.time() + self._cache_ttl
if result.expires_at:
expires_at = min(expires_at, float(result.expires_at))

# Store a deep copy to prevent mutations from affecting cached value
self._cache[cache_key] = _IntrospectionCacheEntry(
result=result.model_copy(deep=True),
expires_at=expires_at,
self._cache = TokenCache(
ttl_seconds=cache_ttl_seconds,
max_size=max_cache_size,
)

def _create_basic_auth_header(self) -> str:
Expand Down Expand Up @@ -293,7 +194,7 @@ async def verify_token(self, token: str) -> AccessToken | None:
AccessToken object if valid and active, None if invalid, inactive, or expired
"""
# Check cache first
is_cached, cached_result = self._get_cached(token)
is_cached, cached_result = self._cache.get(token)
if is_cached:
self.logger.debug("Token introspection cache hit")
return cached_result
Expand Down Expand Up @@ -388,7 +289,7 @@ async def verify_token(self, token: str) -> AccessToken | None:
expires_at=int(exp) if exp else None,
claims=introspection_data, # Store full response for extensibility
)
self._set_cached(token, result)
self._cache.set(token, result)
return result

except httpx.TimeoutException:
Expand Down
Loading
Loading