From bdcd8c259926222f46b11325df55958a9fbdff48 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Wed, 12 Nov 2025 15:48:55 +0200 Subject: [PATCH 1/3] Add assisted-service-mcp.log and .coverage to .gitignore Signed-off-by: Eran Cohen --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ae37e0e..687b363 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ build/ dist/ wheels/ *.egg-info +# tests generated files +.coverage +assisted-service-mcp.log # Virtual environments .venv From e7560de6240e93cd8b51dfc7bae9a06c1e481a34 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Sun, 16 Nov 2025 15:27:42 +0200 Subject: [PATCH 2/3] Add OAuth authentication support for MCP server - Implement OAuth2 authentication flow - Update API and MCP server to support OAuth authentication - Add OAuth tests and auth priority tests - Include OAuth documentation and setup guides - Add startup script for OAuth server --- README.md | 137 ++++- assisted_service_mcp/src/api.py | 141 +++++ assisted_service_mcp/src/mcp.py | 124 +++- assisted_service_mcp/src/oauth/__init__.py | 49 ++ assisted_service_mcp/src/oauth/manager.py | 598 +++++++++++++++++++ assisted_service_mcp/src/oauth/middleware.py | 320 ++++++++++ assisted_service_mcp/src/oauth/models.py | 159 +++++ assisted_service_mcp/src/oauth/store.py | 189 ++++++ assisted_service_mcp/src/oauth/utils.py | 102 ++++ assisted_service_mcp/src/settings.py | 46 ++ assisted_service_mcp/utils/auth.py | 67 ++- doc/OAUTH_SETUP.md | 123 ++++ doc/oauth_authentication.md | 227 +++++++ oauth-config.env | 18 + pyproject.toml | 8 +- start-oauth-server.sh | 49 ++ tests/src/test_mcp.py | 1 - tests/test_auth_priority.py | 101 ++++ tests/test_oauth.py | 260 ++++++++ tests/test_oauth_integration.py | 120 ++++ tests/utils/test_auth.py | 12 +- uv.lock | 20 + 22 files changed, 2812 insertions(+), 59 deletions(-) create mode 100644 assisted_service_mcp/src/oauth/__init__.py create mode 100644 assisted_service_mcp/src/oauth/manager.py create mode 100644 assisted_service_mcp/src/oauth/middleware.py create mode 100644 assisted_service_mcp/src/oauth/models.py create mode 100644 assisted_service_mcp/src/oauth/store.py create mode 100644 assisted_service_mcp/src/oauth/utils.py create mode 100644 doc/OAUTH_SETUP.md create mode 100644 doc/oauth_authentication.md create mode 100644 oauth-config.env create mode 100755 start-oauth-server.sh create mode 100644 tests/test_auth_priority.py create mode 100644 tests/test_oauth.py create mode 100644 tests/test_oauth_integration.py diff --git a/README.md b/README.md index 8fc3704..2daa6cb 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,90 @@ MCP server for interacting with the OpenShift assisted installer API. Diagnose cluster failures and find out how to fix them. -Try it out: +## Quick Start -1. Clone the repo: -``` -git clone git@github.com:openshift-assisted/assisted-service-mcp.git +### Option 1: Simple Token Setup + +1. **Get your OpenShift API token** from https://cloud.redhat.com/openshift/token + +2. **Clone and run**: + ```bash + git clone git@github.com:openshift-assisted/assisted-service-mcp.git + cd assisted-service-mcp + OFFLINE_TOKEN= uv run python -m assisted_service_mcp.src.main + ``` + +3. **Configure your MCP client** (Cursor/Copilot): + ```json + { + "assisted-service-mcp": { + "transport": "streamable-http", + "url": "http://127.0.0.1:8000/mcp" + } + } + ``` + +### Option 2: OAuth Authentication (Advanced) + +For automatic token management with Red Hat SSO: + +1. **Clone the repo**: + ```bash + git clone git@github.com:openshift-assisted/assisted-service-mcp.git + cd assisted-service-mcp + ``` + +2. **Start the OAuth-enabled server**: + ```bash + ./start-oauth-server.sh + ``` + +3. **Configure your MCP client** (Cursor/Copilot): + ```json + { + "assisted-service-mcp": { + "transport": "streamable-http", + "url": "http://127.0.0.1:8000/mcp" + } + } + ``` + +4. **Connect and authenticate**: When you connect from Cursor, a browser will open automatically for Red Hat SSO authentication. + +**For detailed OAuth setup instructions, see [OAUTH_SETUP.md](doc/OAUTH_SETUP.md)** + +### Option 3: OCM-Offline-Token Header +#### Note: this option is available only when OAuth is disabled + +1. **Get your OpenShift API token** from https://cloud.redhat.com/openshift/token + +2. **Clone and run**: + ```bash + git clone git@github.com:openshift-assisted/assisted-service-mcp.git + cd assisted-service-mcp + uv run python -m assisted_service_mcp.src.main + ``` + +```json + "mcpServers": { + "assisted": { + "transport": "streamable-http", + "url": "http://127.0.0.1:8000/mcp", + "headers": { + "OCM-Offline-Token": "" + } + } + } ``` -2. Get your OpenShift API token from https://cloud.redhat.com/openshift/token +## Advanced Transport Options + +The recommended transport is streamable-http as shown in the examples above. +Other transport methods or detailed configuration: -3. The server is started and configured differently depending on what transport you want to use +**Configure the server** depending on your preferred transport: -For STDIO: +#### STDIO Transport In VSCode for example: ```json @@ -32,44 +104,49 @@ In VSCode for example: "/path/to/assisted-service-mcp/assisted_service_mcp/src/main.py" ], "env": { - "OFFLINE_TOKEN": + "OFFLINE_TOKEN": "" } } } } ``` -For SSE (recommended): +#### Server-Sent Events (SSE) Transport (Alternative) +#### Note: SSE is supported for backward compatibility, Streamable HTTP is the recommended transport +Start the server with SSE transport: -Start the server in a terminal: +`OFFLINE_TOKEN= TRANSPORT=sse uv run python -m assisted_service_mcp.src.main` -`OFFLINE_TOKEN= uv run assisted_service_mcp.src.main` - -Configure the server in the client: +Configure the client: ```json - "assisted-sse": { - "transport": "sse", - "url": "http://localhost:8000/sse" - } +{ + "assisted-sse": { + "transport": "sse", + "url": "http://127.0.0.1:8000/sse" + } +} ``` -### Providing the Offline Token via Request Header +## Authentication Methods -If you do not set the `OFFLINE_TOKEN` environment variable, you can provide the token as a request header. -When configuring your MCP client, add the `OCM-Offline-Token` header: +The server supports multiple authentication methods with automatic priority handling: -```json - "assisted-sse": { - "transport": "sse", - "url": "http://localhost:8000/sse", - "headers": { - "OCM-Offline-Token": "" - } - } -``` +1. **Authorization Header** - `Bearer ` in request headers +2. **OAuth Flow** (when `OAUTH_ENABLED=true`) - Automatic browser-based authentication +3. **Environment Variable** - `OFFLINE_TOKEN` environment variable +4. **OCM-Offline-Token Header** - `OCM-Offline-Token: ` in request headers + +### OAuth Benefits (Advanced Users) + +**No Manual Token Management** - Tokens are obtained and cached automatically +**Secure PKCE Flow** - Enhanced OAuth security with Proof Key for Code Exchange +**Automatic Token Refresh** - Expired tokens are refreshed transparently using refresh tokens +**Multi-Client Support** - Different MCP clients can authenticate independently + +## Usage -4. Ask about your clusters: +Ask about your clusters: ![Example prompt asking about a cluster](images/cluster-prompt-example.png) ## Available Tools diff --git a/assisted_service_mcp/src/api.py b/assisted_service_mcp/src/api.py index 955d9a1..0eee605 100644 --- a/assisted_service_mcp/src/api.py +++ b/assisted_service_mcp/src/api.py @@ -4,6 +4,12 @@ with appropriate transport protocols. """ +from typing import Awaitable, Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + from assisted_service_mcp.src.mcp import AssistedServiceMCPServer from assisted_service_mcp.src.settings import settings from assisted_service_mcp.src.logger import log, configure_logging @@ -21,3 +27,138 @@ else: app = server.mcp.sse_app() log.info("Using SSE transport (stateful)") + +# Add OAuth endpoints and middleware if OAuth is enabled +if settings.OAUTH_ENABLED: + from assisted_service_mcp.src.oauth import ( + oauth_register_handler, + oauth_callback_handler, + oauth_token_handler, + mcp_oauth_middleware, + ) + + # Add OAuth middleware to handle authentication during MCP connection + class OAuthMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + return await mcp_oauth_middleware.handle_mcp_request(request, call_next) + + app.add_middleware(OAuthMiddleware) + + # OAuth discovery endpoints for better MCP client compatibility + + async def oauth_well_known_openid_handler(_request: Request) -> JSONResponse: + """OAuth discovery endpoint.""" + return JSONResponse( + { + "issuer": settings.SELF_URL, + "authorization_endpoint": f"{settings.OAUTH_URL}/protocol/openid-connect/auth", + "token_endpoint": f"{settings.OAUTH_URL}/protocol/openid-connect/token", + "registration_endpoint": f"{settings.SELF_URL}/oauth/register", + "userinfo_endpoint": f"{settings.OAUTH_URL}/protocol/openid-connect/userinfo", + "jwks_uri": f"{settings.OAUTH_URL}/protocol/openid-connect/certs", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code"], + "code_challenge_methods_supported": ["S256"], + "scopes_supported": ["openid", "profile", "email"], + } + ) + + async def mcp_register_handler(_request: Request) -> JSONResponse: + """MCP registration endpoint.""" + return JSONResponse( + { + "name": "AssistedService", + "version": "1.0.0", + "description": "Assisted Service MCP Server with OAuth", + "oauth": { + "authorization_endpoint": f"{settings.OAUTH_URL}/protocol/openid-connect/auth", + "token_endpoint": f"{settings.OAUTH_URL}/protocol/openid-connect/token", + "client_id": settings.OAUTH_CLIENT, + "redirect_uri": f"{settings.SELF_URL}/oauth/callback", + "scopes": ["openid", "profile", "email"], + }, + } + ) + + # Wrapper functions to convert dict responses to JSONResponse for Starlette compatibility + async def wrapped_oauth_register_handler(request: Request) -> JSONResponse: + result = await oauth_register_handler(request) + return JSONResponse(result) + + async def wrapped_oauth_token_handler(request: Request) -> JSONResponse: + result = await oauth_token_handler(request) + return JSONResponse(result) + + # Use Starlette's add_route method instead of FastAPI's add_api_route + app.add_route("/oauth/register", wrapped_oauth_register_handler, methods=["GET"]) + app.add_route( + "/oauth/callback", oauth_callback_handler, methods=["GET"] + ) # This one returns Response already + app.add_route("/oauth/token", wrapped_oauth_token_handler, methods=["POST"]) + + # OAuth discovery endpoints - only the standard routes per MCP spec + app.add_route( + "/.well-known/openid-configuration/mcp", + oauth_well_known_openid_handler, + methods=["GET"], + ) + app.add_route( + "/.well-known/openid-configuration", + oauth_well_known_openid_handler, + methods=["GET"], + ) + + # OAuth status endpoint for polling + async def oauth_status_handler(request: Request) -> JSONResponse: + """Check OAuth authentication status for a client.""" + middleware_instance = mcp_oauth_middleware + + client_id = request.query_params.get("client_id") + if not client_id: + return JSONResponse( + {"error": "client_id parameter required"}, status_code=400 + ) + + # Check if client has completed authentication + from assisted_service_mcp.src.oauth import oauth_manager + + if oauth_manager.token_store.get_token_by_client(client_id): + return JSONResponse( + { + "status": "authenticated", + "message": "OAuth authentication completed successfully", + } + ) + + # Check if authentication is in progress + for ( + session_id, + session_info, + ) in middleware_instance.pending_auth_sessions.items(): + if session_info.get("client_id") == client_id: + return JSONResponse( + { + "status": "pending", + "message": "OAuth authentication in progress", + "session_id": session_id, + } + ) + + return JSONResponse( + { + "status": "not_authenticated", + "message": "No authentication found for this client", + } + ) + + # MCP registration endpoint + app.add_route("/register", mcp_register_handler, methods=["POST", "GET"]) + app.add_route("/oauth/status", oauth_status_handler, methods=["GET"]) + + log.info( + "OAuth endpoints and discovery registered: /oauth/*, /.well-known/*, /register, /oauth/status" + ) +else: + log.info("OAuth is disabled - no OAuth endpoints registered") diff --git a/assisted_service_mcp/src/mcp.py b/assisted_service_mcp/src/mcp.py index 53ac9ec..9417335 100644 --- a/assisted_service_mcp/src/mcp.py +++ b/assisted_service_mcp/src/mcp.py @@ -2,14 +2,15 @@ import asyncio import inspect +import time from functools import wraps -from typing import Any, Awaitable, Callable +from typing import Any, Awaitable, Callable, Optional from mcp.server.fastmcp import FastMCP from assisted_service_mcp.src.logger import log # Import auth utilities -from assisted_service_mcp.utils.auth import get_offline_token, get_access_token +from assisted_service_mcp.utils.auth import get_access_token from assisted_service_mcp.src.settings import settings # Import all tool modules @@ -43,10 +44,10 @@ def __init__(self) -> None: host=settings.MCP_HOST, stateless_http=use_stateless_http, ) - # Define auth helpers bound to this MCP instance - self._get_offline_token = lambda: get_offline_token(self.mcp) + self._get_oauth_token = self._create_oauth_token() self._get_access_token = lambda: get_access_token( - self.mcp, offline_token_func=self._get_offline_token + self.mcp, + oauth_token_func=self._get_oauth_token, ) self._register_mcp_tools() log.info("Assisted Service MCP Server initialized successfully") @@ -54,6 +55,119 @@ def __init__(self) -> None: log.exception("Failed to initialize Assisted Service MCP Server: %s", e) raise + def _create_oauth_token(self) -> Callable[[Any], Optional[str]]: + """Create OAuth token function with proper dependency injection. + + This avoids circular imports by importing oauth module only when needed + and creating a closure that captures the import. + + Returns: + Function that can extract OAuth tokens from MCP context + """ + + def get_oauth_token(mcp: Any) -> Optional[str]: + if not settings.OAUTH_ENABLED: + return None + + try: + # Import only when OAuth is enabled and function is called + from assisted_service_mcp.src.oauth import ( + get_oauth_access_token_from_mcp, + mcp_oauth_middleware, + oauth_manager, + open_browser_for_oauth, + ) + + # First check if we have a cached token + cached_token = get_oauth_access_token_from_mcp(mcp) + if cached_token: + return cached_token + + # Get client identifier for OAuth flow + client_id = self._get_mcp_client_identifier(mcp) + + # Check if we have a completed OAuth token for this client + token = oauth_manager.token_store.get_access_token_by_client(client_id) + if token: + log.info("Using cached OAuth token for MCP client %s", client_id) + return token + + # Check if OAuth flow is already in progress for this client + for ( + session_id, + session_info, + ) in mcp_oauth_middleware.pending_auth_sessions.items(): + if session_info.get("client_id") == client_id: + log.info( + "OAuth flow already in progress for MCP client %s", + client_id, + ) + raise RuntimeError( + "OAuth authentication in progress. Please complete the authentication in your browser, " + "then retry this command. The browser should have opened automatically." + ) + + # Start new OAuth flow with PKCE + log.info("Starting OAuth flow for MCP client %s", client_id) + + # Import OAuthState for parsing + from assisted_service_mcp.src.oauth.models import OAuthState + + # Generate OAuth URL with PKCE parameters + auth_url, state_json = oauth_manager.create_authorization_url(client_id) + + # Parse state to extract session_id + state = OAuthState.from_json(state_json) + session_id = state.session_id + + # Store session with full state_json for PKCE verification + mcp_oauth_middleware.pending_auth_sessions[session_id] = { + "client_id": client_id, + "state": state_json, # Store full JSON state with PKCE params + "auth_url": auth_url, + "timestamp": time.time(), + } + + # Open browser automatically + open_browser_for_oauth(auth_url) + + # Provide helpful error message for OAuth requirement + raise RuntimeError( + "OAuth authentication required. Please complete the authentication in your browser " + "(it should have opened automatically), then retry this command." + ) + + except ImportError as e: + log.warning("Failed to import OAuth module: %s", e) + return None + + return get_oauth_token + + def _get_mcp_client_identifier(self, mcp: Any) -> str: + """Get a unique identifier for the MCP client.""" + try: + # Try to get client info from MCP context + context = mcp.get_context() + if ( + context + and hasattr(context, "request_context") + and context.request_context + ): + request = context.request_context.request + if request: + user_agent = request.headers.get("user-agent", "unknown") + client_ip = ( + getattr(request.client, "host", "unknown") + if request.client + else "unknown" + ) + return f"{user_agent}_{client_ip}" + + # Fallback identifier + return "mcp_client_unknown" + except Exception: + return "mcp_client_fallback" + def _register_mcp_tools(self) -> None: """Register MCP tools for assisted service operations. diff --git a/assisted_service_mcp/src/oauth/__init__.py b/assisted_service_mcp/src/oauth/__init__.py new file mode 100644 index 0000000..df298ed --- /dev/null +++ b/assisted_service_mcp/src/oauth/__init__.py @@ -0,0 +1,49 @@ +"""OAuth authentication module for Assisted Service MCP Server. + +This module provides OAuth2 authentication support with PKCE for the MCP server. +""" + +from assisted_service_mcp.src.settings import settings +from assisted_service_mcp.src.oauth.manager import ( + OAuthManager, + get_oauth_access_token_from_mcp, + oauth_callback_handler, + oauth_manager, + oauth_register_handler, + oauth_token_handler, +) +from assisted_service_mcp.src.oauth.middleware import ( + MCPOAuthMiddleware, + mcp_oauth_middleware, +) +from assisted_service_mcp.src.oauth.models import OAuthState, OAuthToken +from assisted_service_mcp.src.oauth.store import TokenStore +from assisted_service_mcp.src.oauth.utils import ( + extract_oauth_callback_params, + get_oauth_success_html, + open_browser_for_oauth, +) + +__all__ = [ + # Manager + "OAuthManager", + "oauth_manager", + "oauth_register_handler", + "oauth_callback_handler", + "oauth_token_handler", + "get_oauth_access_token_from_mcp", + # Middleware + "MCPOAuthMiddleware", + "mcp_oauth_middleware", + # Models + "OAuthToken", + "OAuthState", + # Store + "TokenStore", + # Utils + "open_browser_for_oauth", + "get_oauth_success_html", + "extract_oauth_callback_params", + # Settings (for test mocking) + "settings", +] diff --git a/assisted_service_mcp/src/oauth/manager.py b/assisted_service_mcp/src/oauth/manager.py new file mode 100644 index 0000000..5b2a32a --- /dev/null +++ b/assisted_service_mcp/src/oauth/manager.py @@ -0,0 +1,598 @@ +"""OAuth authentication implementation for Assisted Service MCP Server. + +Simplified implementation using structured models and centralized token storage. +""" + +import base64 +import hashlib +import secrets +import time +import urllib.parse +from typing import Any, Dict, Optional + +import httpx +from fastapi import HTTPException, Request, Response +from fastapi.responses import HTMLResponse + +from assisted_service_mcp.src.logger import log +from assisted_service_mcp.src.oauth.models import OAuthState, OAuthToken +from assisted_service_mcp.src.oauth.store import TokenStore +from assisted_service_mcp.src.oauth.utils import ( + extract_oauth_callback_params, + get_oauth_success_html, +) +from assisted_service_mcp.src.settings import settings + + +class OAuthManager: + """Manages OAuth authentication flow for the MCP server. + + Simplified version using structured models and centralized storage. + """ + + def __init__(self) -> None: + """Initialize OAuth manager with configuration.""" + self.oauth_url = settings.OAUTH_URL + self.client_id = settings.OAUTH_CLIENT + self.self_url = settings.SELF_URL + + # Use configurable redirect URI or construct from SELF_URL + if settings.OAUTH_REDIRECT_URI: + self.redirect_uri = settings.OAUTH_REDIRECT_URI + else: + # For local development, ensure we use 127.0.0.1 which works better with Red Hat SSO + self_url_str = str(self.self_url) + if "localhost" in self_url_str: + base_url = self_url_str.replace("localhost", "127.0.0.1") + self.redirect_uri = f"{base_url}/oauth/callback" + else: + self.redirect_uri = f"{self_url_str}/oauth/callback" + + # Use centralized token store + self.token_store = TokenStore() + + # Pending OAuth states (for CSRF protection) + self._pending_states: Dict[str, OAuthState] = {} + + def generate_pkce_challenge(self) -> tuple[str, str]: + """Generate PKCE code verifier and challenge. + + Returns: + Tuple of (code_verifier, code_challenge) + """ + code_verifier = ( + base64.urlsafe_b64encode(secrets.token_bytes(32)) + .decode("utf-8") + .rstrip("=") + ) + code_challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode("utf-8")).digest() + ) + .decode("utf-8") + .rstrip("=") + ) + return code_verifier, code_challenge + + def create_authorization_url(self, client_id: str) -> tuple[str, str]: + """Create OAuth authorization URL and state. + + Args: + client_id: Client identifier for tracking + + Returns: + Tuple of (authorization_url, state_json) + """ + code_verifier, code_challenge = self.generate_pkce_challenge() + + # Create structured state + state = OAuthState( + session_id=secrets.token_hex(16), + client_id=client_id, + timestamp=time.time(), + code_verifier=code_verifier, + ) + + state_json = state.to_json() + + # Store state for validation + self._pending_states[state_json] = state + + params = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": "openid profile email", + "state": state_json, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + auth_url = f"{self.oauth_url}/protocol/openid-connect/auth" + full_url = f"{auth_url}?{urllib.parse.urlencode(params)}" + + log.debug("Created authorization URL for client %s", client_id) + return full_url, state_json + + def get_authorization_url(self, state: str) -> str: + """Generate OAuth authorization URL (backward compatibility method). + + This method maintains compatibility with the old API where state + was passed in rather than generated internally. + + Args: + state: OAuth state parameter for CSRF protection + + Returns: + Authorization URL for the OAuth provider + """ + code_verifier, code_challenge = self.generate_pkce_challenge() + + # For backward compatibility, create a simple state object + # Use the provided state as both session_id and a simple client_id + oauth_state = OAuthState( + session_id=state, + client_id=f"legacy_{state[:8]}", + timestamp=time.time(), + code_verifier=code_verifier, + ) + + # Store using the original state string as key for backward compatibility + self._pending_states[state] = oauth_state + + params = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": "openid profile email", + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + auth_url = f"{self.oauth_url}/protocol/openid-connect/auth" + return f"{auth_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token( + self, code: str, state_json: str + ) -> Optional[OAuthToken]: + """Exchange authorization code for access token. + + Args: + code: Authorization code from OAuth provider + state_json: OAuth state parameter (JSON string or legacy string) + + Returns: + OAuthToken if successful, None otherwise + + Raises: + HTTPException: If token exchange fails + """ + # Try to parse as JSON first (new format), fall back to legacy format + try: + state = OAuthState.from_json(state_json) + state_key = state_json + except ValueError as exc: + # Legacy format: plain string state + if state_json not in self._pending_states: + log.error("Unknown OAuth state") + raise HTTPException( + status_code=400, detail="Unknown OAuth state" + ) from exc + state = self._pending_states[state_json] + state_key = state_json + + # Check if state exists and is not expired + if state_key not in self._pending_states: + log.error("Unknown OAuth state") + raise HTTPException(status_code=400, detail="Unknown OAuth state") + + stored_state = self._pending_states.pop(state_key) + + if stored_state.is_expired(): + log.error("OAuth state expired") + raise HTTPException(status_code=400, detail="OAuth state expired") + + # Prepare token exchange request + data = { + "grant_type": "authorization_code", + "client_id": self.client_id, + "code": code, + "redirect_uri": self.redirect_uri, + "code_verifier": state.code_verifier, + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.oauth_url}/protocol/openid-connect/token", + data=data, + timeout=30.0, + ) + response.raise_for_status() + token_data = response.json() + + # Create token object + token_id = secrets.token_hex(16) + expires_in = token_data.get("expires_in", 3600) + + token = OAuthToken( + token_id=token_id, + client_id=state.client_id, + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + expires_at=time.time() + expires_in - 300, # 5 min safety margin + ) + + # Store token + self.token_store.store_token(token) + + log.info( + "Successfully exchanged OAuth code for token (client: %s)", + state.client_id, + ) + return token + + except httpx.HTTPError as e: + log.error("Failed to exchange OAuth code for token: %s", e) + raise HTTPException( + status_code=400, detail="Failed to exchange code for token" + ) from e + + async def get_access_token_by_id(self, token_id: str) -> Optional[str]: + """Get access token by token ID, refreshing if necessary. + + Args: + token_id: Token identifier + + Returns: + Access token if found and valid, None otherwise + """ + token = self.token_store.get_token_by_id(token_id) + if not token: + return None + + # Check if token needs refresh + if token.is_expired(): + log.info("Token %s is expired, attempting refresh", token_id) + if await self._refresh_token(token): + # Get updated token + token = self.token_store.get_token_by_id(token_id) + return token.access_token if token else None + log.warning("Failed to refresh token %s", token_id) + return None + + return token.access_token + + async def get_stored_access_token(self, token_id: str) -> Optional[str]: + """Get stored access token by ID (backward compatibility method). + + This is an alias for get_access_token_by_id() to maintain + backward compatibility with the old API. + + Args: + token_id: Token identifier + + Returns: + Access token if found and valid, None otherwise + """ + return await self.get_access_token_by_id(token_id) + + async def get_access_token_by_client(self, client_id: str) -> Optional[str]: + """Get access token for a client, refreshing if necessary. + + Args: + client_id: Client identifier + + Returns: + Access token if found and valid, None otherwise + """ + token = self.token_store.get_token_by_client(client_id) + if not token: + return None + + # Check if token needs refresh + if token.is_expired(): + log.info("Token for client %s is expired, attempting refresh", client_id) + if await self._refresh_token(token): + # Get updated token + token = self.token_store.get_token_by_client(client_id) + return token.access_token if token else None + log.warning("Failed to refresh token for client %s", client_id) + return None + + return token.access_token + + async def _refresh_token(self, token: OAuthToken) -> bool: + """Refresh an access token using the refresh token. + + Args: + token: Token to refresh + + Returns: + True if refresh was successful, False otherwise + """ + if not token.refresh_token: + log.warning("No refresh token available for token %s", token.token_id) + return False + + token_url = f"{self.oauth_url}/protocol/openid-connect/token" + + data = { + "grant_type": "refresh_token", + "client_id": self.client_id, + "refresh_token": token.refresh_token, + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post(token_url, data=data, timeout=30.0) + response.raise_for_status() + token_data = response.json() + + # Update token in store + new_access_token = token_data["access_token"] + new_refresh_token = token_data.get("refresh_token", token.refresh_token) + expires_in = token_data.get("expires_in", 3600) + new_expires_at = time.time() + expires_in - 300 + + self.token_store.update_token( + token.token_id, new_access_token, new_refresh_token, new_expires_at + ) + + log.info("Successfully refreshed token %s", token.token_id) + return True + + except httpx.HTTPError as e: + log.error("Failed to refresh token %s: %s", token.token_id, e) + self.token_store.remove_token(token.token_id) + return False + except (KeyError, ValueError) as e: + log.error("Invalid refresh token response for %s: %s", token.token_id, e) + self.token_store.remove_token(token.token_id) + return False + + def cleanup_expired_tokens(self) -> None: + """Clean up expired tokens and states.""" + # Clean up expired tokens + self.token_store.cleanup_expired_tokens() + + # Clean up expired states + expired_states = [ + state_json + for state_json, state in self._pending_states.items() + if state.is_expired() + ] + for state_json in expired_states: + del self._pending_states[state_json] + + if expired_states: + log.info("Cleaned up %d expired OAuth states", len(expired_states)) + + +# Global OAuth manager instance +oauth_manager = OAuthManager() + + +async def oauth_register_handler(_request: Request) -> Dict[str, Any]: + """Handle OAuth dynamic client registration. + + This endpoint provides the OAuth configuration that MCP clients need + to initiate the OAuth flow. + + Args: + request: FastAPI request object + + Returns: + OAuth registration response + """ + if not settings.OAUTH_ENABLED: + raise HTTPException(status_code=404, detail="OAuth not enabled") + + log.info("OAuth registration requested") + + # Generate client ID (simplified - could use request info) + client_id = f"mcp_client_{secrets.token_hex(8)}" + + # Get authorization URL + auth_url, state = oauth_manager.create_authorization_url(client_id) + + return { + "authorization_endpoint": auth_url, + "token_endpoint": f"{settings.SELF_URL}/oauth/token", + "client_id": settings.OAUTH_CLIENT, + "redirect_uri": oauth_manager.redirect_uri, + "state": state, + "response_type": "code", + "scope": "openid profile email", + } + + +async def oauth_callback_handler(request: Request) -> Response: + """Handle OAuth callback from authorization server. + + This handler works for both standard OAuth flows and MCP automatic flows. + + Args: + request: FastAPI request object containing authorization code + + Returns: + HTML response indicating success or failure + """ + if not settings.OAUTH_ENABLED: + raise HTTPException(status_code=404, detail="OAuth not enabled") + + # Extract parameters from callback + params = extract_oauth_callback_params(request) + code, state, error = params["code"], params["state"], params["error"] + + if error: + log.error("OAuth callback error: %s", error) + return HTMLResponse( + content=f""" + + +

OAuth Authentication Failed

+

Error: {error}

+

You can close this window.

+ + + """, + status_code=400, + ) + + if not code or not state: + log.error("Missing code or state in OAuth callback") + return HTMLResponse( + content=""" + + +

OAuth Authentication Failed

+

Missing authorization code or state parameter.

+

You can close this window.

+ + + """, + status_code=400, + ) + + try: + # Exchange code for token + token = await oauth_manager.exchange_code_for_token(code, state) + + if not token: + raise HTTPException(status_code=500, detail="Failed to create token") + + # Check if this is an MCP flow (state contains mcp_auth in session_id) + try: + state_obj = OAuthState.from_json(state) + is_mcp_flow = state_obj.session_id.startswith("mcp_auth_") + except ValueError: + is_mcp_flow = False + + log.info("OAuth authentication successful") + + return HTMLResponse(content=get_oauth_success_html(is_mcp_flow)) + + except HTTPException as e: + log.error("OAuth token exchange failed: %s", e.detail) + return HTMLResponse( + content=f""" + + +

OAuth Authentication Failed

+

Failed to exchange authorization code for access token.

+

Error: {e.detail}

+

You can close this window.

+ + + """, + status_code=e.status_code, + ) + + +async def oauth_token_handler(request: Request) -> Dict[str, Any]: + """Handle OAuth token requests from MCP clients. + + This endpoint is used by MCP clients to exchange authorization codes + for access tokens. + + Args: + request: FastAPI request object + + Returns: + Token response + """ + if not settings.OAUTH_ENABLED: + raise HTTPException(status_code=404, detail="OAuth not enabled") + + # Parse request body + content_type = request.headers.get("content-type", "") + if content_type.startswith("application/json"): + body = await request.json() + else: + form_data = await request.form() + body = {key: str(value) for key, value in form_data.items()} + + grant_type = body.get("grant_type") + code = body.get("code") + state = body.get("state") + + if grant_type != "authorization_code": + raise HTTPException(status_code=400, detail="Unsupported grant type") + + if not code or not state: + raise HTTPException(status_code=400, detail="Missing code or state") + + try: + # Exchange code for token + token = await oauth_manager.exchange_code_for_token(code, state) + + if not token: + raise HTTPException(status_code=500, detail="Failed to create token") + + # Handle both OAuthToken objects and dict (for backward compatibility with mocks) + if isinstance(token, dict): + return { + "access_token": token.get("access_token"), + "token_type": token.get("token_type", "Bearer"), + "expires_in": token.get("expires_in", 3600), + "refresh_token": token.get("refresh_token"), + "scope": token.get("scope", "openid profile email"), + } + return { + "access_token": token.access_token, + "token_type": "Bearer", + "expires_in": int(token.expires_at - time.time()), + "refresh_token": token.refresh_token, + "scope": "openid profile email", + } + + except HTTPException: + raise + except Exception as e: + log.error("Unexpected error in OAuth token exchange: %s", e) + raise HTTPException(status_code=500, detail="Internal server error") from e + + +def get_oauth_access_token_from_mcp(mcp: Any) -> Optional[str]: + """Extract OAuth access token from MCP request context. + + This function checks if the current request was authenticated using + OAuth and returns the access token if available. + + Note: This does not perform token refresh. If the token is expired, + it will return None and the caller should initiate a new OAuth flow. + + Args: + mcp: FastMCP instance + + Returns: + OAuth access token if available and not expired, None otherwise + """ + if not settings.OAUTH_ENABLED: + return None + + context = mcp.get_context() + if not context or not context.request_context: + return None + + request = context.request_context.request + if not request: + return None + + # Check for OAuth token in custom header + oauth_token_id = request.headers.get("X-OAuth-Token-ID") + if oauth_token_id: + # Get token directly from store without refresh (sync context) + token = oauth_manager.token_store.get_token_by_id(oauth_token_id) + if token and not token.is_expired(): + return token.access_token + + # Check for OAuth token in Authorization header + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + # This could be an OAuth token + return auth_header[7:] # Remove 'Bearer ' prefix + + return None diff --git a/assisted_service_mcp/src/oauth/middleware.py b/assisted_service_mcp/src/oauth/middleware.py new file mode 100644 index 0000000..101eec7 --- /dev/null +++ b/assisted_service_mcp/src/oauth/middleware.py @@ -0,0 +1,320 @@ +"""MCP OAuth middleware for automatic authentication flow. + +Simplified implementation using structured models and centralized token storage. +""" + +import asyncio +from typing import Any, Dict, Optional + +from fastapi import Request, Response +from starlette.responses import JSONResponse + +from assisted_service_mcp.src.logger import log +from assisted_service_mcp.src.oauth.models import OAuthState +from assisted_service_mcp.src.oauth.manager import oauth_manager +from assisted_service_mcp.src.oauth.utils import open_browser_for_oauth + + +class MCPOAuthMiddleware: + """Middleware that handles automatic OAuth flow for MCP clients. + + Simplified version using structured models and centralized storage. + """ + + def __init__(self) -> None: + """Initialize middleware.""" + # Track pending authentication sessions + self.pending_auth_sessions: Dict[str, Dict[str, Any]] = {} + + async def handle_mcp_request(self, request: Request, call_next: Any) -> Response: + """Handle MCP requests and initiate OAuth if needed. + + Args: + request: FastAPI request + call_next: Next middleware/handler + + Returns: + Response from handler or OAuth flow + """ + # Check if this is an MCP request without authentication + if not self._is_mcp_request_without_auth(request): + return await call_next(request) + + client_id = self._get_client_identifier(request) + + # Try to use existing token + token = oauth_manager.token_store.get_access_token_by_client(client_id) + if token: + log.info("Using cached token for client %s", client_id) + return await self._create_authenticated_request(request, call_next, token) + + # Check if OAuth flow is already in progress + if self._has_pending_auth(client_id): + log.info("OAuth flow already in progress for client %s", client_id) + return await self._wait_for_oauth_completion(request, call_next, client_id) + + # Start new OAuth flow + return await self._start_new_oauth_flow(request, call_next, client_id) + + def _is_mcp_request_without_auth(self, request: Request) -> bool: + """Check if this is an MCP request without authentication. + + Args: + request: FastAPI request + + Returns: + True if MCP request without auth, False otherwise + """ + # Check if it's an MCP request + content_type = request.headers.get("content-type", "") + is_mcp_request = ( + request.url.path.startswith("/mcp") + or "mcp" in request.headers.get("user-agent", "").lower() + or content_type.startswith("application/json") + ) + + # Check if authentication is missing + has_auth = ( + request.headers.get("authorization") + or request.headers.get("ocm-offline-token") + or request.headers.get("x-oauth-token-id") + ) + + return is_mcp_request and not has_auth + + def _get_client_identifier(self, request: Request) -> str: + """Get a unique identifier for the MCP client. + + Args: + request: FastAPI request + + Returns: + Client identifier string + """ + user_agent = request.headers.get("user-agent", "unknown") + client_ip = ( + getattr(request.client, "host", "unknown") if request.client else "unknown" + ) + return f"{user_agent}_{client_ip}" + + def _has_pending_auth(self, client_id: str) -> bool: + """Check if client has pending authentication. + + Args: + client_id: Client identifier + + Returns: + True if auth is pending, False otherwise + """ + for session_info in self.pending_auth_sessions.values(): + if session_info.get("client_id") == client_id: + return True + return False + + async def _start_new_oauth_flow( + self, request: Request, call_next: Any, client_id: str + ) -> Response: + """Start a new OAuth flow for the client. + + Args: + request: FastAPI request + call_next: Next middleware/handler + client_id: Client identifier + + Returns: + Response (either success or timeout) + """ + log.info("MCP request detected without authentication, initiating OAuth flow") + + # Create authorization URL using the manager + auth_url, state_json = oauth_manager.create_authorization_url(client_id) + + # Parse state to get session ID + try: + state = OAuthState.from_json(state_json) + session_id = state.session_id + except ValueError: + return JSONResponse( + {"error": "Failed to create OAuth state"}, status_code=500 + ) + + # Store session info + self.pending_auth_sessions[session_id] = { + "client_id": client_id, + "state": state_json, + "auth_url": auth_url, + "timestamp": asyncio.get_event_loop().time(), + } + + # Automatically open browser + open_browser_for_oauth(auth_url) + + log.info( + "OAuth flow initiated for client %s, waiting for completion", client_id + ) + return await self._wait_for_oauth_completion(request, call_next, client_id) + + async def _wait_for_oauth_completion( + self, request: Request, call_next: Any, client_id: str + ) -> Response: + """Wait for OAuth completion and handle the result. + + Args: + request: FastAPI request + call_next: Next middleware/handler + client_id: Client identifier + + Returns: + Response (either authenticated request or timeout) + """ + max_wait_time = 60 # 1 minute + poll_interval = 1 # 1 second + waited_time = 0 + + while waited_time < max_wait_time: + await asyncio.sleep(poll_interval) + waited_time += poll_interval + + # Check if OAuth completed (token available for client) + token = oauth_manager.token_store.get_access_token_by_client(client_id) + if token: + log.info( + "OAuth completed for client %s, proceeding with request", client_id + ) + # Clean up pending session + self._cleanup_client_sessions(client_id) + return await self._create_authenticated_request( + request, call_next, token + ) + + # OAuth timed out + log.warning("OAuth timed out for client %s", client_id) + + # Get auth URL for error message BEFORE cleanup + auth_url = None + for session_info in self.pending_auth_sessions.values(): + if session_info.get("client_id") == client_id: + auth_url = session_info.get("auth_url") + break + + # Now clean up the sessions + self._cleanup_client_sessions(client_id) + + return self._create_timeout_response(auth_url) + + def _cleanup_client_sessions(self, client_id: str) -> None: + """Clean up all pending sessions for a client. + + Args: + client_id: Client identifier + """ + sessions_to_remove = [ + session_id + for session_id, session_info in self.pending_auth_sessions.items() + if session_info.get("client_id") == client_id + ] + for session_id in sessions_to_remove: + del self.pending_auth_sessions[session_id] + + async def _create_authenticated_request( + self, request: Request, call_next: Any, token: str + ) -> Response: + """Create a new request with authentication token. + + Args: + request: Original request + call_next: Next middleware/handler + token: Access token + + Returns: + Response from handler + """ + # Modify request headers for Starlette + new_headers = list(request.scope.get("headers", [])) + # Remove any existing authorization header + new_headers = [(k, v) for k, v in new_headers if k.lower() != b"authorization"] + # Add the new authorization header + new_headers.append((b"authorization", f"Bearer {token}".encode())) + + # Create new scope with updated headers + new_scope = dict(request.scope) + new_scope["headers"] = new_headers + + # Create new request with updated scope + from starlette.requests import Request as StarletteRequest + + new_request = StarletteRequest(new_scope, request.receive) + return await call_next(new_request) + + def _create_timeout_response(self, auth_url: Optional[str] = None) -> JSONResponse: + """Create timeout response for failed OAuth. + + Args: + auth_url: Optional auth URL to include in response + + Returns: + JSON response with timeout error + """ + if auth_url: + return JSONResponse( + { + "type": "oauth_timeout", + "message": "OAuth authentication timed out or failed", + "auth_url": auth_url, + "instructions": [ + "1. Authentication timed out or failed", + "2. You can try the authentication URL manually:", + f" {auth_url}", + "3. Or reconnect to the MCP server to try again", + ], + }, + status_code=401, + ) + return JSONResponse( + { + "type": "oauth_timeout", + "message": "OAuth authentication timed out", + "instructions": [ + "Authentication took too long or failed", + "Please try reconnecting to the MCP server", + ], + }, + status_code=401, + ) + + async def handle_oauth_callback(self, request: Request) -> Response: + """Handle OAuth callback and complete authentication. + + Note: This is now handled by oauth.oauth_callback_handler. + This method is kept for backward compatibility but delegates to the main handler. + + Args: + request: FastAPI request + + Returns: + Response + """ + from assisted_service_mcp.src.oauth.manager import oauth_callback_handler + + return await oauth_callback_handler(request) + + def cleanup_expired_sessions(self, max_age_seconds: int = 600) -> None: + """Clean up expired authentication sessions. + + Args: + max_age_seconds: Maximum session age in seconds (default 10 minutes) + """ + current_time = asyncio.get_event_loop().time() + expired_sessions = [ + session_id + for session_id, info in self.pending_auth_sessions.items() + if current_time - info["timestamp"] > max_age_seconds + ] + + for session_id in expired_sessions: + del self.pending_auth_sessions[session_id] + log.info("Cleaned up expired OAuth session: %s", session_id) + + +# Global middleware instance +mcp_oauth_middleware = MCPOAuthMiddleware() diff --git a/assisted_service_mcp/src/oauth/models.py b/assisted_service_mcp/src/oauth/models.py new file mode 100644 index 0000000..064c08b --- /dev/null +++ b/assisted_service_mcp/src/oauth/models.py @@ -0,0 +1,159 @@ +"""OAuth data models for type safety and clarity.""" + +import json +import time +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class OAuthToken: + """Unified OAuth token model. + + This consolidates token information that was previously spread across + multiple dictionaries with different structures. + + Supports dictionary-style access for backward compatibility. + """ + + token_id: str + client_id: str + access_token: str + refresh_token: Optional[str] + expires_at: float + token_type: str = "Bearer" + + def is_expired(self) -> bool: + """Check if token is expired. + + Returns: + True if token is expired, False otherwise + """ + return time.time() >= self.expires_at + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "token_id": self.token_id, + "client_id": self.client_id, + "access_token": self.access_token, + "token_type": self.token_type, + "refresh_token": self.refresh_token, + "expires_at": self.expires_at, + } + + @classmethod + def from_dict(cls, data: dict) -> "OAuthToken": + """Create from dictionary.""" + return cls( + token_id=data["token_id"], + client_id=data["client_id"], + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + expires_at=data["expires_at"], + token_type=data.get("token_type", "Bearer"), + ) + + def __getitem__(self, key: str) -> Any: + """Support dictionary-style access for backward compatibility. + + Args: + key: Attribute name + + Returns: + Attribute value + """ + if key == "token_id": + return self.token_id + if key == "access_token": + return self.access_token + if key == "token_type": + return self.token_type + if key == "refresh_token": + return self.refresh_token + if key == "expires_at": + return self.expires_at + if key == "client_id": + return self.client_id + raise KeyError(f"Unknown key: {key}") + + def __contains__(self, key: str) -> bool: + """Support 'in' operator for backward compatibility.""" + return key in [ + "token_id", + "access_token", + "token_type", + "refresh_token", + "expires_at", + "client_id", + ] + + def get(self, key: str, default: Any = None) -> Any: + """Support dict.get() for backward compatibility.""" + try: + return self[key] + except KeyError: + return default + + +@dataclass +class OAuthState: + """Structured OAuth state for CSRF protection. + + Replaces string concatenation with proper JSON serialization. + """ + + session_id: str + client_id: str + timestamp: float + code_verifier: str + + def to_json(self) -> str: + """Serialize to JSON string for use as OAuth state parameter. + + Returns: + JSON string representation + """ + return json.dumps( + { + "session_id": self.session_id, + "client_id": self.client_id, + "timestamp": self.timestamp, + "code_verifier": self.code_verifier, + } + ) + + @classmethod + def from_json(cls, state_str: str) -> "OAuthState": + """Deserialize from JSON string. + + Args: + state_str: JSON string from OAuth state parameter + + Returns: + OAuthState instance + + Raises: + ValueError: If state string is invalid + """ + try: + data = json.loads(state_str) + return cls( + session_id=data["session_id"], + client_id=data["client_id"], + timestamp=data["timestamp"], + code_verifier=data["code_verifier"], + ) + except (json.JSONDecodeError, KeyError) as e: + raise ValueError(f"Invalid OAuth state: {e}") from e + + def is_expired(self, max_age_seconds: int = 600) -> bool: + """Check if state is too old. + + Args: + max_age_seconds: Maximum age in seconds (default 10 minutes) + + Returns: + True if state is expired, False otherwise + """ + return time.time() - self.timestamp > max_age_seconds diff --git a/assisted_service_mcp/src/oauth/store.py b/assisted_service_mcp/src/oauth/store.py new file mode 100644 index 0000000..e2d2ceb --- /dev/null +++ b/assisted_service_mcp/src/oauth/store.py @@ -0,0 +1,189 @@ +"""Centralized OAuth token storage. + +This replaces the multiple token storage mechanisms that were spread across +oauth.py and mcp_oauth_middleware.py. +""" + +import time +from typing import Dict, Optional + +from assisted_service_mcp.src.logger import log +from assisted_service_mcp.src.oauth.models import OAuthToken + + +class TokenStore: + """Centralized token storage with clear interface. + + This consolidates: + - oauth_manager._tokens (token_id -> token data) + - mcp_oauth_middleware.completed_tokens (client_id -> token_id) + + Into a single, well-defined storage mechanism. + """ + + def __init__(self) -> None: + """Initialize token store.""" + self._tokens: Dict[str, OAuthToken] = {} + self._client_tokens: Dict[str, str] = {} # client_id -> token_id + + def store_token(self, token: OAuthToken) -> None: + """Store a token and associate it with a client. + + Args: + token: OAuthToken to store + """ + self._tokens[token.token_id] = token + self._client_tokens[token.client_id] = token.token_id + log.debug( + "Stored token %s for client %s (expires at %s)", + token.token_id, + token.client_id, + token.expires_at, + ) + + def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]: + """Get token by token ID. + + Args: + token_id: Token identifier + + Returns: + OAuthToken if found and valid, None otherwise + """ + token = self._tokens.get(token_id) + if not token: + return None + + if token.is_expired(): + log.debug("Token %s is expired, removing from store", token_id) + self.remove_token(token_id) + return None + + return token + + def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]: + """Get token for a client. + + Args: + client_id: Client identifier + + Returns: + OAuthToken if found and valid, None otherwise + """ + token_id = self._client_tokens.get(client_id) + if not token_id: + return None + + return self.get_token_by_id(token_id) + + def get_access_token_by_id(self, token_id: str) -> Optional[str]: + """Get access token string by token ID. + + Args: + token_id: Token identifier + + Returns: + Access token string if found and valid, None otherwise + """ + token = self.get_token_by_id(token_id) + return token.access_token if token else None + + def get_access_token_by_client(self, client_id: str) -> Optional[str]: + """Get access token string for a client. + + Args: + client_id: Client identifier + + Returns: + Access token string if found and valid, None otherwise + """ + token = self.get_token_by_client(client_id) + return token.access_token if token else None + + def update_token( + self, + token_id: str, + access_token: str, + refresh_token: Optional[str], + expires_at: float, + ) -> bool: + """Update an existing token (e.g., after refresh). + + Args: + token_id: Token identifier + access_token: New access token + refresh_token: New refresh token (optional) + expires_at: New expiration timestamp + + Returns: + True if token was updated, False if token not found + """ + token = self._tokens.get(token_id) + if not token: + return False + + token.access_token = access_token + if refresh_token: + token.refresh_token = refresh_token + token.expires_at = expires_at + + log.debug("Updated token %s (new expiry: %s)", token_id, expires_at) + return True + + def remove_token(self, token_id: str) -> None: + """Remove a token and its associations. + + Args: + token_id: Token identifier to remove + """ + token = self._tokens.pop(token_id, None) + if token: + self._client_tokens.pop(token.client_id, None) + log.debug("Removed token %s for client %s", token_id, token.client_id) + + def remove_client_token(self, client_id: str) -> None: + """Remove token associated with a client. + + Args: + client_id: Client identifier + """ + token_id = self._client_tokens.get(client_id) + if token_id: + self.remove_token(token_id) + + def cleanup_expired_tokens(self) -> int: + """Clean up expired tokens. + + Returns: + Number of tokens removed + """ + current_time = time.time() + expired_token_ids = [ + token_id + for token_id, token in self._tokens.items() + if current_time >= token.expires_at + ] + + for token_id in expired_token_ids: + self.remove_token(token_id) + + if expired_token_ids: + log.info("Cleaned up %d expired tokens", len(expired_token_ids)) + + return len(expired_token_ids) + + def get_all_tokens(self) -> Dict[str, OAuthToken]: + """Get all stored tokens (for debugging/monitoring). + + Returns: + Dictionary of token_id -> OAuthToken + """ + return self._tokens.copy() + + def get_token_count(self) -> int: + """Get number of stored tokens. + + Returns: + Number of tokens in store + """ + return len(self._tokens) diff --git a/assisted_service_mcp/src/oauth/utils.py b/assisted_service_mcp/src/oauth/utils.py new file mode 100644 index 0000000..6853649 --- /dev/null +++ b/assisted_service_mcp/src/oauth/utils.py @@ -0,0 +1,102 @@ +"""Shared OAuth utilities to reduce code duplication.""" + +import webbrowser +from typing import Any, Dict + +from fastapi import Request + +from assisted_service_mcp.src.logger import log + + +def open_browser_for_oauth(auth_url: str) -> None: + """Open browser for OAuth authentication with error handling. + + Args: + auth_url: The OAuth authorization URL to open + """ + try: + webbrowser.open(auth_url) + log.info("Opened browser for OAuth authentication: %s", auth_url) + except Exception as e: + log.warning("Could not open browser automatically: %s", e) + + +def get_oauth_success_html(is_mcp_flow: bool = False, session_id: str = "") -> str: + """Generate OAuth success HTML page. + + Args: + is_mcp_flow: Whether this is an MCP automatic flow + session_id: Session ID for display (optional) + + Returns: + HTML content for success page + """ + instructions_html = "" + if is_mcp_flow: + instructions_html = """ +
+

For MCP Clients (Cursor/Copilot):

+
    +
  1. Close this browser window
  2. +
  3. Return to Cursor
  4. +
  5. Try your MCP command again (e.g., "list my clusters")
  6. +
  7. The connection should now work with your authenticated token
  8. +
+
+ """ + else: + instructions_html = """ +
+

Next Steps:

+
    +
  1. Close this browser window
  2. +
  3. Return to Cursor/Copilot
  4. +
  5. Your MCP connection should now work automatically
  6. +
+
+ """ + + session_display = ( + f"

Session ID: {session_id}

" if session_id else "" + ) + auto_close_delay = "3000" if is_mcp_flow else "5000" + return f""" + + + Authentication Successful + + + +

🎉 Authentication Successful!

+

You have successfully authenticated with the Assisted Service MCP server.

+ {instructions_html} + {session_display} + + + + """ + + +def extract_oauth_callback_params(request: Request) -> Dict[str, Any]: + """Extract and validate OAuth callback parameters. + + Args: + request: FastAPI request object + + Returns: + Dictionary with code, state, and error parameters + """ + return { + "code": request.query_params.get("code"), + "state": request.query_params.get("state"), + "error": request.query_params.get("error"), + } diff --git a/assisted_service_mcp/src/settings.py b/assisted_service_mcp/src/settings.py index f871ddf..8f5fed1 100644 --- a/assisted_service_mcp/src/settings.py +++ b/assisted_service_mcp/src/settings.py @@ -106,6 +106,52 @@ class Settings(BaseSettings): }, ) + # OAuth Configuration + OAUTH_ENABLED: bool = Field( + default=False, + json_schema_extra={ + "env": "OAUTH_ENABLED", + "description": "Enable OAuth authentication flow", + "example": True, + }, + ) + + OAUTH_URL: str = Field( + default="https://sso.redhat.com/auth/realms/redhat-external", + json_schema_extra={ + "env": "OAUTH_URL", + "description": "OAuth authorization server URL", + "example": "https://sso.redhat.com/auth/realms/redhat-external", + }, + ) + + OAUTH_CLIENT: str = Field( + default="ocm-cli", + json_schema_extra={ + "env": "OAUTH_CLIENT", + "description": "OAuth client identifier", + "example": "ocm-cli", + }, + ) + + SELF_URL: str = Field( + default="http://localhost:8000", + json_schema_extra={ + "env": "SELF_URL", + "description": "Base URL that the server uses to construct URLs referencing itself", + "example": "https://my.host.com", + }, + ) + + OAUTH_REDIRECT_URI: Optional[str] = Field( + default=None, + json_schema_extra={ + "env": "OAUTH_REDIRECT_URI", + "description": "Override OAuth redirect URI (optional - automatically constructed from SELF_URL with 127.0.0.1 for localhost)", + "example": "http://127.0.0.1:8000/oauth/callback", + }, + ) + # Logging Configuration LOGGING_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default="INFO", diff --git a/assisted_service_mcp/utils/auth.py b/assisted_service_mcp/utils/auth.py index 3016b67..bf175a2 100644 --- a/assisted_service_mcp/utils/auth.py +++ b/assisted_service_mcp/utils/auth.py @@ -1,10 +1,10 @@ """Authentication utilities for Assisted Service MCP Server.""" -from typing import Any, Callable +from typing import Any, Callable, Optional import requests from assisted_service_mcp.src.logger import log -from assisted_service_mcp.src.settings import get_setting +from assisted_service_mcp.src.settings import get_setting, settings def get_offline_token(mcp: Any) -> str: @@ -46,28 +46,36 @@ def get_offline_token(mcp: Any) -> str: def get_access_token( - mcp: Any, offline_token_func: Callable[[], str] | None = None + mcp: Any, + oauth_token_func: Callable[[Any], Optional[str]] | None = None, ) -> str: """ Retrieve the access token. - This function tries to get the Red Hat OpenShift Cluster Manager (OCM) access token. First - it tries to extract it from the authorization header, and if it isn't there then it tries - to generate a new one using the offline token. + Authentication methods are checked in the following order of priority: + 1. Access token in the Authorization request header + 2. OAuth flow (if OAUTH_ENABLED is true) - no fallback to offline token + 3. Offline token via environment variable (only if OAuth is disabled) + + When OAuth is enabled, offline token fallback is disabled to ensure consistent + OAuth-only authentication flow. + + Note: OCM-Offline-Token header support is deprecated but still functional for backward compatibility. Args: mcp: The FastMCP instance to get request context from. - offline_token_func: Optional function to get offline token. If not provided, - uses get_offline_token(mcp). + oauth_token_func: Optional function to get OAuth token. If not provided, + OAuth flow will not be attempted. Returns: str: The access token. Raises: - RuntimeError: If it isn't possible to obtain or generate the access token. + RuntimeError: If no valid authentication method is available or authentication fails. """ - log.debug("Attempting to retrieve access token") - # First try to get the token from the authorization header: + log.debug("Attempting to retrieve access token using priority order") + + # 1. First try to get the token from the authorization header: context = mcp.get_context() if context and context.request_context: request = context.request_context.request @@ -79,14 +87,35 @@ def get_access_token( log.debug("Found access token in authorization header") return parts[1] - # Now try to get the offline token, and generate a new access token from it: - log.debug("Generating new access token from offline token") + # 2. Try OAuth flow if enabled + if settings.OAUTH_ENABLED: + log.debug("OAuth is enabled, checking for OAuth access token") + if oauth_token_func: + oauth_token = oauth_token_func(mcp) + if oauth_token: + log.debug("Found OAuth access token (priority 2)") + return oauth_token + log.debug( + "OAuth token function returned None - OAuth flow may be in progress" + ) + else: + log.debug( + "OAuth enabled but no oauth_token_func provided - skipping OAuth priority" + ) + + # When OAuth is enabled, don't fall back to offline token + log.error( + "OAuth is enabled but no valid OAuth token found - offline token fallback disabled" + ) + raise RuntimeError( + "OAuth authentication is enabled but no valid OAuth token found. " + "Please complete the OAuth authentication flow." + ) + + # 3. & 4. Try offline token methods (environment variable has priority over header) + log.debug("Generating new access token from offline token (priority 3 & 4)") - # Use the provided offline token function or default to get_offline_token(mcp) - if offline_token_func is None: - offline_token = get_offline_token(mcp) - else: - offline_token = offline_token_func() + offline_token = get_offline_token(mcp) params = { "client_id": "cloud-services", @@ -113,5 +142,5 @@ def get_access_token( "Invalid SSO response: missing or malformed access_token" ) from e - log.debug("Successfully generated new access token") + log.debug("Successfully generated new access token from offline token") return access_token diff --git a/doc/OAUTH_SETUP.md b/doc/OAUTH_SETUP.md new file mode 100644 index 0000000..d755f81 --- /dev/null +++ b/doc/OAUTH_SETUP.md @@ -0,0 +1,123 @@ +# OAuth Authentication Setup + +This MCP server supports automatic OAuth authentication with Red Hat SSO for seamless integration with MCP clients. + +## Quick Start + +1. **Start the OAuth-enabled server**: + ```bash + ./start-oauth-server.sh + ``` + +2. **Configure your MCP client** (Cursor): +```json +{ + "mcpServers": { + "assisted-local-oauth": { + "transport": "streamable-http", + "url": "http://localhost:8000/mcp" + } + } +} +``` + +3. **Connect from your MCP client** - OAuth flow will start automatically! + +## How It Works + +1. **Automatic Detection**: When Cursor connects without credentials, the server detects this and initiates OAuth flow +2. **Browser Authentication**: A browser window will open automatically for Red Hat SSO authentication +3. **Token Caching**: After successful authentication, access and refresh tokens are cached for the client +4. **Automatic Refresh**: Expired tokens are automatically refreshed using refresh tokens (5 minutes before expiry) +5. **Seamless Reconnection**: Subsequent connections use cached tokens with transparent refresh + +## Configuration + +The OAuth configuration is stored in `oauth-config.env`: + +```bash +# OAuth Configuration +OAUTH_ENABLED=true +OAUTH_URL=https://sso.redhat.com/auth/realms/redhat-external +OAUTH_CLIENT=ocm-cli +SELF_URL=http://127.0.0.1:8000 + +# Server Configuration +MCP_HOST=0.0.0.0 +MCP_PORT=8000 +TRANSPORT="streamable-http" + +# Logging +LOGGING_LEVEL=DEBUG +LOG_TO_FILE=false + +# Assisted Service API +INVENTORY_URL=https://api.openshift.com/api/assisted-install/v2 +SSO_URL=https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token +``` + +## Available Endpoints + +| Endpoint | Purpose | +|----------|---------| +| `/mcp` | Main MCP endpoint with OAuth middleware | +| `/oauth/register` | OAuth registration for MCP clients | +| `/oauth/callback` | OAuth callback handler | +| `/oauth/token` | Token exchange endpoint | +| `/.well-known/*` | OAuth discovery endpoints | +| `/metrics` | Prometheus metrics | + +## Security Features + +- **PKCE (Proof Key for Code Exchange)**: Enhanced OAuth security +- **State Parameter**: CSRF protection +- **Token Caching**: Secure in-memory token storage per client +- **Automatic Cleanup**: Expired sessions are cleaned up automatically + +## Authentication Priority Order + +The server follows this authentication priority: + +1. **Authorization Header**: `Bearer ` in request headers +2. **OAuth Flow**: Automatic OAuth if enabled and no token found (no fallback to offline token) +3. **Offline Token (Environment)**: `OFFLINE_TOKEN` environment variable (only when OAuth is disabled) + +**Important**: When OAuth is enabled, offline token fallback is disabled to ensure consistent OAuth-only authentication. + +## Troubleshooting + + +### Tools Not Loading in Cursor / "Loading Tools" Hangs +**Symptom**: After OAuth authentication completes, Cursor shows "Loading tools..." indefinitely + +**Cause**: Cursor's initial MCP connection request gets a 401 (OAuth required), but after OAuth completes, Cursor doesn't automatically retry the connection + +**Solution**: +1. **Complete OAuth authentication** in the browser (this works correctly) +2. **Reload the MCP connection** in Cursor: + - Go to Cursor settings → MCP + - Disable and re-enable the `assisted-local-oauth` server, OR + - Restart Cursor +3. **The connection will now work** with your authenticated token + +**Alternative**: Use the status endpoint to check authentication: +- GET `http://127.0.0.1:8000/oauth/status?client_id=` +- Returns authentication status for debugging + +### Server Won't Start +- Check if port 8000 is already in use +- Verify all dependencies are installed +- Check `oauth-config.env` file exists and is properly formatted + +## Development Notes + +- **Client Identification**: Clients are identified by User-Agent + IP address +- **Session Management**: OAuth sessions are stored in memory (not persistent across restarts) +- **Token Expiration**: Tokens are cached until server restart (no automatic refresh yet) +- **Middleware Integration**: OAuth middleware is integrated with FastMCP/Starlette + +## Production Considerations + +For production deployment: +- Use HTTPS for `SELF_URL` +- Use environment-specific OAuth clients diff --git a/doc/oauth_authentication.md b/doc/oauth_authentication.md new file mode 100644 index 0000000..2e5dad4 --- /dev/null +++ b/doc/oauth_authentication.md @@ -0,0 +1,227 @@ +# OAuth Authentication for Assisted Service MCP Server + +This document describes the OAuth authentication implementation for the Assisted Service MCP Server. + +## Overview + +The server supports multiple authorization methods for accessing the Assisted Installer API. The method used depends on the environment variables and headers you provide. The following methods are checked in order of priority; the first one that succeeds will be used, and the rest will be ignored. + +## Authentication Priority Order + +### 1. Access token in the `Authorization` request header + +If the `Authorization` request header contains a bearer token, it will be passed directly to the Assisted Installer API. In this case, the OAuth flow will not be triggered, and any values provided in the `OFFLINE_TOKEN` environment variable or the `OCM-Offline-Token` request header will be ignored. + +**Example:** +```http +Authorization: Bearer ACCESS_TOKEN_HERE +``` + +### 2. OAuth flow + +If the `OAUTH_ENABLED` environment variable is set to `true`, the server will use a subset of the OAuth protocol that MCP clients (such as the one in VS Code) use for authentication. When you attempt to connect, the MCP client will open a browser window where you can enter your credentials. The client will then request an access token, which the server will use to authenticate requests to the Assisted Installer API. + +When using this authentication method, the `OFFLINE_TOKEN` environment variable and the `OCM-Offline-Token` header will be ignored. + +### 3. Offline token via environment variable + +If you set the `OFFLINE_TOKEN` environment variable, the server will use this offline token to request an access token, which will then be used to call the Assisted Installer API. + +### 4. Offline token via request header + +If the `OCM-Offline-Token` request header is set, the server will use it to request an access token, and will then use that access token to call the Assisted Installer API. + +## OAuth Configuration + +### Environment Variables + +You can configure the OAuth authorization server and client identifier using the following environment variables: + +| Variable | Default Value | Description | +|----------|---------------|-------------| +| `OAUTH_ENABLED` | `false` | Enable OAuth authentication flow | +| `OAUTH_URL` | `https://sso.redhat.com/auth/realms/redhat-external` | OAuth authorization server URL | +| `OAUTH_CLIENT` | `ocm-cli` | OAuth client identifier | +| `SELF_URL` | `http://localhost:8000` | Base URL that the server uses to construct URLs referencing itself | + +### SELF_URL Configuration + +The `SELF_URL` environment variable specifies the base URL that the server uses to construct URLs referencing itself. For example, when OAuth is enabled, the server will generate the dynamic client registration URL by appending `/oauth/register` to this base URL. + +- **Default:** `http://localhost:8000` +- **Production:** Should be set to the actual URL of the server as accessible to clients + +**Examples:** +- Local development: `http://localhost:8000` +- Production with reverse proxy: `https://my.host.com` +- Production with custom port: `https://my.host.com:8443` + +## OAuth Endpoints + +When OAuth is enabled, the server exposes the following endpoints: + +### `/oauth/register` (GET) + +Returns OAuth configuration that MCP clients need to initiate the OAuth flow. + +**Response:** +```json +{ + "authorization_endpoint": "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/auth?client_id=cloud-services&...", + "token_endpoint": "http://localhost:8000/oauth/token", + "client_id": "cloud-services", + "redirect_uri": "http://localhost:8000/oauth/callback", + "state": "random_state_string", + "response_type": "code", + "scope": "openid profile email" +} +``` + +### `/oauth/callback` (GET) + +Handles the OAuth callback from the authorization server. This endpoint: +- Receives the authorization code from the OAuth provider +- Exchanges it for an access token +- Displays a success/failure page to the user + +**Parameters:** +- `code`: Authorization code from OAuth provider +- `state`: OAuth state parameter for CSRF protection +- `error`: Error code if authentication failed + +### `/oauth/token` (POST) + +Handles OAuth token requests from MCP clients. This endpoint is used by MCP clients to exchange authorization codes for access tokens. + +**Request Body:** +```json +{ + "grant_type": "authorization_code", + "code": "authorization_code_from_callback", + "state": "oauth_state_parameter" +} +``` + +**Response:** +```json +{ + "access_token": "ACCESS_TOKEN_HERE", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refresh_token_if_available", + "scope": "openid profile email" +} +``` + +## OAuth Flow Sequence + +1. **Client Registration**: MCP client calls `/oauth/register` to get OAuth configuration +2. **Authorization**: Client redirects user to authorization endpoint with PKCE challenge +3. **User Authentication**: User authenticates with OAuth provider (Red Hat SSO) +4. **Callback**: OAuth provider redirects to `/oauth/callback` with authorization code +5. **Token Exchange**: Client calls `/oauth/token` to exchange code for access token +6. **API Access**: Client uses access token to authenticate API requests + +## Security Features + +### PKCE (Proof Key for Code Exchange) + +The implementation uses PKCE (RFC 7636) for enhanced security: +- Generates cryptographically random code verifier +- Creates SHA256-based code challenge +- Protects against authorization code interception attacks + +### State Parameter + +Uses OAuth state parameter for CSRF protection: +- Generates cryptographically random state values +- Validates state on callback to prevent CSRF attacks + +### Token Storage + +- Access tokens are stored temporarily in memory +- Tokens are associated with random identifiers +- No sensitive data is logged + +## Usage Examples + +### Enable OAuth Authentication + +```bash +export OAUTH_ENABLED=true +export OAUTH_URL=https://sso.redhat.com/auth/realms/redhat-external +export OAUTH_CLIENT=cloud-services +export SELF_URL=https://my-mcp-server.com +``` + +### MCP Client Configuration + +When OAuth is enabled, MCP clients should: + +1. Call `/oauth/register` to get OAuth configuration +2. Open authorization URL in browser for user authentication +3. Handle the callback and extract authorization code +4. Exchange code for access token via `/oauth/token` +5. Use access token in `Authorization: Bearer ` header + +### Testing OAuth Flow + +You can test the OAuth endpoints manually: + +```bash +# Get OAuth configuration +curl http://localhost:8000/oauth/register + +# Test callback with error +curl "http://localhost:8000/oauth/callback?error=access_denied" + +# Test token endpoint with invalid grant +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/json" \ + -d '{"grant_type": "client_credentials", "code": "test", "state": "test"}' +``` + +## Troubleshooting + +### Common Issues + +1. **OAuth endpoints return 404** + - Ensure `OAUTH_ENABLED=true` is set + - Restart the server after changing environment variables + +2. **Invalid OAuth state error** + - State parameters expire after use + - Ensure client uses the state from `/oauth/register` response + +3. **Token exchange fails** + - Verify authorization code is valid and not expired + - Check that redirect_uri matches exactly + +4. **SELF_URL misconfiguration** + - Ensure SELF_URL is accessible from client browsers + - Include protocol (http/https) and correct port + +### Debug Logging + +Enable debug logging to troubleshoot OAuth issues: + +```bash +export LOGGING_LEVEL=DEBUG +``` + +Look for log messages with OAuth-related information: +- OAuth registration requests +- Token exchange attempts +- Authentication priority decisions + +## Integration with MCP Clients + +MCP clients (like VS Code extensions) can integrate with this OAuth implementation by: + +1. **Discovery**: Call `/oauth/register` to get OAuth configuration +2. **Browser Flow**: Open authorization URL in system browser +3. **Callback Handling**: Set up local server to handle OAuth callback +4. **Token Management**: Store and refresh access tokens as needed +5. **API Authentication**: Include tokens in `Authorization` header + +The server will automatically detect OAuth tokens and use them according to the authentication priority order. diff --git a/oauth-config.env b/oauth-config.env new file mode 100644 index 0000000..25b0420 --- /dev/null +++ b/oauth-config.env @@ -0,0 +1,18 @@ +# OAuth Configuration for Local Development +OAUTH_ENABLED=true +OAUTH_URL=https://sso.redhat.com/auth/realms/redhat-external +OAUTH_CLIENT=ocm-cli +SELF_URL=http://127.0.0.1:8000 + +# MCP Server Configuration +MCP_HOST=0.0.0.0 +MCP_PORT=8000 +TRANSPORT="streamable-http" + +# Logging Configuration +LOGGING_LEVEL=DEBUG +LOG_TO_FILE=false + +# Assisted Service Configuration +INVENTORY_URL=https://api.openshift.com/api/assisted-install/v2 +SSO_URL=https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token diff --git a/pyproject.toml b/pyproject.toml index 9ba6f29..802b82e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,16 +9,20 @@ dependencies = [ "mcp>=1.15.0", "netaddr>=1.3.0", "requests>=2.32.3", + "httpx>=0.27.0", "retry>=0.9.2", "prometheus_client>=0.22.1", "pyyaml>=6", "jinja2>=3.1", "pydantic>=2.12.1", "pydantic-settings>=2.6.0", - "python-dotenv>=1.0.0", + "python-dotenv>=1.0.0", "nestedarchive>=0.2.4", "tabulate>=0.9.0", "fastapi>=0.115.0", + "authlib>=1.3.0", + "cryptography>=41.0.0", + "uvicorn>=0.30.0", ] [dependency-groups] @@ -90,7 +94,7 @@ add-ignore=[ ] [tool.mypy] -exclude = "assisted_service_mcp/src/utils/log_analyzer/" +exclude = ["assisted_service_mcp/src/utils/log_analyzer/", "tests/"] follow_imports = "skip" explicit_package_bases = true disallow_untyped_calls = true diff --git a/start-oauth-server.sh b/start-oauth-server.sh new file mode 100755 index 0000000..6fc02e8 --- /dev/null +++ b/start-oauth-server.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -e + +echo "Starting Assisted Service MCP Server with OAuth Authentication..." +echo + +# Load OAuth configuration +if [ -f oauth-config.env ]; then + export $(grep -v '^#' oauth-config.env | xargs) +else + echo "Error: oauth-config.env not found!" + exit 1 +fi + +echo "Configuration:" +echo " OAuth Enabled: $OAUTH_ENABLED" +echo " OAuth Client: $OAUTH_CLIENT" +echo " Server: $MCP_HOST:$MCP_PORT" +echo " Transport: $TRANSPORT" +echo + +echo "OAuth Endpoints:" +echo " Registration: $SELF_URL/oauth/register" +echo " Callback: $SELF_URL/oauth/callback" +echo " Token: $SELF_URL/oauth/token" +echo + +echo "MCP Client Configuration:" +echo " Add this to your Cursor MCP settings:" +echo " {" +echo " \"assisted-local-oauth\": {" +echo " \"transport\": \"streamable-http\"," +echo " \"url\": \"$SELF_URL/mcp\"" +echo " }" +echo " }" +echo + +echo "How it works:" +echo " 1. Cursor connects -> OAuth flow starts automatically" +echo " 2. Browser opens for Red Hat SSO authentication" +echo " 3. After authentication, connection proceeds automatically" +echo " 4. Subsequent connections use cached token" +echo + +echo "Starting MCP server..." +echo "Press Ctrl+C to stop" +echo + +python -m assisted_service_mcp.src.main diff --git a/tests/src/test_mcp.py b/tests/src/test_mcp.py index 9d3260e..4578989 100644 --- a/tests/src/test_mcp.py +++ b/tests/src/test_mcp.py @@ -6,7 +6,6 @@ def test_mcp_registers_tools_and_auth_closures() -> None: server = mod.AssistedServiceMCPServer() # Check closures exist - assert hasattr(server, "_get_offline_token") assert hasattr(server, "_get_access_token") # List tools diff --git a/tests/test_auth_priority.py b/tests/test_auth_priority.py new file mode 100644 index 0000000..fef1c78 --- /dev/null +++ b/tests/test_auth_priority.py @@ -0,0 +1,101 @@ +"""Tests for authentication priority order implementation.""" + +from unittest.mock import MagicMock, patch + + +from assisted_service_mcp.utils.auth import get_access_token + + +class TestAuthenticationPriority: + """Test cases for authentication priority order.""" + + mock_mcp: MagicMock + mock_context: MagicMock + mock_request: MagicMock + mock_headers: MagicMock + + def setup_method(self) -> None: + """Set up test fixtures.""" + self.mock_mcp = MagicMock() + self.mock_context = MagicMock() + self.mock_request = MagicMock() + + # Setup mock headers with proper get method + self.mock_headers = MagicMock() + self.mock_request.headers = self.mock_headers + + self.mock_context.request_context.request = self.mock_request + self.mock_mcp.get_context.return_value = self.mock_context + + def test_priority_1_authorization_header_bearer_token(self) -> None: + """Test priority 1: Access token in Authorization header.""" + # Setup Authorization header with Bearer token + self.mock_headers.get.return_value = "Bearer test_access_token" + + with patch("assisted_service_mcp.src.settings.settings.OAUTH_ENABLED", False): + token = get_access_token(self.mock_mcp) + + assert token == "test_access_token" + + @patch("assisted_service_mcp.src.settings.settings.OAUTH_ENABLED", True) + def test_priority_1_overrides_oauth(self) -> None: + """Test that Authorization header (priority 1) overrides OAuth (priority 2).""" + # Setup Authorization header + self.mock_headers.get.return_value = "Bearer priority_1_token" + + def mock_oauth_func(_mcp): + # This should not be called + return "oauth_token" + + token = get_access_token(self.mock_mcp, oauth_token_func=mock_oauth_func) + + # Should return Authorization header token, not OAuth token + assert token == "priority_1_token" + + def test_priority_2_oauth_when_no_auth_header(self) -> None: + """Test priority 2: OAuth flow when no Authorization header.""" + + # Setup headers to return None for Authorization header + def mock_header_get(key, default=None): + if key == "Authorization": + return None + return default + + self.mock_headers.get.side_effect = mock_header_get + + def mock_oauth_func(_mcp): + return "oauth_access_token" + + with patch("assisted_service_mcp.utils.auth.settings.OAUTH_ENABLED", True): + token = get_access_token(self.mock_mcp, oauth_token_func=mock_oauth_func) + + assert token == "oauth_access_token" + + @patch("assisted_service_mcp.src.settings.settings.OAUTH_ENABLED", False) + @patch("requests.post") + def test_offline_token_fallback_when_oauth_disabled( + self, mock_post: MagicMock + ) -> None: + """Test offline token fallback when OAuth is disabled.""" + # No Authorization header + self.mock_headers.get.return_value = None + + # Mock offline token retrieval + with patch("assisted_service_mcp.utils.auth.get_offline_token") as mock_offline: + mock_offline.return_value = "test_offline_token" + + # Mock SSO token exchange + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "exchanged_access_token"} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + with patch( + "assisted_service_mcp.src.settings.get_setting" + ) as mock_get_setting: + mock_get_setting.return_value = "https://sso.example.com/token" + + token = get_access_token(self.mock_mcp) + + assert token == "exchanged_access_token" + mock_offline.assert_called_once_with(self.mock_mcp) diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..18c8bbe --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,260 @@ +"""Tests for OAuth authentication functionality.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi import HTTPException + +from assisted_service_mcp.src.oauth import ( + OAuthManager, + oauth_callback_handler, + oauth_register_handler, + oauth_token_handler, + get_oauth_access_token_from_mcp, +) +from assisted_service_mcp.src.settings import settings + + +class TestOAuthManager: + """Test cases for OAuthManager class.""" + + oauth_manager: OAuthManager + + def setup_method(self) -> None: + """Set up test fixtures.""" + self.oauth_manager = OAuthManager() + + @patch("httpx.AsyncClient") + async def test_exchange_code_for_token_success( + self, mock_client_class: MagicMock + ) -> None: + """Test successful code exchange for token.""" + # Setup + state = "test_state" + code = "test_code" + + # Generate state first to store PKCE verifier + self.oauth_manager.get_authorization_url(state) + + # Mock successful response + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "test_refresh_token", + } + mock_response.raise_for_status.return_value = None + + # Mock the async client + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + # Execute + token = await self.oauth_manager.exchange_code_for_token(code, state) + + # Verify + assert token is not None + assert token["access_token"] == "test_access_token" + assert token["token_type"] == "Bearer" + mock_client.post.assert_called_once() + + @patch("httpx.AsyncClient") + async def test_exchange_code_for_token_request_failure( + self, mock_client_class: MagicMock + ) -> None: + """Test code exchange with request failure.""" + # Setup + state = "test_state" + code = "test_code" + + # Generate state first + self.oauth_manager.get_authorization_url(state) + + # Mock request failure + mock_client = MagicMock() + mock_client.post = AsyncMock(side_effect=httpx.HTTPError("Network error")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + # Execute and verify + with pytest.raises(HTTPException) as exc_info: + await self.oauth_manager.exchange_code_for_token(code, state) + + assert exc_info.value.status_code == 400 + + +class TestOAuthHandlers: + """Test cases for OAuth HTTP handlers.""" + + async def test_oauth_register_handler_success(self) -> None: + """Test successful OAuth registration.""" + mock_request = MagicMock() + + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + response = await oauth_register_handler(mock_request) + + assert "authorization_endpoint" in response + assert "token_endpoint" in response + assert "client_id" in response + assert "redirect_uri" in response + assert "state" in response + assert response["client_id"] == settings.OAUTH_CLIENT + + async def test_oauth_callback_handler_success(self) -> None: + """Test successful OAuth callback.""" + mock_request = MagicMock() + # Create a proper mock query_params object + mock_query_params = MagicMock() + mock_query_params.get.side_effect = lambda key, default=None: { + "code": "test_code", + "state": "test_state", + }.get(key, default) + mock_request.query_params = mock_query_params + + with ( + patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True), + patch( + "assisted_service_mcp.src.oauth.oauth_manager.exchange_code_for_token" + ) as mock_exchange, + ): + mock_exchange.return_value = {"access_token": "test_token"} + + response = await oauth_callback_handler(mock_request) + + assert response.status_code == 200 + assert "Authentication Successful!" in bytes(response.body).decode() + + async def test_oauth_callback_handler_error(self) -> None: + """Test OAuth callback with error.""" + mock_request = MagicMock() + # Create a proper mock query_params object + mock_query_params = MagicMock() + mock_query_params.get.side_effect = lambda key, default=None: { + "error": "access_denied" + }.get(key, default) + mock_request.query_params = mock_query_params + + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + response = await oauth_callback_handler(mock_request) + + assert response.status_code == 400 + assert "OAuth Authentication Failed" in bytes(response.body).decode() + + async def test_oauth_token_handler_success(self) -> None: + """Test successful OAuth token exchange.""" + mock_request = MagicMock() + mock_request.headers = {"content-type": "application/json"} + mock_request.json = AsyncMock( + return_value={ + "grant_type": "authorization_code", + "code": "test_code", + "state": "test_state", + } + ) + + with ( + patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True), + patch( + "assisted_service_mcp.src.oauth.oauth_manager.exchange_code_for_token" + ) as mock_exchange, + ): + mock_exchange.return_value = { + "access_token": "test_token", + "token_type": "Bearer", + "expires_in": 3600, + } + + response = await oauth_token_handler(mock_request) + + assert response["access_token"] == "test_token" + assert response["token_type"] == "Bearer" + + +class TestOAuthIntegration: + """Test cases for OAuth integration with MCP.""" + + def test_get_oauth_access_token_from_mcp_with_token_id(self) -> None: + """Test getting OAuth token from MCP with token ID header.""" + from assisted_service_mcp.src.oauth.models import OAuthToken + import time + + mock_mcp = MagicMock() + mock_context = MagicMock() + mock_request = MagicMock() + + # Create a proper mock headers object + mock_headers = MagicMock() + mock_headers.get.side_effect = lambda key, default=None: { + "X-OAuth-Token-ID": "test_token_id", + "Authorization": None, + }.get(key, default) + mock_request.headers = mock_headers + mock_context.request_context.request = mock_request + mock_mcp.get_context.return_value = mock_context + + # Create a mock token + mock_token = OAuthToken( + token_id="test_token_id", + client_id="test_client", + access_token="stored_access_token", + refresh_token=None, + expires_at=time.time() + 3600, # Not expired + ) + + with ( + patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True), + patch( + "assisted_service_mcp.src.oauth.oauth_manager.token_store.get_token_by_id" + ) as mock_get_token, + ): + mock_get_token.return_value = mock_token + + token = get_oauth_access_token_from_mcp(mock_mcp) + + assert token == "stored_access_token" + mock_get_token.assert_called_once_with("test_token_id") + + def test_get_oauth_access_token_from_mcp_with_bearer_token(self) -> None: + """Test getting OAuth token from MCP with Bearer token.""" + mock_mcp = MagicMock() + mock_context = MagicMock() + mock_request = MagicMock() + + # Create a proper mock headers object + mock_headers = MagicMock() + mock_headers.get.side_effect = lambda key, default=None: { + "X-OAuth-Token-ID": None, + "Authorization": "Bearer test_bearer_token", + }.get(key, default) + mock_request.headers = mock_headers + mock_context.request_context.request = mock_request + mock_mcp.get_context.return_value = mock_context + + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + token = get_oauth_access_token_from_mcp(mock_mcp) + + assert token == "test_bearer_token" + + @patch("assisted_service_mcp.src.settings.settings.OAUTH_ENABLED", False) + def test_get_oauth_access_token_from_mcp_disabled(self) -> None: + """Test getting OAuth token when OAuth is disabled.""" + mock_mcp = MagicMock() + + token = get_oauth_access_token_from_mcp(mock_mcp) + + assert token is None + + def test_get_oauth_access_token_from_mcp_no_context(self) -> None: + """Test getting OAuth token with no MCP context.""" + mock_mcp = MagicMock() + mock_mcp.get_context.return_value = None + + token = get_oauth_access_token_from_mcp(mock_mcp) + + assert token is None diff --git a/tests/test_oauth_integration.py b/tests/test_oauth_integration.py new file mode 100644 index 0000000..ce5f63f --- /dev/null +++ b/tests/test_oauth_integration.py @@ -0,0 +1,120 @@ +"""Integration tests for OAuth functionality with FastAPI.""" + +# pylint: disable=redefined-outer-name + +from typing import Generator +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + + +@pytest.fixture +def oauth_enabled_app() -> Generator[Starlette, None, None]: + """Create a test app with OAuth enabled.""" + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + # Create a fresh app instance and manually register OAuth routes + app = Starlette() + + # Import OAuth handlers + from assisted_service_mcp.src.oauth import ( + oauth_register_handler, + oauth_callback_handler, + oauth_token_handler, + ) + + # Wrap handlers to return proper Response objects + async def wrapped_register_handler(request: Request) -> Response: + result = await oauth_register_handler(request) + if isinstance(result, dict): + return JSONResponse(result) + return result + + async def wrapped_callback_handler(request: Request) -> Response: + result = await oauth_callback_handler(request) + if isinstance(result, dict): + return JSONResponse(result) + return result + + async def wrapped_token_handler(request: Request) -> Response: + result = await oauth_token_handler(request) + if isinstance(result, dict): + return JSONResponse(result) + return result + + # Register OAuth routes + app.add_route("/oauth/register", wrapped_register_handler, methods=["GET"]) + app.add_route("/oauth/callback", wrapped_callback_handler, methods=["GET"]) + app.add_route("/oauth/token", wrapped_token_handler, methods=["POST"]) + + yield app + + +@pytest.fixture +def oauth_disabled_app() -> Generator[Starlette, None, None]: + """Create a test app with OAuth disabled.""" + with patch("assisted_service_mcp.src.settings.settings.OAUTH_ENABLED", False): + # Create a fresh app instance without OAuth routes + app = Starlette() + yield app + + +class TestOAuthEndpointsIntegration: + """Integration tests for OAuth endpoints.""" + + def test_oauth_register_endpoint_enabled( + self, oauth_enabled_app: Starlette + ) -> None: + """Test OAuth register endpoint when OAuth is enabled.""" + app = oauth_enabled_app + client = TestClient(app) + + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + response = client.get("/oauth/register") + + assert response.status_code == 200 + data = response.json() + assert "authorization_endpoint" in data + assert "token_endpoint" in data + assert "client_id" in data + + def test_oauth_callback_endpoint_enabled( + self, oauth_enabled_app: Starlette + ) -> None: + """Test OAuth callback endpoint when OAuth is enabled.""" + app = oauth_enabled_app + client = TestClient(app) + + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + # Test with error parameter + response = client.get("/oauth/callback?error=access_denied") + + assert response.status_code == 400 + assert "OAuth Authentication Failed" in response.text + + def test_oauth_token_endpoint_enabled(self, oauth_enabled_app: Starlette) -> None: + """Test OAuth token endpoint when OAuth is enabled.""" + app = oauth_enabled_app + client = TestClient(app) + + with patch("assisted_service_mcp.src.oauth.settings.OAUTH_ENABLED", True): + # Test with invalid grant type + response = client.post( + "/oauth/token", + json={ + "grant_type": "client_credentials", + "code": "test_code", + "state": "test_state", + }, + ) + + assert response.status_code == 400 + # The response might be plain text or JSON + if response.headers.get("content-type", "").startswith("application/json"): + data = response.json() + assert "Unsupported grant type" in data["detail"] + else: + assert "Unsupported grant type" in response.text diff --git a/tests/utils/test_auth.py b/tests/utils/test_auth.py index cc25b9b..62e2f43 100644 --- a/tests/utils/test_auth.py +++ b/tests/utils/test_auth.py @@ -104,12 +104,16 @@ def test_get_access_token_sso_request_exception() -> None: "assisted_service_mcp.src.settings.settings.SSO_URL", "https://sso.example.com", ), + patch( + "assisted_service_mcp.src.settings.settings.OFFLINE_TOKEN", + "offline-token", + ), ): mock_post.side_effect = requests.exceptions.RequestException("network error") with pytest.raises( RuntimeError, match="Failed to obtain access token from SSO" ): - mod.get_access_token(mcp, offline_token_func=lambda: "offline") + mod.get_access_token(mcp) def test_get_access_token_invalid_json_response() -> None: @@ -127,6 +131,10 @@ def test_get_access_token_invalid_json_response() -> None: "assisted_service_mcp.src.settings.settings.SSO_URL", "https://sso.example.com", ), + patch( + "assisted_service_mcp.src.settings.settings.OFFLINE_TOKEN", + "offline-token", + ), ): with pytest.raises(RuntimeError, match="Invalid SSO response"): - mod.get_access_token(mcp, offline_token_func=lambda: "offline") + mod.get_access_token(mcp) diff --git a/uv.lock b/uv.lock index c7ddfab..3d27f2a 100644 --- a/uv.lock +++ b/uv.lock @@ -143,7 +143,10 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "assisted-service-client" }, + { name = "authlib" }, + { name = "cryptography" }, { name = "fastapi" }, + { name = "httpx" }, { name = "jinja2" }, { name = "mcp" }, { name = "nestedarchive" }, @@ -156,6 +159,7 @@ dependencies = [ { name = "requests" }, { name = "retry" }, { name = "tabulate" }, + { name = "uvicorn" }, ] [package.dev-dependencies] @@ -183,7 +187,10 @@ test = [ [package.metadata] requires-dist = [ { name = "assisted-service-client", specifier = ">=2.41.0.post3" }, + { name = "authlib", specifier = ">=1.3.0" }, + { name = "cryptography", specifier = ">=41.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, + { name = "httpx", specifier = ">=0.27.0" }, { name = "jinja2", specifier = ">=3.1" }, { name = "mcp", specifier = ">=1.15.0" }, { name = "nestedarchive", specifier = ">=0.2.4" }, @@ -196,6 +203,7 @@ requires-dist = [ { name = "requests", specifier = ">=2.32.3" }, { name = "retry", specifier = ">=0.9.2" }, { name = "tabulate", specifier = ">=0.9.0" }, + { name = "uvicorn", specifier = ">=0.30.0" }, ] [package.metadata.requires-dev] @@ -236,6 +244,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] +[[package]] +name = "authlib" +version = "1.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/3f/1d3bbd0bf23bdd99276d4def22f29c27a914067b4cf66f753ff9b8bbd0f3/authlib-1.6.5.tar.gz", hash = "sha256:6aaf9c79b7cc96c900f0b284061691c5d4e61221640a948fe690b556a6d6d10b", size = 164553, upload-time = "2025-10-02T13:36:09.489Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" }, +] + [[package]] name = "black" version = "25.11.0" From 561d0f44abd62c904ec63734d9b02ebe86bb0c37 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Sun, 16 Nov 2025 21:50:38 +0200 Subject: [PATCH 3/3] Thread-safe TokenStore - Add threading.RLock() to TokenStore for concurrent access protection - Protect all dict operations (store, get, update, remove, cleanup) with lock - Add _remove_token_unsafe() internal method for re-entrant calls - Prevents race conditions between FastAPI handlers and MCP tool threads - RLock allows nested locking (e.g., get_token_by_client -> get_token_by_id) --- assisted_service_mcp/src/oauth/store.py | 116 +++++++++++++++--------- 1 file changed, 71 insertions(+), 45 deletions(-) diff --git a/assisted_service_mcp/src/oauth/store.py b/assisted_service_mcp/src/oauth/store.py index e2d2ceb..7e8bf41 100644 --- a/assisted_service_mcp/src/oauth/store.py +++ b/assisted_service_mcp/src/oauth/store.py @@ -4,6 +4,7 @@ oauth.py and mcp_oauth_middleware.py. """ +import threading import time from typing import Dict, Optional @@ -19,10 +20,14 @@ class TokenStore: - mcp_oauth_middleware.completed_tokens (client_id -> token_id) Into a single, well-defined storage mechanism. + + Thread-safe: All operations are protected by a re-entrant lock to handle + concurrent access from FastAPI async handlers and MCP tool worker threads. """ def __init__(self) -> None: """Initialize token store.""" + self._lock = threading.RLock() # Re-entrant lock for thread safety self._tokens: Dict[str, OAuthToken] = {} self._client_tokens: Dict[str, str] = {} # client_id -> token_id @@ -32,14 +37,15 @@ def store_token(self, token: OAuthToken) -> None: Args: token: OAuthToken to store """ - self._tokens[token.token_id] = token - self._client_tokens[token.client_id] = token.token_id - log.debug( - "Stored token %s for client %s (expires at %s)", - token.token_id, - token.client_id, - token.expires_at, - ) + with self._lock: + self._tokens[token.token_id] = token + self._client_tokens[token.client_id] = token.token_id + log.debug( + "Stored token %s for client %s (expires at %s)", + token.token_id, + token.client_id, + token.expires_at, + ) def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]: """Get token by token ID. @@ -50,16 +56,17 @@ def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]: Returns: OAuthToken if found and valid, None otherwise """ - token = self._tokens.get(token_id) - if not token: - return None + with self._lock: + token = self._tokens.get(token_id) + if not token: + return None - if token.is_expired(): - log.debug("Token %s is expired, removing from store", token_id) - self.remove_token(token_id) - return None + if token.is_expired(): + log.debug("Token %s is expired, removing from store", token_id) + self._remove_token_unsafe(token_id) + return None - return token + return token def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]: """Get token for a client. @@ -70,11 +77,13 @@ def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]: Returns: OAuthToken if found and valid, None otherwise """ - token_id = self._client_tokens.get(client_id) - if not token_id: - return None + with self._lock: + token_id = self._client_tokens.get(client_id) + if not token_id: + return None - return self.get_token_by_id(token_id) + # Re-use get_token_by_id which also acquires lock (RLock allows re-entrance) + return self.get_token_by_id(token_id) def get_access_token_by_id(self, token_id: str) -> Optional[str]: """Get access token string by token ID. @@ -118,17 +127,18 @@ def update_token( Returns: True if token was updated, False if token not found """ - token = self._tokens.get(token_id) - if not token: - return False + with self._lock: + token = self._tokens.get(token_id) + if not token: + return False - token.access_token = access_token - if refresh_token: - token.refresh_token = refresh_token - token.expires_at = expires_at + token.access_token = access_token + if refresh_token: + token.refresh_token = refresh_token + token.expires_at = expires_at - log.debug("Updated token %s (new expiry: %s)", token_id, expires_at) - return True + log.debug("Updated token %s (new expiry: %s)", token_id, expires_at) + return True def remove_token(self, token_id: str) -> None: """Remove a token and its associations. @@ -136,6 +146,18 @@ def remove_token(self, token_id: str) -> None: Args: token_id: Token identifier to remove """ + with self._lock: + self._remove_token_unsafe(token_id) + + def _remove_token_unsafe(self, token_id: str) -> None: + """Remove a token without acquiring lock (internal use only). + + Args: + token_id: Token identifier to remove + + Note: + Caller must hold self._lock before calling this method. + """ token = self._tokens.pop(token_id, None) if token: self._client_tokens.pop(token.client_id, None) @@ -147,9 +169,10 @@ def remove_client_token(self, client_id: str) -> None: Args: client_id: Client identifier """ - token_id = self._client_tokens.get(client_id) - if token_id: - self.remove_token(token_id) + with self._lock: + token_id = self._client_tokens.get(client_id) + if token_id: + self._remove_token_unsafe(token_id) def cleanup_expired_tokens(self) -> int: """Clean up expired tokens. @@ -157,20 +180,21 @@ def cleanup_expired_tokens(self) -> int: Returns: Number of tokens removed """ - current_time = time.time() - expired_token_ids = [ - token_id - for token_id, token in self._tokens.items() - if current_time >= token.expires_at - ] + with self._lock: + current_time = time.time() + expired_token_ids = [ + token_id + for token_id, token in self._tokens.items() + if current_time >= token.expires_at + ] - for token_id in expired_token_ids: - self.remove_token(token_id) + for token_id in expired_token_ids: + self._remove_token_unsafe(token_id) - if expired_token_ids: - log.info("Cleaned up %d expired tokens", len(expired_token_ids)) + if expired_token_ids: + log.info("Cleaned up %d expired tokens", len(expired_token_ids)) - return len(expired_token_ids) + return len(expired_token_ids) def get_all_tokens(self) -> Dict[str, OAuthToken]: """Get all stored tokens (for debugging/monitoring). @@ -178,7 +202,8 @@ def get_all_tokens(self) -> Dict[str, OAuthToken]: Returns: Dictionary of token_id -> OAuthToken """ - return self._tokens.copy() + with self._lock: + return self._tokens.copy() def get_token_count(self) -> int: """Get number of stored tokens. @@ -186,4 +211,5 @@ def get_token_count(self) -> int: Returns: Number of tokens in store """ - return len(self._tokens) + with self._lock: + return len(self._tokens)