From 59075e2ced963b5273d9e9242b35a3d205e58575 Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sat, 6 Sep 2025 16:53:47 +0530 Subject: [PATCH 1/7] initial implementation for oauth + apikey providers --- src/mcpadapt/auth/__init__.py | 40 ++++++++ src/mcpadapt/auth/authenticate.py | 43 ++++++++ src/mcpadapt/auth/handlers.py | 156 ++++++++++++++++++++++++++++++ src/mcpadapt/auth/models.py | 105 ++++++++++++++++++++ src/mcpadapt/auth/oauth.py | 45 +++++++++ src/mcpadapt/auth/providers.py | 106 ++++++++++++++++++++ src/mcpadapt/core.py | 72 +++++++++++++- 7 files changed, 563 insertions(+), 4 deletions(-) create mode 100644 src/mcpadapt/auth/__init__.py create mode 100644 src/mcpadapt/auth/authenticate.py create mode 100644 src/mcpadapt/auth/handlers.py create mode 100644 src/mcpadapt/auth/models.py create mode 100644 src/mcpadapt/auth/oauth.py create mode 100644 src/mcpadapt/auth/providers.py diff --git a/src/mcpadapt/auth/__init__.py b/src/mcpadapt/auth/__init__.py new file mode 100644 index 0000000..e1fbcdf --- /dev/null +++ b/src/mcpadapt/auth/__init__.py @@ -0,0 +1,40 @@ +"""Authentication module for MCPAdapt.""" + +from .handlers import default_callback_handler, default_redirect_handler +from .oauth import InMemoryTokenStorage +from .providers import ( + ApiKeyAuthProvider, + BearerAuthProvider, + create_auth_provider, + get_auth_headers, +) +from .models import ( + ApiKeyConfig, + AuthConfig, + AuthConfigBase, + BearerAuthConfig, + CallbackHandler, + OAuthConfig, + RedirectHandler, +) + +__all__ = [ + # Types + "AuthConfig", + "AuthConfigBase", + "OAuthConfig", + "ApiKeyConfig", + "BearerAuthConfig", + "CallbackHandler", + "RedirectHandler", + # OAuth utilities + "InMemoryTokenStorage", + # Handlers + "default_callback_handler", + "default_redirect_handler", + # Providers + "ApiKeyAuthProvider", + "BearerAuthProvider", + "create_auth_provider", + "get_auth_headers", +] diff --git a/src/mcpadapt/auth/authenticate.py b/src/mcpadapt/auth/authenticate.py new file mode 100644 index 0000000..519694a --- /dev/null +++ b/src/mcpadapt/auth/authenticate.py @@ -0,0 +1,43 @@ +"""Authentication utilities for pre-authenticating providers.""" + +from typing import Any + +from mcp.client.auth import OAuthClientProvider + +from .providers import ApiKeyAuthProvider, BearerAuthProvider, create_auth_provider +from .models import AuthConfig, OAuthConfig + + +async def authenticate( + auth_config: AuthConfig, + server_url: str +) -> OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider: + """ + Create and prepare an auth provider for use with MCPAdapt. + + For OAuth: Creates a configured OAuth provider that will perform the OAuth flow + (browser redirect, callback, token exchange) when first used by MCPAdapt + For API Key/Bearer: Creates a ready-to-use provider (no additional flow needed) + + Args: + auth_config: Authentication configuration + server_url: Server URL (needed for OAuth server endpoint discovery) + + Returns: + Auth provider ready to use with MCPAdapt + + Example: + >>> # Prepare OAuth provider + >>> oauth_config = OAuthConfig(client_metadata={...}) + >>> auth_provider = await authenticate(oauth_config, "https://mcp.canva.com/mcp") + >>> + >>> # Use with MCPAdapt - OAuth flow will happen during connection + >>> with MCPAdapt(server_config, adapter, auth_provider=auth_provider) as tools: + >>> print(tools) + + Note: + For OAuth, the actual authentication flow (browser redirect, token exchange) + occurs when MCPAdapt makes its first connection to the MCP server. This function + prepares the OAuth provider with all necessary configuration. + """ + return await create_auth_provider(auth_config, server_url) diff --git a/src/mcpadapt/auth/handlers.py b/src/mcpadapt/auth/handlers.py new file mode 100644 index 0000000..4286b19 --- /dev/null +++ b/src/mcpadapt/auth/handlers.py @@ -0,0 +1,156 @@ +"""Default OAuth callback and redirect handlers.""" + +import threading +import time +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any +from urllib.parse import parse_qs, urlparse + + +class CallbackHandler(BaseHTTPRequestHandler): + """Simple HTTP handler to capture OAuth callback.""" + + def __init__(self, request, client_address, server, callback_data): + """Initialize with callback data storage.""" + self.callback_data = callback_data + super().__init__(request, client_address, server) + + def do_GET(self): + """Handle GET request from OAuth redirect.""" + parsed = urlparse(self.path) + query_params = parse_qs(parsed.query) + + if "code" in query_params: + self.callback_data["authorization_code"] = query_params["code"][0] + self.callback_data["state"] = query_params.get("state", [None])[0] + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b""" + + +

Authorization Successful!

+

You can close this window and return to the terminal.

+ + + + """) + elif "error" in query_params: + self.callback_data["error"] = query_params["error"][0] + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + f""" + + +

Authorization Failed

+

Error: {query_params["error"][0]}

+

You can close this window and return to the terminal.

+ + + """.encode() + ) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + """Suppress default logging.""" + pass + + +class LocalCallbackServer: + """Simple server to handle OAuth callbacks.""" + + def __init__(self, port: int = 3030): + """Initialize callback server. + + Args: + port: Port to listen on for OAuth callbacks + """ + self.port = port + self.server = None + self.thread = None + self.callback_data = {"authorization_code": None, "state": None, "error": None} + + def _create_handler_with_data(self): + """Create a handler class with access to callback data.""" + callback_data = self.callback_data + + class DataCallbackHandler(CallbackHandler): + def __init__(self, request, client_address, server): + super().__init__(request, client_address, server, callback_data) + + return DataCallbackHandler + + def start(self) -> None: + """Start the callback server in a background thread.""" + handler_class = self._create_handler_with_data() + self.server = HTTPServer(("localhost", self.port), handler_class) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + + def stop(self) -> None: + """Stop the callback server.""" + if self.server: + self.server.shutdown() + self.server.server_close() + if self.thread: + self.thread.join(timeout=1) + + def wait_for_callback(self, timeout: int = 300) -> str: + """Wait for OAuth callback with timeout. + + Args: + timeout: Maximum time to wait for callback in seconds + + Returns: + Authorization code from OAuth callback + + Raises: + Exception: If timeout occurs or OAuth error is received + """ + start_time = time.time() + while time.time() - start_time < timeout: + if self.callback_data["authorization_code"]: + return self.callback_data["authorization_code"] + elif self.callback_data["error"]: + raise Exception(f"OAuth error: {self.callback_data['error']}") + time.sleep(0.1) + raise Exception("Timeout waiting for OAuth callback") + + def get_state(self) -> str | None: + """Get the received state parameter. + + Returns: + OAuth state parameter or None + """ + return self.callback_data["state"] + + +async def default_redirect_handler(authorization_url: str) -> None: + """Default redirect handler that opens the URL in a browser. + + Args: + authorization_url: OAuth authorization URL to open + """ + webbrowser.open(authorization_url) + + +async def default_callback_handler() -> tuple[str, str | None]: + """Default callback handler that starts local server and waits for OAuth callback. + + Returns: + Tuple of (authorization_code, state) + """ + callback_server = LocalCallbackServer(port=3030) + callback_server.start() + + try: + auth_code = callback_server.wait_for_callback(timeout=300) + state = callback_server.get_state() + return auth_code, state + finally: + callback_server.stop() diff --git a/src/mcpadapt/auth/models.py b/src/mcpadapt/auth/models.py new file mode 100644 index 0000000..dd5b7ce --- /dev/null +++ b/src/mcpadapt/auth/models.py @@ -0,0 +1,105 @@ +"""Authentication configuration types and protocols for MCPAdapt.""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Coroutine, Protocol, Union + +from mcp.shared.auth import OAuthClientMetadata + + +class AuthConfigBase(ABC): + """Base class for all authentication configurations.""" + + @abstractmethod + def get_auth_type(self) -> str: + """Return the authentication type identifier.""" + pass + + +class OAuthConfig(AuthConfigBase): + """OAuth authentication configuration.""" + + def __init__( + self, + client_metadata: OAuthClientMetadata | dict[str, Any], + callback_handler: "CallbackHandler | None" = None, + redirect_handler: "RedirectHandler | None" = None, + ): + """Initialize OAuth configuration. + + Args: + client_metadata: OAuth client metadata or dict representation + callback_handler: Optional custom callback handler + redirect_handler: Optional custom redirect handler + """ + if isinstance(client_metadata, dict): + self.client_metadata = OAuthClientMetadata.model_validate(client_metadata) + else: + self.client_metadata = client_metadata + self.callback_handler = callback_handler + self.redirect_handler = redirect_handler + + def get_auth_type(self) -> str: + """Return OAuth auth type.""" + return "oauth" + + +class ApiKeyConfig(AuthConfigBase): + """API Key authentication configuration.""" + + def __init__(self, header_name: str, header_value: str): + """Initialize API key configuration. + + Args: + header_name: Name of the header to send the API key in + header_value: The API key value + """ + self.header_name = header_name + self.header_value = header_value + + def get_auth_type(self) -> str: + """Return API key auth type.""" + return "api_key" + + +class BearerAuthConfig(AuthConfigBase): + """Bearer token authentication configuration.""" + + def __init__(self, token: str): + """Initialize bearer auth configuration. + + Args: + token: The bearer token + """ + self.token = token + + def get_auth_type(self) -> str: + """Return bearer auth type.""" + return "bearer" + + +# Union type for all auth configurations +AuthConfig = Union[OAuthConfig, ApiKeyConfig, BearerAuthConfig] + + +class CallbackHandler(Protocol): + """Protocol for OAuth callback handlers.""" + + async def __call__(self) -> tuple[str, str | None]: + """Handle OAuth callback and return authorization code and state. + + Returns: + Tuple of (authorization_code, state) + """ + ... + + +class RedirectHandler(Protocol): + """Protocol for OAuth redirect handlers.""" + + async def __call__(self, authorization_url: str) -> None: + """Handle OAuth redirect by opening authorization URL. + + Args: + authorization_url: The OAuth authorization URL to redirect to + """ + ... diff --git a/src/mcpadapt/auth/oauth.py b/src/mcpadapt/auth/oauth.py new file mode 100644 index 0000000..6195b8b --- /dev/null +++ b/src/mcpadapt/auth/oauth.py @@ -0,0 +1,45 @@ +"""OAuth token storage and utility implementations.""" + +from mcp.client.auth import TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + + +class InMemoryTokenStorage(TokenStorage): + """Simple in-memory token storage implementation.""" + + def __init__(self): + """Initialize empty token storage.""" + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + """Get stored OAuth tokens. + + Returns: + Stored OAuth tokens or None if not available + """ + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store OAuth tokens. + + Args: + tokens: OAuth tokens to store + """ + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + """Get stored OAuth client information. + + Returns: + Stored OAuth client information or None if not available + """ + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store OAuth client information. + + Args: + client_info: OAuth client information to store + """ + self._client_info = client_info diff --git a/src/mcpadapt/auth/providers.py b/src/mcpadapt/auth/providers.py new file mode 100644 index 0000000..12a7700 --- /dev/null +++ b/src/mcpadapt/auth/providers.py @@ -0,0 +1,106 @@ +"""Authentication provider factory and integration logic.""" + +from typing import Any +from urllib.parse import urlparse + +from mcp.client.auth import OAuthClientProvider + +from .handlers import default_callback_handler, default_redirect_handler +from .oauth import InMemoryTokenStorage +from .models import ApiKeyConfig, AuthConfig, BearerAuthConfig, OAuthConfig + + +class ApiKeyAuthProvider: + """Simple API key authentication provider.""" + + def __init__(self, config: ApiKeyConfig): + """Initialize with API key configuration. + + Args: + config: API key configuration + """ + self.config = config + + def get_headers(self) -> dict[str, str]: + """Get authentication headers. + + Returns: + Dictionary of headers to add to requests + """ + return {self.config.header_name: self.config.header_value} + + +class BearerAuthProvider: + """Simple Bearer token authentication provider.""" + + def __init__(self, config: BearerAuthConfig): + """Initialize with Bearer token configuration. + + Args: + config: Bearer token configuration + """ + self.config = config + + def get_headers(self) -> dict[str, str]: + """Get authentication headers. + + Returns: + Dictionary of headers to add to requests + """ + return {"Authorization": f"Bearer {self.config.token}"} + + +async def create_auth_provider( + auth_config: AuthConfig, server_url: str +) -> OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider: + """Factory function to create appropriate auth provider from config. + + Args: + auth_config: Authentication configuration + server_url: Server URL for OAuth (needed to determine OAuth server endpoint) + + Returns: + Appropriate auth provider instance + + Raises: + ValueError: If auth configuration type is not supported + """ + if isinstance(auth_config, OAuthConfig): + # Use provided handlers or default ones + callback_handler = auth_config.callback_handler or default_callback_handler + redirect_handler = auth_config.redirect_handler or default_redirect_handler + + # Create OAuth provider with domain root only + parsed_url = urlparse(server_url) + oauth_server_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + return OAuthClientProvider( + server_url=oauth_server_url, + client_metadata=auth_config.client_metadata, + storage=InMemoryTokenStorage(), + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + elif isinstance(auth_config, ApiKeyConfig): + return ApiKeyAuthProvider(auth_config) + + elif isinstance(auth_config, BearerAuthConfig): + return BearerAuthProvider(auth_config) + + else: + raise ValueError(f"Unsupported auth configuration type: {type(auth_config)}") + + +def get_auth_headers(auth_provider: Any) -> dict[str, str]: + """Get authentication headers from provider. + + Args: + auth_provider: Authentication provider instance + + Returns: + Dictionary of headers to add to requests + """ + if isinstance(auth_provider, (ApiKeyAuthProvider, BearerAuthProvider)): + return auth_provider.get_headers() + return {} diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index cdf2d76..f56b17b 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -15,10 +15,14 @@ import mcp from mcp import ClientSession, StdioServerParameters +from mcp.client.auth import OAuthClientProvider from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client +from .auth import AuthConfig, create_auth_provider, get_auth_headers +from .auth.providers import ApiKeyAuthProvider, BearerAuthProvider + class ToolAdapter(ABC): """A basic interface for adapting tools from MCP to the desired Agent framework.""" @@ -75,6 +79,7 @@ def async_adapt( async def mcptools( serverparams: StdioServerParameters | dict[str, Any], client_session_timeout_seconds: float | timedelta | None = 5, + auth_provider: Any = None, ) -> AsyncGenerator[tuple[ClientSession, list[mcp.types.Tool]], None]: """Async context manager that yields tools from an MCP server. @@ -86,6 +91,7 @@ async def mcptools( * if StdioServerParameters, run the MCP server using the stdio protocol. * if dict, assume the dict corresponds to parameters to an sse MCP server. client_session_timeout_seconds: Timeout for MCP ClientSession calls + auth_provider: Optional authentication provider for securing connections Yields: A tuple of (MCP Client Session, list of MCP tools) available on the MCP server. @@ -100,6 +106,19 @@ async def mcptools( # Create a deep copy to avoid modifying the original dict client_params = copy.deepcopy(serverparams) transport = client_params.pop("transport", "sse") + + # Add authentication if provided + if auth_provider is not None: + if isinstance(auth_provider, OAuthClientProvider): + client_params["auth"] = auth_provider + elif isinstance(auth_provider, (ApiKeyAuthProvider, BearerAuthProvider)): + # Add custom headers for API Key and Bearer auth + headers = get_auth_headers(auth_provider) + if "headers" in client_params: + client_params["headers"].update(headers) + else: + client_params["headers"] = headers + if transport == "sse": client = sse_client(**client_params) elif transport == "streamable-http": @@ -180,6 +199,7 @@ def __init__( adapter: ToolAdapter, connect_timeout: int = 30, client_session_timeout_seconds: float | timedelta | None = 5, + auth_config: AuthConfig | list[AuthConfig | None] | None = None, ): """ Manage the MCP server / client lifecycle and expose tools adapted with the adapter. @@ -190,6 +210,8 @@ def __init__( adapter (ToolAdapter): Adapter to use to convert MCP tools call into agentic framework tools. connect_timeout (int): Connection timeout in seconds to the mcp server (default is 30s). client_session_timeout_seconds: Timeout for MCP ClientSession calls + auth_config: Optional authentication configuration. Can be a single config for all servers, + a list matching serverparams length, or None for no authentication. Raises: TimeoutError: When the connection to the mcp server time out. @@ -202,9 +224,23 @@ def __init__( self.adapter = adapter + # Handle auth_config - ensure it matches serverparams length + if auth_config is None: + self.auth_configs = [None] * len(self.serverparams) + elif isinstance(auth_config, list): + if len(auth_config) != len(self.serverparams): + raise ValueError( + f"auth_config list length ({len(auth_config)}) must match serverparams length ({len(self.serverparams)})" + ) + self.auth_configs = auth_config + else: + # Single auth config for all servers + self.auth_configs = [auth_config] * len(self.serverparams) + # session and tools get set by the async loop during initialization. self.sessions: list[ClientSession] = [] self.mcp_tools: list[list[mcp.types.Tool]] = [] + self.auth_providers: list[Any] = [] # all attributes used to manage the async loop and separate thread. self.loop = asyncio.new_event_loop() @@ -221,12 +257,26 @@ def _run_loop(self): asyncio.set_event_loop(self.loop) async def setup(): + # Create auth providers if needed (only if not already provided) + if not self.auth_providers: + auth_providers = [] + for params, auth_config in zip(self.serverparams, self.auth_configs): + if auth_config is not None: + # Get server URL from params for OAuth + server_url = params.get("url", "") if isinstance(params, dict) else "" + auth_provider = await create_auth_provider(auth_config, server_url) + auth_providers.append(auth_provider) + else: + auth_providers.append(None) + + self.auth_providers = auth_providers + async with AsyncExitStack() as stack: connections = [ await stack.enter_async_context( - mcptools(params, self.client_session_timeout_seconds) + mcptools(params, self.client_session_timeout_seconds, auth_provider) ) - for params in self.serverparams + for params, auth_provider in zip(self.serverparams, self.auth_providers) ] self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] self.ready.set() # Signal initialization is complete @@ -330,11 +380,25 @@ async def atools(self) -> list[Any]: async def __aenter__(self) -> list[Any]: self._ctxmanager = AsyncExitStack() + # Create auth providers if needed (only if not already provided) + if not self.auth_providers: + auth_providers = [] + for params, auth_config in zip(self.serverparams, self.auth_configs): + if auth_config is not None: + # Get server URL from params for OAuth + server_url = params.get("url", "") if isinstance(params, dict) else "" + auth_provider = await create_auth_provider(auth_config, server_url) + auth_providers.append(auth_provider) + else: + auth_providers.append(None) + + self.auth_providers = auth_providers + connections = [ await self._ctxmanager.enter_async_context( - mcptools(params, self.client_session_timeout_seconds) + mcptools(params, self.client_session_timeout_seconds, auth_provider) ) - for params in self.serverparams + for params, auth_provider in zip(self.serverparams, self.auth_providers) ] self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] From e39acbcae594a61de583a82bb4b282a4e7a8a305 Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sat, 6 Sep 2025 19:03:13 +0530 Subject: [PATCH 2/7] Nicer error handling --- src/mcpadapt/auth/__init__.py | 140 ++++++++++++++++++----- src/mcpadapt/auth/exceptions.py | 85 ++++++++++++++ src/mcpadapt/auth/handlers.py | 191 +++++++++++++++++++++++++++----- src/mcpadapt/auth/models.py | 116 ++----------------- src/mcpadapt/auth/providers.py | 69 ++---------- src/mcpadapt/core.py | 55 +++------ 6 files changed, 399 insertions(+), 257 deletions(-) create mode 100644 src/mcpadapt/auth/exceptions.py diff --git a/src/mcpadapt/auth/__init__.py b/src/mcpadapt/auth/__init__.py index e1fbcdf..d068f04 100644 --- a/src/mcpadapt/auth/__init__.py +++ b/src/mcpadapt/auth/__init__.py @@ -1,40 +1,128 @@ -"""Authentication module for MCPAdapt.""" +"""Authentication module for MCPAdapt. + +This module provides OAuth, API Key, and Bearer token authentication support +for MCP servers. + +Example usage with OAuth: + +```python +from mcp.client.auth import OAuthClientProvider +from mcp.shared.auth import OAuthClientMetadata +from pydantic import HttpUrl + +from mcpadapt.auth import ( + InMemoryTokenStorage, + LocalBrowserOAuthHandler +) +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Create OAuth provider directly +client_metadata = OAuthClientMetadata( + client_name="My App", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) + +oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) +token_storage = InMemoryTokenStorage() + +oauth_provider = OAuthClientProvider( + server_url="https://example.com", + client_metadata=client_metadata, + storage=token_storage, + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, +) + +# Use with MCPAdapt +with MCPAdapt( + serverparams={"url": "https://example.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +Example usage with API Key: + +```python +from mcpadapt.auth import ApiKeyAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Create API Key provider +api_key_provider = ApiKeyAuthProvider( + header_name="X-API-Key", + header_value="your-api-key-here" +) + +with MCPAdapt( + serverparams={"url": "https://example.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=api_key_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +For custom implementations, extend BaseOAuthHandler: + +```python +from mcpadapt.auth import BaseOAuthHandler + +class CustomOAuthHandler(BaseOAuthHandler): + async def handle_redirect(self, authorization_url: str) -> None: + # Custom redirect logic (e.g., print URL for headless environments) + print(f"Please open: {authorization_url}") + + async def handle_callback(self) -> tuple[str, str | None]: + # Custom callback logic (e.g., manual code input) + auth_code = input("Enter authorization code: ") + return auth_code, None +``` +""" -from .handlers import default_callback_handler, default_redirect_handler from .oauth import InMemoryTokenStorage +from .handlers import ( + BaseOAuthHandler, + LocalBrowserOAuthHandler, + LocalCallbackServer, +) from .providers import ( ApiKeyAuthProvider, BearerAuthProvider, - create_auth_provider, get_auth_headers, ) -from .models import ( - ApiKeyConfig, - AuthConfig, - AuthConfigBase, - BearerAuthConfig, - CallbackHandler, - OAuthConfig, - RedirectHandler, +from .exceptions import ( + OAuthError, + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, + OAuthConfigurationError, + OAuthServerError, + OAuthCallbackError, ) __all__ = [ - # Types - "AuthConfig", - "AuthConfigBase", - "OAuthConfig", - "ApiKeyConfig", - "BearerAuthConfig", - "CallbackHandler", - "RedirectHandler", - # OAuth utilities - "InMemoryTokenStorage", - # Handlers - "default_callback_handler", - "default_redirect_handler", - # Providers + # Handler classes + "BaseOAuthHandler", + "LocalBrowserOAuthHandler", + "LocalCallbackServer", + # Provider classes "ApiKeyAuthProvider", "BearerAuthProvider", - "create_auth_provider", + # Default implementations + "InMemoryTokenStorage", + # Provider functions "get_auth_headers", + # Exception classes + "OAuthError", + "OAuthTimeoutError", + "OAuthCancellationError", + "OAuthNetworkError", + "OAuthConfigurationError", + "OAuthServerError", + "OAuthCallbackError", ] diff --git a/src/mcpadapt/auth/exceptions.py b/src/mcpadapt/auth/exceptions.py new file mode 100644 index 0000000..c06fad5 --- /dev/null +++ b/src/mcpadapt/auth/exceptions.py @@ -0,0 +1,85 @@ +"""Custom exceptions for OAuth authentication errors.""" + + +class OAuthError(Exception): + """Base class for all OAuth authentication errors.""" + + def __init__(self, message: str, error_code: str | None = None, context: dict | None = None): + """Initialize OAuth error. + + Args: + message: Human-readable error message + error_code: Machine-readable error code (optional) + context: Additional context about the error (optional) + """ + super().__init__(message) + self.error_code = error_code + self.context = context or {} + + +class OAuthTimeoutError(OAuthError): + """Raised when OAuth callback doesn't arrive within the specified timeout.""" + + def __init__(self, timeout_seconds: int, context: dict | None = None): + message = ( + f"OAuth authentication timed out after {timeout_seconds} seconds. " + f"The user may have closed the browser window or the OAuth server may be unreachable. " + f"Try refreshing the browser or check your network connection." + ) + super().__init__(message, "oauth_timeout", context) + self.timeout_seconds = timeout_seconds + + +class OAuthCancellationError(OAuthError): + """Raised when the user cancels or denies the OAuth authorization.""" + + def __init__(self, error_details: str | None = None, context: dict | None = None): + base_message = "OAuth authorization was cancelled or denied by the user." + if error_details: + message = f"{base_message} Details: {error_details}" + else: + message = base_message + super().__init__(message, "oauth_cancelled", context) + self.error_details = error_details + + +class OAuthNetworkError(OAuthError): + """Raised when network-related issues prevent OAuth completion.""" + + def __init__(self, original_error: Exception, context: dict | None = None): + message = ( + f"OAuth authentication failed due to network error: {str(original_error)}. " + f"Check your internet connection and try again." + ) + super().__init__(message, "oauth_network_error", context) + self.original_error = original_error + + +class OAuthConfigurationError(OAuthError): + """Raised when OAuth configuration is invalid or incomplete.""" + + def __init__(self, config_issue: str, context: dict | None = None): + message = f"OAuth configuration error: {config_issue}" + super().__init__(message, "oauth_config_error", context) + self.config_issue = config_issue + + +class OAuthServerError(OAuthError): + """Raised when the OAuth server returns an error response.""" + + def __init__(self, server_error: str, error_description: str | None = None, context: dict | None = None): + message = f"OAuth server error: {server_error}" + if error_description: + message += f" - {error_description}" + super().__init__(message, "oauth_server_error", context) + self.server_error = server_error + self.error_description = error_description + + +class OAuthCallbackError(OAuthError): + """Raised when there's an issue with the OAuth callback handling.""" + + def __init__(self, callback_issue: str, context: dict | None = None): + message = f"OAuth callback error: {callback_issue}" + super().__init__(message, "oauth_callback_error", context) + self.callback_issue = callback_issue diff --git a/src/mcpadapt/auth/handlers.py b/src/mcpadapt/auth/handlers.py index 4286b19..5de9173 100644 --- a/src/mcpadapt/auth/handlers.py +++ b/src/mcpadapt/auth/handlers.py @@ -1,12 +1,21 @@ -"""Default OAuth callback and redirect handlers.""" +"""OAuth handlers for managing authentication flows.""" import threading import time import webbrowser +from abc import ABC, abstractmethod from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any from urllib.parse import parse_qs, urlparse +from .exceptions import ( + OAuthCallbackError, + OAuthCancellationError, + OAuthNetworkError, + OAuthServerError, + OAuthTimeoutError, +) + class CallbackHandler(BaseHTTPRequestHandler): """Simple HTTP handler to capture OAuth callback.""" @@ -86,11 +95,32 @@ def __init__(self, request, client_address, server): return DataCallbackHandler def start(self) -> None: - """Start the callback server in a background thread.""" - handler_class = self._create_handler_with_data() - self.server = HTTPServer(("localhost", self.port), handler_class) - self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) - self.thread.start() + """Start the callback server in a background thread. + + Raises: + OAuthCallbackError: If server cannot be started + """ + try: + handler_class = self._create_handler_with_data() + self.server = HTTPServer(("localhost", self.port), handler_class) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + except OSError as e: + if e.errno == 48: # Address already in use + raise OAuthCallbackError( + f"Port {self.port} is already in use. Try using a different port or check if another OAuth flow is running.", + context={"port": self.port, "original_error": str(e)} + ) + else: + raise OAuthCallbackError( + f"Failed to start OAuth callback server on port {self.port}: {str(e)}", + context={"port": self.port, "original_error": str(e)} + ) + except Exception as e: + raise OAuthCallbackError( + f"Unexpected error starting OAuth callback server: {str(e)}", + context={"port": self.port, "original_error": str(e)} + ) def stop(self) -> None: """Stop the callback server.""" @@ -110,16 +140,36 @@ def wait_for_callback(self, timeout: int = 300) -> str: Authorization code from OAuth callback Raises: - Exception: If timeout occurs or OAuth error is received + OAuthTimeoutError: If timeout occurs + OAuthCancellationError: If user cancels authorization + OAuthServerError: If OAuth server returns an error """ start_time = time.time() while time.time() - start_time < timeout: if self.callback_data["authorization_code"]: return self.callback_data["authorization_code"] elif self.callback_data["error"]: - raise Exception(f"OAuth error: {self.callback_data['error']}") + error = self.callback_data["error"] + context = {"port": self.port, "timeout": timeout} + + # Map common OAuth error codes to specific exceptions + if error in ["access_denied", "user_cancelled"]: + raise OAuthCancellationError( + error_details=error, + context=context + ) + else: + # Generic server error for other OAuth errors + raise OAuthServerError( + server_error=error, + context=context + ) time.sleep(0.1) - raise Exception("Timeout waiting for OAuth callback") + + raise OAuthTimeoutError( + timeout_seconds=timeout, + context={"port": self.port} + ) def get_state(self) -> str | None: """Get the received state parameter. @@ -130,27 +180,116 @@ def get_state(self) -> str | None: return self.callback_data["state"] -async def default_redirect_handler(authorization_url: str) -> None: - """Default redirect handler that opens the URL in a browser. +class BaseOAuthHandler(ABC): + """Base class for OAuth authentication handlers. - Args: - authorization_url: OAuth authorization URL to open + Combines redirect and callback handling into a single cohesive interface. + Subclasses should implement both the redirect flow (opening authorization URL) + and callback flow (receiving authorization code). """ - webbrowser.open(authorization_url) + + @abstractmethod + async def handle_redirect(self, authorization_url: str) -> None: + """Handle OAuth redirect to authorization URL. + + Args: + authorization_url: The OAuth authorization URL to redirect the user to + """ + pass + + @abstractmethod + async def handle_callback(self) -> tuple[str, str | None]: + """Handle OAuth callback and return authorization code and state. + + Returns: + Tuple of (authorization_code, state) received from OAuth provider + + Raises: + Exception: If OAuth flow fails or times out + """ + pass -async def default_callback_handler() -> tuple[str, str | None]: - """Default callback handler that starts local server and waits for OAuth callback. +class LocalBrowserOAuthHandler(BaseOAuthHandler): + """Default OAuth handler using local browser and callback server. - Returns: - Tuple of (authorization_code, state) + Opens authorization URL in the user's default browser and starts a local + HTTP server to receive the OAuth callback. This is the most user-friendly + approach for desktop applications. """ - callback_server = LocalCallbackServer(port=3030) - callback_server.start() - try: - auth_code = callback_server.wait_for_callback(timeout=300) - state = callback_server.get_state() - return auth_code, state - finally: - callback_server.stop() + def __init__(self, callback_port: int = 3030, timeout: int = 300): + """Initialize the local browser OAuth handler. + + Args: + callback_port: Port to run the local callback server on + timeout: Maximum time to wait for OAuth callback in seconds + """ + self.callback_port = callback_port + self.timeout = timeout + self.callback_server: LocalCallbackServer | None = None + + async def handle_redirect(self, authorization_url: str) -> None: + """Open authorization URL in the user's default browser. + + Args: + authorization_url: OAuth authorization URL to open + + Raises: + OAuthNetworkError: If browser cannot be opened + """ + print(f"Opening OAuth authorization URL: {authorization_url}") + + try: + success = webbrowser.open(authorization_url) + if not success: + print("Failed to automatically open browser. Please manually open the URL above.") + raise OAuthNetworkError( + Exception("Failed to open browser - no suitable browser found"), + context={"authorization_url": authorization_url} + ) + except Exception as e: + if isinstance(e, OAuthNetworkError): + raise + print("Failed to automatically open browser. Please manually open the URL above.") + raise OAuthNetworkError( + e, + context={"authorization_url": authorization_url} + ) + + async def handle_callback(self) -> tuple[str, str | None]: + """Start local server and wait for OAuth callback. + + Returns: + Tuple of (authorization_code, state) from OAuth callback + + Raises: + OAuthCallbackError: If callback server cannot be started + OAuthTimeoutError: If callback doesn't arrive within timeout + OAuthCancellationError: If user cancels authorization + OAuthServerError: If OAuth server returns an error + """ + try: + self.callback_server = LocalCallbackServer(port=self.callback_port) + self.callback_server.start() + + auth_code = self.callback_server.wait_for_callback(timeout=self.timeout) + state = self.callback_server.get_state() + return auth_code, state + + except (OAuthTimeoutError, OAuthCancellationError, OAuthServerError, OAuthCallbackError): + # Re-raise OAuth-specific exceptions as-is + raise + except Exception as e: + # Wrap unexpected errors + raise OAuthCallbackError( + f"Unexpected error during OAuth callback handling: {str(e)}", + context={ + "port": self.callback_port, + "timeout": self.timeout, + "original_error": str(e) + } + ) + finally: + if self.callback_server: + self.callback_server.stop() diff --git a/src/mcpadapt/auth/models.py b/src/mcpadapt/auth/models.py index dd5b7ce..bfb5277 100644 --- a/src/mcpadapt/auth/models.py +++ b/src/mcpadapt/auth/models.py @@ -1,105 +1,11 @@ -"""Authentication configuration types and protocols for MCPAdapt.""" - -from abc import ABC, abstractmethod -from typing import Any, Callable, Coroutine, Protocol, Union - -from mcp.shared.auth import OAuthClientMetadata - - -class AuthConfigBase(ABC): - """Base class for all authentication configurations.""" - - @abstractmethod - def get_auth_type(self) -> str: - """Return the authentication type identifier.""" - pass - - -class OAuthConfig(AuthConfigBase): - """OAuth authentication configuration.""" - - def __init__( - self, - client_metadata: OAuthClientMetadata | dict[str, Any], - callback_handler: "CallbackHandler | None" = None, - redirect_handler: "RedirectHandler | None" = None, - ): - """Initialize OAuth configuration. - - Args: - client_metadata: OAuth client metadata or dict representation - callback_handler: Optional custom callback handler - redirect_handler: Optional custom redirect handler - """ - if isinstance(client_metadata, dict): - self.client_metadata = OAuthClientMetadata.model_validate(client_metadata) - else: - self.client_metadata = client_metadata - self.callback_handler = callback_handler - self.redirect_handler = redirect_handler - - def get_auth_type(self) -> str: - """Return OAuth auth type.""" - return "oauth" - - -class ApiKeyConfig(AuthConfigBase): - """API Key authentication configuration.""" - - def __init__(self, header_name: str, header_value: str): - """Initialize API key configuration. - - Args: - header_name: Name of the header to send the API key in - header_value: The API key value - """ - self.header_name = header_name - self.header_value = header_value - - def get_auth_type(self) -> str: - """Return API key auth type.""" - return "api_key" - - -class BearerAuthConfig(AuthConfigBase): - """Bearer token authentication configuration.""" - - def __init__(self, token: str): - """Initialize bearer auth configuration. - - Args: - token: The bearer token - """ - self.token = token - - def get_auth_type(self) -> str: - """Return bearer auth type.""" - return "bearer" - - -# Union type for all auth configurations -AuthConfig = Union[OAuthConfig, ApiKeyConfig, BearerAuthConfig] - - -class CallbackHandler(Protocol): - """Protocol for OAuth callback handlers.""" - - async def __call__(self) -> tuple[str, str | None]: - """Handle OAuth callback and return authorization code and state. - - Returns: - Tuple of (authorization_code, state) - """ - ... - - -class RedirectHandler(Protocol): - """Protocol for OAuth redirect handlers.""" - - async def __call__(self, authorization_url: str) -> None: - """Handle OAuth redirect by opening authorization URL. - - Args: - authorization_url: The OAuth authorization URL to redirect to - """ - ... +"""Authentication models for MCPAdapt - Legacy file kept for backwards compatibility.""" + +# This file previously contained OAuth/API Key/Bearer auth configuration classes +# but they have been removed in favor of direct auth provider usage. +# +# Users now create providers directly: +# - OAuthClientProvider (from MCP SDK) +# - ApiKeyAuthProvider/BearerAuthProvider (from mcpadapt.auth.providers) +# +# This file is kept empty to avoid breaking existing imports, +# but may be removed in a future version. diff --git a/src/mcpadapt/auth/providers.py b/src/mcpadapt/auth/providers.py index 12a7700..3c014e7 100644 --- a/src/mcpadapt/auth/providers.py +++ b/src/mcpadapt/auth/providers.py @@ -1,25 +1,20 @@ -"""Authentication provider factory and integration logic.""" +"""Authentication provider classes for MCPAdapt.""" from typing import Any -from urllib.parse import urlparse - -from mcp.client.auth import OAuthClientProvider - -from .handlers import default_callback_handler, default_redirect_handler -from .oauth import InMemoryTokenStorage -from .models import ApiKeyConfig, AuthConfig, BearerAuthConfig, OAuthConfig class ApiKeyAuthProvider: """Simple API key authentication provider.""" - def __init__(self, config: ApiKeyConfig): + def __init__(self, header_name: str, header_value: str): """Initialize with API key configuration. Args: - config: API key configuration + header_name: Name of the header to send the API key in + header_value: The API key value """ - self.config = config + self.header_name = header_name + self.header_value = header_value def get_headers(self) -> dict[str, str]: """Get authentication headers. @@ -27,19 +22,19 @@ def get_headers(self) -> dict[str, str]: Returns: Dictionary of headers to add to requests """ - return {self.config.header_name: self.config.header_value} + return {self.header_name: self.header_value} class BearerAuthProvider: """Simple Bearer token authentication provider.""" - def __init__(self, config: BearerAuthConfig): + def __init__(self, token: str): """Initialize with Bearer token configuration. Args: - config: Bearer token configuration + token: The bearer token """ - self.config = config + self.token = token def get_headers(self) -> dict[str, str]: """Get authentication headers. @@ -47,49 +42,7 @@ def get_headers(self) -> dict[str, str]: Returns: Dictionary of headers to add to requests """ - return {"Authorization": f"Bearer {self.config.token}"} - - -async def create_auth_provider( - auth_config: AuthConfig, server_url: str -) -> OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider: - """Factory function to create appropriate auth provider from config. - - Args: - auth_config: Authentication configuration - server_url: Server URL for OAuth (needed to determine OAuth server endpoint) - - Returns: - Appropriate auth provider instance - - Raises: - ValueError: If auth configuration type is not supported - """ - if isinstance(auth_config, OAuthConfig): - # Use provided handlers or default ones - callback_handler = auth_config.callback_handler or default_callback_handler - redirect_handler = auth_config.redirect_handler or default_redirect_handler - - # Create OAuth provider with domain root only - parsed_url = urlparse(server_url) - oauth_server_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - return OAuthClientProvider( - server_url=oauth_server_url, - client_metadata=auth_config.client_metadata, - storage=InMemoryTokenStorage(), - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - - elif isinstance(auth_config, ApiKeyConfig): - return ApiKeyAuthProvider(auth_config) - - elif isinstance(auth_config, BearerAuthConfig): - return BearerAuthProvider(auth_config) - - else: - raise ValueError(f"Unsupported auth configuration type: {type(auth_config)}") + return {"Authorization": f"Bearer {self.token}"} def get_auth_headers(auth_provider: Any) -> dict[str, str]: diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index f56b17b..d112233 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -20,8 +20,8 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client -from .auth import AuthConfig, create_auth_provider, get_auth_headers -from .auth.providers import ApiKeyAuthProvider, BearerAuthProvider +from .auth.exceptions import OAuthError +from .auth.providers import ApiKeyAuthProvider, BearerAuthProvider, get_auth_headers class ToolAdapter(ABC): @@ -199,7 +199,7 @@ def __init__( adapter: ToolAdapter, connect_timeout: int = 30, client_session_timeout_seconds: float | timedelta | None = 5, - auth_config: AuthConfig | list[AuthConfig | None] | None = None, + auth_provider: OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider | list[OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider | None] | None = None, ): """ Manage the MCP server / client lifecycle and expose tools adapted with the adapter. @@ -210,7 +210,7 @@ def __init__( adapter (ToolAdapter): Adapter to use to convert MCP tools call into agentic framework tools. connect_timeout (int): Connection timeout in seconds to the mcp server (default is 30s). client_session_timeout_seconds: Timeout for MCP ClientSession calls - auth_config: Optional authentication configuration. Can be a single config for all servers, + auth_provider: Optional authentication provider. Can be a single provider for all servers, a list matching serverparams length, or None for no authentication. Raises: @@ -224,23 +224,22 @@ def __init__( self.adapter = adapter - # Handle auth_config - ensure it matches serverparams length - if auth_config is None: - self.auth_configs = [None] * len(self.serverparams) - elif isinstance(auth_config, list): - if len(auth_config) != len(self.serverparams): + # Handle auth_provider - ensure it matches serverparams length + if auth_provider is None: + self.auth_providers = [None] * len(self.serverparams) + elif isinstance(auth_provider, list): + if len(auth_provider) != len(self.serverparams): raise ValueError( - f"auth_config list length ({len(auth_config)}) must match serverparams length ({len(self.serverparams)})" + f"auth_provider list length ({len(auth_provider)}) must match serverparams length ({len(self.serverparams)})" ) - self.auth_configs = auth_config + self.auth_providers = auth_provider else: - # Single auth config for all servers - self.auth_configs = [auth_config] * len(self.serverparams) + # Single auth provider for all servers + self.auth_providers = [auth_provider] * len(self.serverparams) # session and tools get set by the async loop during initialization. self.sessions: list[ClientSession] = [] self.mcp_tools: list[list[mcp.types.Tool]] = [] - self.auth_providers: list[Any] = [] # all attributes used to manage the async loop and separate thread. self.loop = asyncio.new_event_loop() @@ -257,20 +256,6 @@ def _run_loop(self): asyncio.set_event_loop(self.loop) async def setup(): - # Create auth providers if needed (only if not already provided) - if not self.auth_providers: - auth_providers = [] - for params, auth_config in zip(self.serverparams, self.auth_configs): - if auth_config is not None: - # Get server URL from params for OAuth - server_url = params.get("url", "") if isinstance(params, dict) else "" - auth_provider = await create_auth_provider(auth_config, server_url) - auth_providers.append(auth_provider) - else: - auth_providers.append(None) - - self.auth_providers = auth_providers - async with AsyncExitStack() as stack: connections = [ await stack.enter_async_context( @@ -380,20 +365,6 @@ async def atools(self) -> list[Any]: async def __aenter__(self) -> list[Any]: self._ctxmanager = AsyncExitStack() - # Create auth providers if needed (only if not already provided) - if not self.auth_providers: - auth_providers = [] - for params, auth_config in zip(self.serverparams, self.auth_configs): - if auth_config is not None: - # Get server URL from params for OAuth - server_url = params.get("url", "") if isinstance(params, dict) else "" - auth_provider = await create_auth_provider(auth_config, server_url) - auth_providers.append(auth_provider) - else: - auth_providers.append(None) - - self.auth_providers = auth_providers - connections = [ await self._ctxmanager.enter_async_context( mcptools(params, self.client_session_timeout_seconds, auth_provider) From cb3a9c0f64306cc7b12576c616f9d5a94b4acb44 Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sun, 7 Sep 2025 03:34:16 +0530 Subject: [PATCH 3/7] docs --- docs/auth/api-key.md | 149 +++++ docs/auth/bearer-token.md | 135 +++++ docs/auth/custom-handlers.md | 161 ++++++ docs/auth/oauth.md | 360 ++++++++++++ docs/auth/overview.md | 61 ++ docs/auth/quickstart.md | 135 +++++ examples/canva_oauth_example.py | 105 ++++ examples/oauth_with_credentials_example.py | 133 +++++ impl_sample/mcp_sample.py | 363 ++++++++++++ mkdocs.yml | 9 +- pyproject.toml | 3 + src/mcpadapt/auth/__init__.py | 101 +--- src/mcpadapt/auth/authenticate.py | 43 -- src/mcpadapt/auth/exceptions.py | 27 +- src/mcpadapt/auth/handlers.py | 103 ++-- src/mcpadapt/auth/models.py | 11 - src/mcpadapt/auth/oauth.py | 19 +- src/mcpadapt/auth/providers.py | 12 +- src/mcpadapt/core.py | 19 +- tests/auth/conftest.py | 115 ++++ tests/auth/test_core_auth.py | 574 +++++++++++++++++++ tests/auth/test_exceptions.py | 295 ++++++++++ tests/auth/test_handlers.py | 623 +++++++++++++++++++++ tests/auth/test_oauth.py | 524 +++++++++++++++++ tests/auth/test_providers.py | 297 ++++++++++ uv.lock | 122 ++++ 26 files changed, 4282 insertions(+), 217 deletions(-) create mode 100644 docs/auth/api-key.md create mode 100644 docs/auth/bearer-token.md create mode 100644 docs/auth/custom-handlers.md create mode 100644 docs/auth/oauth.md create mode 100644 docs/auth/overview.md create mode 100644 docs/auth/quickstart.md create mode 100644 examples/canva_oauth_example.py create mode 100644 examples/oauth_with_credentials_example.py create mode 100644 impl_sample/mcp_sample.py delete mode 100644 src/mcpadapt/auth/authenticate.py delete mode 100644 src/mcpadapt/auth/models.py create mode 100644 tests/auth/conftest.py create mode 100644 tests/auth/test_core_auth.py create mode 100644 tests/auth/test_exceptions.py create mode 100644 tests/auth/test_handlers.py create mode 100644 tests/auth/test_oauth.py create mode 100644 tests/auth/test_providers.py diff --git a/docs/auth/api-key.md b/docs/auth/api-key.md new file mode 100644 index 0000000..ea48e19 --- /dev/null +++ b/docs/auth/api-key.md @@ -0,0 +1,149 @@ +# API Key Authentication + +API Key authentication allows you to pass in an API Key as a header to the MCP server for authentication + +## Basic Usage + +```python +from mcpadapt.auth import ApiKeyAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Create API Key provider +api_key_provider = ApiKeyAuthProvider( + header_name="X-API-Key", + header_value="your-api-key-here" +) + +with MCPAdapt( + serverparams={"url": "https://api.example.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=api_key_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +## Custom Header Names + +Different APIs use different header names for API keys: + +```python +from mcpadapt.auth import ApiKeyAuthProvider + +# Common API key header variations +providers = [ + ApiKeyAuthProvider("X-API-Key", "key123"), # Most common + ApiKeyAuthProvider("Authorization", "key123"), # Simple auth header + ApiKeyAuthProvider("X-Auth-Token", "key123"), # Auth token variant + ApiKeyAuthProvider("X-Custom-Auth", "key123"), # Custom header +] +``` + +## Environment Variables + +Store API keys securely using environment variables: + +```python +import os +from mcpadapt.auth import ApiKeyAuthProvider + +# Load API key from environment +api_key = os.getenv("MY_API_KEY") +if not api_key: + raise ValueError("MY_API_KEY environment variable is required") + +api_key_provider = ApiKeyAuthProvider( + header_name="X-API-Key", + header_value=api_key +) +``` + +## Multiple APIs with Different Keys + +Use different API keys for different MCP servers: + +```python +import os +from mcpadapt.auth import ApiKeyAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Different API keys for different services +auth_providers = [ + ApiKeyAuthProvider("X-API-Key", os.getenv("SERVICE_A_KEY")), + ApiKeyAuthProvider("X-Auth-Token", os.getenv("SERVICE_B_KEY")), + None, # No authentication for third service +] + +server_configs = [ + {"url": "https://service-a.com/mcp", "transport": "streamable-http"}, + {"url": "https://service-b.com/mcp", "transport": "streamable-http"}, + {"url": "http://localhost:8000/sse"}, +] + +with MCPAdapt( + serverparams=server_configs, + adapter=SmolAgentsAdapter(), + auth_provider=auth_providers, +) as tools: + print(f"Connected to {len(server_configs)} servers") +``` + +## API Key Formats + +### Simple API Key + +```python +ApiKeyAuthProvider("X-API-Key", "abc123def456") +``` + +### Prefixed API Key + +```python +ApiKeyAuthProvider("X-API-Key", "Bearer abc123def456") +ApiKeyAuthProvider("Authorization", "API-Key abc123def456") +``` + +### Base64 Encoded Credentials + +```python +import base64 + +credentials = base64.b64encode(b"username:password").decode() +ApiKeyAuthProvider("Authorization", f"Basic {credentials}") +``` + +## Best Practices + +### Security +- Never hard-code API keys in source code +- Use environment variables or secure configuration files +- Rotate API keys regularly +- Use the principle of least privilege for API key permissions + +### Configuration +- Use descriptive environment variable names +- Document required API keys in your README +- Provide clear error messages for missing keys +- Validate API key format before using + +## Integration Examples + +### With Different Frameworks + +```python +# SmolAgents +from mcpadapt.smolagents_adapter import SmolAgentsAdapter +adapter = SmolAgentsAdapter() + +# CrewAI +from mcpadapt.crewai_adapter import CrewAIAdapter +adapter = CrewAIAdapter() + +# LangChain +from mcpadapt.langchain_adapter import LangChainAdapter +adapter = LangChainAdapter() + +# All use the same API key provider +api_key_provider = ApiKeyAuthProvider("X-API-Key", os.getenv("API_KEY")) +``` diff --git a/docs/auth/bearer-token.md b/docs/auth/bearer-token.md new file mode 100644 index 0000000..cc79431 --- /dev/null +++ b/docs/auth/bearer-token.md @@ -0,0 +1,135 @@ +# Bearer Token Authentication + +Bearer token authentication uses standard Authorization headers with Bearer tokens for MCP server authentication. + +## Basic Usage + +```python +from mcpadapt.auth import BearerAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Create Bearer token provider +bearer_provider = BearerAuthProvider(token="your-bearer-token") + +with MCPAdapt( + serverparams={"url": "https://api.example.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=bearer_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +## JWT Tokens + +Bearer tokens are commonly used with JWT (JSON Web Tokens): + +```python +from mcpadapt.auth import BearerAuthProvider + +# JWT token example +jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." +bearer_provider = BearerAuthProvider(token=jwt_token) +``` + +## Environment Variables + +Store bearer tokens securely: + +```python +import os +from mcpadapt.auth import BearerAuthProvider + +# Load token from environment +bearer_token = os.getenv("BEARER_TOKEN") +if not bearer_token: + raise ValueError("BEARER_TOKEN environment variable is required") + +bearer_provider = BearerAuthProvider(token=bearer_token) +``` + +## Multiple Services + +Use different bearer tokens for different MCP servers: + +```python +import os +from mcpadapt.auth import BearerAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Different tokens for different services +auth_providers = [ + BearerAuthProvider(os.getenv("SERVICE_A_TOKEN")), + BearerAuthProvider(os.getenv("SERVICE_B_TOKEN")), + None, # No authentication for third service +] + +server_configs = [ + {"url": "https://service-a.com/mcp", "transport": "streamable-http"}, + {"url": "https://service-b.com/mcp", "transport": "streamable-http"}, + {"url": "http://localhost:8000/sse"}, +] + +with MCPAdapt( + serverparams=server_configs, + adapter=SmolAgentsAdapter(), + auth_provider=auth_providers, +) as tools: + print(f"Connected to {len(server_configs)} servers") +``` + +## Token Formats + +Bearer tokens can have different formats: + +```python +from mcpadapt.auth import BearerAuthProvider + +# Standard JWT +BearerAuthProvider("eyJhbGciOiJIUzI1NiIs...") + +# Simple token +BearerAuthProvider("abc123def456ghi789") + +# API-specific format +BearerAuthProvider("sk-1234567890abcdef") +``` + +## Best Practices + +### Security +- Never hard-code bearer tokens in source code +- Use environment variables or secure configuration management +- Implement token rotation when possible +- Monitor token expiration and refresh as needed + +### Configuration +- Use descriptive environment variable names +- Validate token format before using +- Handle token expiration gracefully +- Log authentication failures for debugging + +## Integration Examples + +### With Different Frameworks + +```python +from mcpadapt.auth import BearerAuthProvider +import os + +# Same bearer provider works with all frameworks +bearer_provider = BearerAuthProvider(os.getenv("BEARER_TOKEN")) + +# SmolAgents +from mcpadapt.smolagents_adapter import SmolAgentsAdapter +adapter = SmolAgentsAdapter() + +# CrewAI +from mcpadapt.crewai_adapter import CrewAIAdapter +adapter = CrewAIAdapter() + +# LangChain +from mcpadapt.langchain_adapter import LangChainAdapter +adapter = LangChainAdapter() +``` diff --git a/docs/auth/custom-handlers.md b/docs/auth/custom-handlers.md new file mode 100644 index 0000000..5b64892 --- /dev/null +++ b/docs/auth/custom-handlers.md @@ -0,0 +1,161 @@ +# Creating Custom Handlers + +Custom OAuth handlers allow you to implement specialized authentication flows for different environments and use cases. + +## BaseOAuthHandler Interface + +All custom OAuth handlers must extend the `BaseOAuthHandler` abstract class: + +```python +from mcpadapt.auth import BaseOAuthHandler + +class CustomOAuthHandler(BaseOAuthHandler): + async def handle_redirect(self, authorization_url: str) -> None: + """Handle OAuth redirect to authorization URL.""" + # Your custom redirect logic here + pass + + async def handle_callback(self) -> tuple[str, str | None]: + """Handle OAuth callback and return authorization code and state.""" + # Your custom callback logic here + return authorization_code, state +``` + +## Headless Environment Handler + +For server environments without a browser: + +```python +from mcpadapt.auth import BaseOAuthHandler + +class HeadlessOAuthHandler(BaseOAuthHandler): + """OAuth handler for headless environments.""" + + async def handle_redirect(self, authorization_url: str) -> None: + print(f"Please open this URL in your browser:") + print(f"{authorization_url}") + print() + + async def handle_callback(self) -> tuple[str, str | None]: + auth_code = input("Enter the authorization code from the callback URL: ") + state = input("Enter the state parameter (or press Enter to skip): ").strip() + return auth_code, state or None +``` + +## Custom Callback Handler + +For applications with existing web servers: + +```python +from mcpadapt.auth import BaseOAuthHandler +import asyncio + +class CustomCallbackHandler(BaseOAuthHandler): + """OAuth handler that integrates with existing web application.""" + + def __init__(self, callback_url: str): + self.callback_url = callback_url + self.callback_data = {} + self.callback_received = asyncio.Event() + + async def handle_redirect(self, authorization_url: str) -> None: + # In a real app, you might redirect the user's current request + print(f"Redirecting to: {authorization_url}") + # Your web framework's redirect logic here + + async def handle_callback(self) -> tuple[str, str | None]: + # Wait for callback to be received by your web server + await self.callback_received.wait() + + auth_code = self.callback_data.get('code') + state = self.callback_data.get('state') + + if not auth_code: + raise ValueError("No authorization code received") + + return auth_code, state + + def receive_callback(self, code: str, state: str | None = None): + """Call this method from your web server's callback endpoint.""" + self.callback_data = {'code': code, 'state': state} + self.callback_received.set() +``` + +## CLI Integration Handler + +For command-line applications: + +```python +from mcpadapt.auth import BaseOAuthHandler +import webbrowser +import urllib.parse + +class CLIHandler(BaseOAuthHandler): + """OAuth handler optimized for CLI applications.""" + + def __init__(self, auto_open_browser: bool = True): + self.auto_open_browser = auto_open_browser + + async def handle_redirect(self, authorization_url: str) -> None: + if self.auto_open_browser: + try: + webbrowser.open(authorization_url) + print("Opening browser for authentication...") + except Exception: + print("Could not open browser automatically.") + print(f"Please open: {authorization_url}") + else: + print(f"Please open: {authorization_url}") + + async def handle_callback(self) -> tuple[str, str | None]: + print() + print("After authorizing, copy the full callback URL from your browser.") + callback_url = input("Callback URL: ").strip() + + # Parse the callback URL to extract code and state + parsed = urllib.parse.urlparse(callback_url) + query_params = urllib.parse.parse_qs(parsed.query) + + if 'code' not in query_params: + raise ValueError("No authorization code found in callback URL") + + auth_code = query_params['code'][0] + state = query_params.get('state', [None])[0] + + return auth_code, state +``` + +## Using Custom Handlers + +```python +from mcpadapt.auth import OAuthClientProvider, OAuthClientMetadata, InMemoryTokenStorage +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter +from pydantic import HttpUrl + +# Use your custom handler +custom_handler = HeadlessOAuthHandler() + +client_metadata = OAuthClientMetadata( + client_name="My Application", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) + +oauth_provider = OAuthClientProvider( + server_url="https://oauth-server.com", + client_metadata=client_metadata, + storage=InMemoryTokenStorage(), + redirect_handler=custom_handler.handle_redirect, + callback_handler=custom_handler.handle_callback, +) + +with MCPAdapt( + serverparams={"url": "https://oauth-server.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, +) as tools: + print(f"Connected with custom handler: {len(tools)} tools") +``` diff --git a/docs/auth/oauth.md b/docs/auth/oauth.md new file mode 100644 index 0000000..cfaf698 --- /dev/null +++ b/docs/auth/oauth.md @@ -0,0 +1,360 @@ +# OAuth 2.0 Authentication + +OAuth 2.0 provides secure authorization for MCP servers requiring user consent. MCPAdapt implements the authorization code flow with automatic token refresh. + +## How OAuth Works with MCPAdapt + +The built in provider helps you perform the following sequence: + +1. **Dynamic Client Registration**: Register your application with the OAuth server (if supported) + - If your OAuth Server does not support Dynamic Client Registration, be sure to populate the client credentials manually +2. **Authorization Flow**: User authorizes your application in their browser +3. **Token Exchange**: Exchange authorization code for access tokens +4. **Automatic Refresh**: Tokens are refreshed automatically when needed + +## Basic OAuth Setup + +```python +from pydantic import HttpUrl + +from mcpadapt.auth import ( + OAuthClientProvider, + OAuthClientMetadata, + InMemoryTokenStorage, + LocalBrowserOAuthHandler, +) +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Configure client metadata +client_metadata = OAuthClientMetadata( + client_name="My Application", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) + +# Set up OAuth components +oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) +token_storage = InMemoryTokenStorage() + +# Create OAuth provider +oauth_provider = OAuthClientProvider( + server_url="https://oauth-server.com", + client_metadata=client_metadata, + storage=token_storage, + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, +) + +# Use with MCPAdapt +with MCPAdapt( + serverparams={"url": "https://oauth-server.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, +) as tools: + # OAuth flow happens automatically + print(f"Authenticated with {len(tools)} tools") +``` + +## OAuth Components + +### OAuthClientMetadata + +Configure your application's OAuth settings: + +```python +from pydantic import HttpUrl +from mcpadapt.auth import OAuthClientMetadata + +client_metadata = OAuthClientMetadata( + client_name="Your App Name", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + # Optional fields: + client_uri=HttpUrl("https://yourapp.com"), + tos_uri=HttpUrl("https://yourapp.com/terms"), + policy_uri=HttpUrl("https://yourapp.com/privacy"), +) +``` + +### LocalBrowserOAuthHandler + +Handles the OAuth flow using the user's browser: + +```python +from mcpadapt.auth import LocalBrowserOAuthHandler + +# Default configuration +oauth_handler = LocalBrowserOAuthHandler() + +# Custom configuration +oauth_handler = LocalBrowserOAuthHandler( + callback_port=8080, # Custom port + timeout=600, # 10 minute timeout +) +``` + +### InMemoryTokenStorage + +Simple token storage for development and testing: + +```python +from mcpadapt.auth import InMemoryTokenStorage + +token_storage = InMemoryTokenStorage() +``` + +## Using Pre-configured OAuth Credentials + +When the OAuth server doesn't support Dynamic Client Registration (DCR), or when you have existing OAuth application credentials, you can pre-configure the token storage with your client information to skip the registration step. + +### Basic Pre-configured Setup + +```python +import os +from pydantic import HttpUrl +from mcp.shared.auth import OAuthClientInformationFull + +from mcpadapt.auth import ( + OAuthClientProvider, + OAuthClientMetadata, + InMemoryTokenStorage, + LocalBrowserOAuthHandler, +) + +# Get your OAuth app credentials (from environment variables or secure storage) +CLIENT_ID = os.getenv("OAUTH_CLIENT_ID") +CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET") +REDIRECT_URI = "http://localhost:3030/callback" + +# Create pre-configured client information +client_info = OAuthClientInformationFull( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + redirect_uris=[HttpUrl(REDIRECT_URI)] +) + +# Create token storage with pre-configured credentials +token_storage = InMemoryTokenStorage(client_info=client_info) + +# Configure client metadata (still needed for OAuth flow) +client_metadata = OAuthClientMetadata( + client_name="My Pre-configured App", + redirect_uris=[HttpUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) + +# Set up OAuth handler +oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) + +# Create OAuth provider +oauth_provider = OAuthClientProvider( + server_url="https://oauth-server.com", + client_metadata=client_metadata, + storage=token_storage, # Contains pre-configured credentials + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, +) + +# Use with MCPAdapt - DCR will be skipped +with MCPAdapt( + serverparams={"url": "https://oauth-server.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, +) as tools: + print(f"Authenticated with pre-configured credentials: {len(tools)} tools") +``` + +### When to Use Pre-configured Credentials + +Use pre-configured credentials when: + +- **Server doesn't support DCR**: Some OAuth servers don't implement Dynamic Client Registration +- **Existing OAuth app**: You already have a registered OAuth application with client credentials +- **Compliance requirements**: Your organization requires using specific pre-registered applications + +### Environment Variables for Credentials + +Store your OAuth credentials securely using environment variables: + +```bash +# Set environment variables +export OAUTH_CLIENT_ID="your-actual-client-id" +export OAUTH_CLIENT_SECRET="your-actual-client-secret" +export OAUTH_REDIRECT_URI="http://localhost:3030/callback" +``` + +Then reference them in your code: + +```python +import os +from pydantic import HttpUrl +from mcp.shared.auth import OAuthClientInformationFull +from mcpadapt.auth import InMemoryTokenStorage + +# Load from environment +client_info = OAuthClientInformationFull( + client_id=os.getenv("OAUTH_CLIENT_ID"), + client_secret=os.getenv("OAUTH_CLIENT_SECRET"), + redirect_uris=[HttpUrl(os.getenv("OAUTH_REDIRECT_URI", "http://localhost:3030/callback"))] +) + +# Create storage with pre-configured credentials +token_storage = InMemoryTokenStorage(client_info=client_info) +``` + +### Complete Example + +See `examples/oauth_with_credentials_example.py` for a complete working example of using pre-configured OAuth credentials. + +## Custom OAuth Handlers + +Create custom OAuth handlers for production environments or when you are integrating into a larger app: + +```python +from mcpadapt.auth import BaseOAuthHandler + +class HeadlessOAuthHandler(BaseOAuthHandler): + """OAuth handler for headless environments.""" + + async def handle_redirect(self, authorization_url: str) -> None: + print(f"Open this URL in your browser: {authorization_url}") + + async def handle_callback(self) -> tuple[str, str | None]: + auth_code = input("Enter the authorization code: ") + return auth_code, None + +# Use custom handler +custom_handler = HeadlessOAuthHandler() +oauth_provider = OAuthClientProvider( + server_url="https://oauth-server.com", + client_metadata=client_metadata, + storage=token_storage, + redirect_handler=custom_handler.handle_redirect, + callback_handler=custom_handler.handle_callback, +) +``` + +## Token Storage Options + +### In-Memory Storage (Development) + +```python +from mcpadapt.auth import InMemoryTokenStorage + +# Simple in-memory storage - tokens lost when application exits +storage = InMemoryTokenStorage() +``` + +### Custom Persistent Storage + +```python +from mcpadapt.auth import TokenStorage, OAuthClientInformationFull, OAuthToken +import json + +class FileTokenStorage(TokenStorage): + """File-based token storage.""" + + def __init__(self, filepath: str): + self.filepath = filepath + + async def get_tokens(self) -> OAuthToken | None: + try: + with open(self.filepath, 'r') as f: + data = json.load(f) + return OAuthToken(**data['tokens']) + except FileNotFoundError: + return None + + async def set_tokens(self, tokens: OAuthToken) -> None: + data = {'tokens': tokens.model_dump()} + with open(self.filepath, 'w') as f: + json.dump(data, f) + + async def get_client_info(self) -> OAuthClientInformationFull | None: + try: + with open(self.filepath, 'r') as f: + data = json.load(f) + return OAuthClientInformationFull(**data['client_info']) + except (FileNotFoundError, KeyError): + return None + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + try: + with open(self.filepath, 'r') as f: + data = json.load(f) + except FileNotFoundError: + data = {} + + data['client_info'] = client_info.model_dump() + with open(self.filepath, 'w') as f: + json.dump(data, f) + +# Use custom storage +storage = FileTokenStorage("oauth_tokens.json") +``` + +## Error Handling + +Handle OAuth-specific errors gracefully: + +```python +from mcpadapt.auth import ( + OAuthError, + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, + OAuthConfigurationError, + OAuthServerError, +) + +try: + with MCPAdapt( + serverparams=server_config, + adapter=adapter, + auth_provider=oauth_provider, + ) as tools: + # Use tools + pass + +except OAuthTimeoutError as e: + print(f"Authentication timed out after {e.timeout_seconds} seconds") + +except OAuthCancellationError as e: + print(f"User cancelled authorization: {e.error_details}") + +except OAuthNetworkError as e: + print(f"Network error during OAuth: {e.original_error}") + +except OAuthConfigurationError as e: + print(f"OAuth configuration error: {e.config_issue}") + +except OAuthServerError as e: + print(f"OAuth server error: {e.server_error}") + if e.error_description: + print(f"Description: {e.error_description}") + +except OAuthError as e: + print(f"General OAuth error: {e}") + if e.context: + print(f"Context: {e.context}") +``` + +## Configuration Tips + +### Port Selection +- Default port is 3030 +- Ensure the port is available and not blocked by firewalls +- Use different ports for different applications + +## Security Considerations + +- Store tokens securely in production environments, you may want to use something like Vault +- Use HTTPS for all OAuth flows in production +- Use appropriate OAuth scopes (minimal permissions) diff --git a/docs/auth/overview.md b/docs/auth/overview.md new file mode 100644 index 0000000..3a68bd1 --- /dev/null +++ b/docs/auth/overview.md @@ -0,0 +1,61 @@ +# Authentication Overview + +MCPAdapt builds upon the offical MCP python SDK and provides authentication support for connecting to secure MCP servers that require access control, rate limiting, or premium features. + +## Supported Authentication Methods + +### OAuth 2.0 +Secure authorization code flow with automatic token refresh. Best for production applications and user-facing services requiring consent. + +### API Key +Header-based authentication using API keys. Ideal for server-to-server communication and development environments. + +### Bearer Token +Standard Bearer token authentication for JWT-based systems and modern API patterns. + +## Core Components + +**Authentication Providers:** +- `OAuthClientProvider` - OAuth 2.0 authentication +- `ApiKeyAuthProvider` - API key authentication +- `BearerAuthProvider` - Bearer token authentication + +**OAuth Handlers:** +- `LocalBrowserOAuthHandler` - Browser-based OAuth flow +- `BaseOAuthHandler` - Base class for custom implementations + +**Token Storage:** +- `InMemoryTokenStorage` - In-memory token storage +- `TokenStorage` - Base class for custom implementations + +**Error Handling:** +- `OAuthTimeoutError` - Authentication timeout +- `OAuthCancellationError` - User cancelled authorization +- `OAuthNetworkError` - Network issues +- `OAuthConfigurationError` - Configuration problems +- `OAuthServerError` - Server-side errors +- `OAuthCallbackError` - Callback handling issues + +## Basic Usage + +Authentication integrates transparently into MCPAdapt: + +```python +with MCPAdapt( + serverparams=server_config, + adapter=YourAdapter(), + auth_provider=your_auth_provider, +) as tools: + # Authentication handled automatically + result = tools[0]({"param": "value"}) +``` + +## Next Steps + +- [Quick Start Guide](quickstart.md) - Get started immediately +- [OAuth 2.0 Guide](oauth.md) - Complete OAuth implementation +- [API Key Guide](api-key.md) - Key-based authentication +- [Bearer Token Guide](bearer-token.md) - Token authentication +- [Custom Handlers](custom-handlers.md) - Custom authentication flows +- [Error Handling](error-handling.md) - Handle authentication errors +- [Examples](examples.md) - Real-world examples diff --git a/docs/auth/quickstart.md b/docs/auth/quickstart.md new file mode 100644 index 0000000..0840196 --- /dev/null +++ b/docs/auth/quickstart.md @@ -0,0 +1,135 @@ +# Quick Start Guide + +Get authentication working quickly with these minimal examples. + +## OAuth 2.0 with a provider like Canva + +```python +from mcp.client.auth import OAuthClientProvider +from mcp.shared.auth import OAuthClientMetadata +from pydantic import HttpUrl + +from mcpadapt.auth import InMemoryTokenStorage, LocalBrowserOAuthHandler +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Configure OAuth +client_metadata = OAuthClientMetadata( + client_name="My App", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) + +oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) +token_storage = InMemoryTokenStorage() + +oauth_provider = OAuthClientProvider( + server_url="https://mcp.canva.com", + client_metadata=client_metadata, + storage=token_storage, + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, +) + +# Use with MCPAdapt +with MCPAdapt( + serverparams={"url": "https://mcp.canva.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +## API Key Authentication + +```python +from mcpadapt.auth import ApiKeyAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Create API Key provider +api_key_provider = ApiKeyAuthProvider( + header_name="X-API-Key", + header_value="your-api-key-here" +) + +with MCPAdapt( + serverparams={"url": "https://example.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=api_key_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +## Bearer Token Authentication + +```python +from mcpadapt.auth import BearerAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Create Bearer token provider +bearer_provider = BearerAuthProvider(token="your-bearer-token") + +with MCPAdapt( + serverparams={"url": "https://example.com/mcp", "transport": "streamable-http"}, + adapter=SmolAgentsAdapter(), + auth_provider=bearer_provider, +) as tools: + print(f"Connected with {len(tools)} tools") +``` + +## Multiple Servers with Different Authentication + +```python +from mcpadapt.auth import ApiKeyAuthProvider, BearerAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + +# Different auth for each server +auth_providers = [ + ApiKeyAuthProvider("X-API-Key", "key1"), + BearerAuthProvider("token2"), + None, # No auth for third server +] + +server_configs = [ + {"url": "https://api1.com/mcp", "transport": "streamable-http"}, + {"url": "https://api2.com/mcp", "transport": "streamable-http"}, + {"url": "http://localhost:8000/sse"}, +] + +with MCPAdapt( + serverparams=server_configs, + adapter=SmolAgentsAdapter(), + auth_provider=auth_providers, +) as tools: + print(f"Connected to {len(server_configs)} servers with {len(tools)} total tools") +``` + +## Error Handling + +```python +from mcpadapt.auth import ( + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, +) + +try: + with MCPAdapt( + serverparams=server_config, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, + ) as tools: + # Use tools + pass +except OAuthTimeoutError: + print("Authentication timed out - try again") +except OAuthCancellationError: + print("User cancelled authorization") +except OAuthNetworkError as e: + print(f"Network error: {e}") +``` diff --git a/examples/canva_oauth_example.py b/examples/canva_oauth_example.py new file mode 100644 index 0000000..c0546be --- /dev/null +++ b/examples/canva_oauth_example.py @@ -0,0 +1,105 @@ +"""Example demonstrating OAuth authentication with Canva MCP server. + +This example shows how to connect to the Canva MCP server (https://mcp.canva.com/mcp) +using OAuth authentication through MCPAdapt. + +The Canva MCP server provides tools for creating and managing Canva designs. +It fully complies with OAuth 2.0 Dynamic Client Registration +""" + +from pydantic import HttpUrl + +from mcpadapt.auth import ( + InMemoryTokenStorage, + LocalBrowserOAuthHandler, + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, + OAuthConfigurationError, + OAuthClientProvider, + OAuthClientMetadata, +) +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter + + +def main(): + """Main example function demonstrating Canva OAuth connection.""" + print("Canva MCP OAuth Example") + print("=" * 40) + + # Create OAuth client metadata + client_metadata = OAuthClientMetadata( + client_name="MCPAdapt Canva Example", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + # Create OAuth handler and token storage + oauth_handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + token_storage = InMemoryTokenStorage() + + # Create OAuth provider directly + oauth_provider = OAuthClientProvider( + server_url="https://mcp.canva.com", + client_metadata=client_metadata, + storage=token_storage, + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, + ) + + # Server configuration for Canva MCP + server_config = { + "url": "https://mcp.canva.com/mcp", + "transport": "streamable-http", + } + + print("Connecting to Canva MCP server with OAuth...") + print("This will open your browser for OAuth authorization") + print() + + try: + # Connect to Canva MCP server with OAuth authentication (sync) + with MCPAdapt( + serverparams=server_config, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, + ) as tools: + print("Successfully connected to Canva MCP server!") + print(f"Found {len(tools)} available tools:") + print() + + # List available tools + for i, tool in enumerate(tools, 1): + print(f"{i}. {tool.name}") + if hasattr(tool, "description") and tool.description: + print(f" Description: {tool.description}") + print() + + print("Connection successful! Tools are ready to use.") + + except OAuthTimeoutError as e: + print(f"OAuth authentication timed out: {e}") + print("Try again and complete the authorization in your browser quickly.") + except OAuthCancellationError as e: + print(f"OAuth authorization was cancelled: {e}") + print("You need to authorize the application to access Canva.") + except OAuthNetworkError as e: + print(f"Network error during OAuth: {e}") + print("Check your internet connection and try again.") + except OAuthConfigurationError as e: + print(f"OAuth configuration error: {e}") + print("Please check your OAuth settings.") + except Exception as e: + print(f"Failed to connect to Canva MCP server: {e}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nExample cancelled by user") + except Exception as e: + print(f"Example failed: {e}") diff --git a/examples/oauth_with_credentials_example.py b/examples/oauth_with_credentials_example.py new file mode 100644 index 0000000..1b50b8b --- /dev/null +++ b/examples/oauth_with_credentials_example.py @@ -0,0 +1,133 @@ +"""Example demonstrating OAuth with pre-configured client credentials. + +This example shows how to use your own OAuth client credentials +instead of relying on Dynamic Client Registration (DCR). This is useful when: +- The server doesn't support DCR +- You have a pre-registered OAuth application +- You want to use specific client credentials +""" + +import os +from pydantic import HttpUrl + +from mcpadapt.auth import ( + InMemoryTokenStorage, + LocalBrowserOAuthHandler, + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, + OAuthConfigurationError, + OAuthClientProvider, + OAuthClientMetadata, +) +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter +from mcp.shared.auth import OAuthClientInformationFull + + +def main(): + """Main example function demonstrating OAuth with pre-configured credentials.""" + print("OAuth with Pre-configured Credentials Example") + print("=" * 40) + + # Get credentials from environment variables + CLIENT_ID = os.getenv("OAUTH_CLIENT_ID", "your-client-id") + CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET", "your-client-secret") + REDIRECT_URI = os.getenv("OAUTH_REDIRECT_URI", "http://localhost:3030/callback") + + print(f"Using client ID: {CLIENT_ID}") + print(f"Using redirect URI: {REDIRECT_URI}") + print() + + # Create pre-configured client information + client_info = OAuthClientInformationFull( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + redirect_uris=[HttpUrl(REDIRECT_URI)], + ) + + # Create OAuth client metadata (still needed for the OAuth flow) + client_metadata = OAuthClientMetadata( + client_name="MCPAdapt Pre-configured OAuth Example", + redirect_uris=[HttpUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + # Create OAuth handler + oauth_handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + + # Create token storage WITH pre-configured client information + # This is the key difference - we pass the client_info object + token_storage = InMemoryTokenStorage(client_info=client_info) + + # Create OAuth provider + oauth_provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=token_storage, # Storage contains pre-configured credentials + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, + ) + + # Server configuration + server_config = { + "url": "https://api.example.com/mcp", + "transport": "streamable-http", + } + + print("Connecting with pre-configured OAuth credentials...") + print("This will skip Dynamic Client Registration and use your credentials") + print("This will open your browser for OAuth authorization") + print() + + try: + # Connect to MCP server with pre-configured OAuth credentials + with MCPAdapt( + serverparams=server_config, + adapter=SmolAgentsAdapter(), + auth_provider=oauth_provider, + ) as tools: + print("Successfully connected with pre-configured credentials!") + print(f"Found {len(tools)} available tools:") + print() + + # List available tools + for i, tool in enumerate(tools, 1): + print(f"{i}. {tool.name}") + if hasattr(tool, "description") and tool.description: + print(f" Description: {tool.description}") + print() + + print("Connection successful! Tools are ready to use.") + + except OAuthTimeoutError as e: + print(f"OAuth authentication timed out: {e}") + print("Try again and complete the authorization in your browser quickly.") + except OAuthCancellationError as e: + print(f"OAuth authorization was cancelled: {e}") + print("You need to authorize the application to access the service.") + except OAuthNetworkError as e: + print(f"Network error during OAuth: {e}") + print("Check your internet connection and try again.") + except OAuthConfigurationError as e: + print(f"OAuth configuration error: {e}") + print("Please check your OAuth settings and credentials.") + except Exception as e: + print(f"Failed to connect: {e}") + + +if __name__ == "__main__": + print("Set environment variables for your OAuth credentials:") + print("export OAUTH_CLIENT_ID='your-client-id'") + print("export OAUTH_CLIENT_SECRET='your-client-secret'") + print("export OAUTH_REDIRECT_URI='http://localhost:3030/callback'") + print() + + try: + main() + except KeyboardInterrupt: + print("\nExample cancelled by user") + except Exception as e: + print(f"Example failed: {e}") diff --git a/impl_sample/mcp_sample.py b/impl_sample/mcp_sample.py new file mode 100644 index 0000000..7a9e322 --- /dev/null +++ b/impl_sample/mcp_sample.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +Simple MCP client example with OAuth authentication support. + +This client connects to an MCP server using streamable HTTP transport with OAuth. + +""" + +import asyncio +import os +import threading +import time +import webbrowser +from datetime import timedelta +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any +from urllib.parse import parse_qs, urlparse + +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +class InMemoryTokenStorage(TokenStorage): + """Simple in-memory token storage implementation.""" + + def __init__(self): + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +class CallbackHandler(BaseHTTPRequestHandler): + """Simple HTTP handler to capture OAuth callback.""" + + def __init__(self, request, client_address, server, callback_data): + """Initialize with callback data storage.""" + self.callback_data = callback_data + super().__init__(request, client_address, server) + + def do_GET(self): + """Handle GET request from OAuth redirect.""" + parsed = urlparse(self.path) + query_params = parse_qs(parsed.query) + + if "code" in query_params: + self.callback_data["authorization_code"] = query_params["code"][0] + self.callback_data["state"] = query_params.get("state", [None])[0] + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b""" + + +

Authorization Successful!

+

You can close this window and return to the terminal.

+ + + + """) + elif "error" in query_params: + self.callback_data["error"] = query_params["error"][0] + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + f""" + + +

Authorization Failed

+

Error: {query_params["error"][0]}

+

You can close this window and return to the terminal.

+ + + """.encode() + ) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + """Suppress default logging.""" + pass + + +class CallbackServer: + """Simple server to handle OAuth callbacks.""" + + def __init__(self, port=3000): + self.port = port + self.server = None + self.thread = None + self.callback_data = {"authorization_code": None, "state": None, "error": None} + + def _create_handler_with_data(self): + """Create a handler class with access to callback data.""" + callback_data = self.callback_data + + class DataCallbackHandler(CallbackHandler): + def __init__(self, request, client_address, server): + super().__init__(request, client_address, server, callback_data) + + return DataCallbackHandler + + def start(self): + """Start the callback server in a background thread.""" + handler_class = self._create_handler_with_data() + self.server = HTTPServer(("localhost", self.port), handler_class) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + print(f"🖥️ Started callback server on http://localhost:{self.port}") + + def stop(self): + """Stop the callback server.""" + if self.server: + self.server.shutdown() + self.server.server_close() + if self.thread: + self.thread.join(timeout=1) + + def wait_for_callback(self, timeout=300): + """Wait for OAuth callback with timeout.""" + start_time = time.time() + while time.time() - start_time < timeout: + if self.callback_data["authorization_code"]: + return self.callback_data["authorization_code"] + elif self.callback_data["error"]: + raise Exception(f"OAuth error: {self.callback_data['error']}") + time.sleep(0.1) + raise Exception("Timeout waiting for OAuth callback") + + def get_state(self): + """Get the received state parameter.""" + return self.callback_data["state"] + + +class SimpleAuthClient: + """Simple MCP client with auth support.""" + + def __init__(self, server_url: str, transport_type: str = "streamable_http"): + self.server_url = server_url + self.transport_type = transport_type + self.session: ClientSession | None = None + + async def connect(self): + """Connect to the MCP server.""" + print(f"🔗 Attempting to connect to {self.server_url}...") + + try: + callback_server = CallbackServer(port=3030) + callback_server.start() + + async def callback_handler() -> tuple[str, str | None]: + """Wait for OAuth callback and return auth code and state.""" + print("⏳ Waiting for authorization callback...") + try: + auth_code = callback_server.wait_for_callback(timeout=300) + return auth_code, callback_server.get_state() + finally: + callback_server.stop() + + client_metadata_dict = { + "client_name": "Simple Auth Client", + "redirect_uris": ["http://localhost:3030/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + } + + async def _default_redirect_handler(authorization_url: str) -> None: + """Default redirect handler that opens the URL in a browser.""" + print(f"Opening browser for authorization: {authorization_url}") + webbrowser.open(authorization_url) + + # Create OAuth authentication handler using the new interface + oauth_auth = OAuthClientProvider( + server_url=self.server_url.replace("/mcp", ""), + client_metadata=OAuthClientMetadata.model_validate( + client_metadata_dict + ), + storage=InMemoryTokenStorage(), + redirect_handler=_default_redirect_handler, + callback_handler=callback_handler, + ) + + # Create transport with auth handler based on transport type + if self.transport_type == "sse": + print("📡 Opening SSE transport connection with auth...") + async with sse_client( + url=self.server_url, + auth=oauth_auth, + timeout=60, + ) as (read_stream, write_stream): + await self._run_session(read_stream, write_stream, None) + else: + print("📡 Opening StreamableHTTP transport connection with auth...") + async with streamablehttp_client( + url=self.server_url, + auth=oauth_auth, + timeout=timedelta(seconds=60), + ) as (read_stream, write_stream, get_session_id): + await self._run_session(read_stream, write_stream, get_session_id) + + except Exception as e: + print(f"❌ Failed to connect: {e}") + import traceback + + traceback.print_exc() + + async def _run_session(self, read_stream, write_stream, get_session_id): + """Run the MCP session with the given streams.""" + print("🤝 Initializing MCP session...") + async with ClientSession(read_stream, write_stream) as session: + self.session = session + print("⚡ Starting session initialization...") + await session.initialize() + print("✨ Session initialization complete!") + + print(f"\n✅ Connected to MCP server at {self.server_url}") + if get_session_id: + session_id = get_session_id() + if session_id: + print(f"Session ID: {session_id}") + + # Run interactive loop + await self.interactive_loop() + + async def list_tools(self): + """List available tools from the server.""" + if not self.session: + print("❌ Not connected to server") + return + + try: + result = await self.session.list_tools() + if hasattr(result, "tools") and result.tools: + print("\n📋 Available tools:") + for i, tool in enumerate(result.tools, 1): + print(f"{i}. {tool.name}") + if tool.description: + print(f" Description: {tool.description}") + print() + else: + print("No tools available") + except Exception as e: + print(f"❌ Failed to list tools: {e}") + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): + """Call a specific tool.""" + if not self.session: + print("❌ Not connected to server") + return + + try: + result = await self.session.call_tool(tool_name, arguments or {}) + print(f"\n🔧 Tool '{tool_name}' result:") + if hasattr(result, "content"): + for content in result.content: + if content.type == "text": + print(content.text) + else: + print(content) + else: + print(result) + except Exception as e: + print(f"❌ Failed to call tool '{tool_name}': {e}") + + async def interactive_loop(self): + """Run interactive command loop.""" + print("\n🎯 Interactive MCP Client") + print("Commands:") + print(" list - List available tools") + print(" call [args] - Call a tool") + print(" quit - Exit the client") + print() + + while True: + try: + command = input("mcp> ").strip() + + if not command: + continue + + if command == "quit": + break + + elif command == "list": + await self.list_tools() + + elif command.startswith("call "): + parts = command.split(maxsplit=2) + tool_name = parts[1] if len(parts) > 1 else "" + + if not tool_name: + print("❌ Please specify a tool name") + continue + + # Parse arguments (simple JSON-like format) + arguments = {} + if len(parts) > 2: + import json + + try: + arguments = json.loads(parts[2]) + except json.JSONDecodeError: + print("❌ Invalid arguments format (expected JSON)") + continue + + await self.call_tool(tool_name, arguments) + + else: + print( + "❌ Unknown command. Try 'list', 'call ', or 'quit'" + ) + + except KeyboardInterrupt: + print("\n\n👋 Goodbye!") + break + except EOFError: + break + + +async def main(): + """Main entry point.""" + # Default server URL - can be overridden with environment variable + # Most MCP streamable HTTP servers use /mcp as the endpoint + server_url = os.getenv("MCP_SERVER_PORT", 8000) + transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http") + server_url = ( + f"http://localhost:{server_url}/mcp" + if transport_type == "streamable_http" + else f"http://localhost:{server_url}/sse" + ) + + print("🚀 Simple MCP Auth Client") + print(f"Connecting to: {server_url}") + print(f"Transport type: {transport_type}") + + # Start connection flow - OAuth will be handled automatically + client = SimpleAuthClient(server_url, transport_type) + await client.connect() + + +def cli(): + """CLI entry point for uv script.""" + asyncio.run(main()) + + +if __name__ == "__main__": + cli() diff --git a/mkdocs.yml b/mkdocs.yml index 4488318..a1a6877 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,13 @@ theme: nav: - Intro: index.md - Quickstart: quickstart.md + - Authentication: + - Overview: auth/overview.md + - Quick Start: auth/quickstart.md + - OAuth 2.0: auth/oauth.md + - API Key: auth/api-key.md + - Bearer Token: auth/bearer-token.md + - Custom Handlers: auth/custom-handlers.md - Guided Examples: - guide/smolagents.md - guide/crewai.md @@ -94,4 +101,4 @@ extra_css: - stylesheets/extra.css watch: - - "src/mcpadapt" \ No newline at end of file + - "src/mcpadapt" diff --git a/pyproject.toml b/pyproject.toml index 07a0934..fb42c4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ test = [ "pytest-asyncio>=0.25.2", "pytest>=8.3.4", "pytest-datadir>=1.7.2", + "pytest-cov>=4.0.0", + "pytest-mock>=3.10.0", + "coverage[toml]>=7.0.0", "mcpadapt[langchain]", "mcpadapt[smolagents]", "mcpadapt[crewai]", diff --git a/src/mcpadapt/auth/__init__.py b/src/mcpadapt/auth/__init__.py index d068f04..d1f783d 100644 --- a/src/mcpadapt/auth/__init__.py +++ b/src/mcpadapt/auth/__init__.py @@ -2,89 +2,19 @@ This module provides OAuth, API Key, and Bearer token authentication support for MCP servers. - -Example usage with OAuth: - -```python -from mcp.client.auth import OAuthClientProvider -from mcp.shared.auth import OAuthClientMetadata -from pydantic import HttpUrl - -from mcpadapt.auth import ( - InMemoryTokenStorage, - LocalBrowserOAuthHandler -) -from mcpadapt.core import MCPAdapt -from mcpadapt.smolagents_adapter import SmolAgentsAdapter - -# Create OAuth provider directly -client_metadata = OAuthClientMetadata( - client_name="My App", - redirect_uris=[HttpUrl("http://localhost:3030/callback")], - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - token_endpoint_auth_method="client_secret_post", -) - -oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) -token_storage = InMemoryTokenStorage() - -oauth_provider = OAuthClientProvider( - server_url="https://example.com", - client_metadata=client_metadata, - storage=token_storage, - redirect_handler=oauth_handler.handle_redirect, - callback_handler=oauth_handler.handle_callback, -) - -# Use with MCPAdapt -with MCPAdapt( - serverparams={"url": "https://example.com/mcp", "transport": "streamable-http"}, - adapter=SmolAgentsAdapter(), - auth_provider=oauth_provider, -) as tools: - print(f"Connected with {len(tools)} tools") -``` - -Example usage with API Key: - -```python -from mcpadapt.auth import ApiKeyAuthProvider -from mcpadapt.core import MCPAdapt -from mcpadapt.smolagents_adapter import SmolAgentsAdapter - -# Create API Key provider -api_key_provider = ApiKeyAuthProvider( - header_name="X-API-Key", - header_value="your-api-key-here" -) - -with MCPAdapt( - serverparams={"url": "https://example.com/mcp", "transport": "streamable-http"}, - adapter=SmolAgentsAdapter(), - auth_provider=api_key_provider, -) as tools: - print(f"Connected with {len(tools)} tools") -``` - -For custom implementations, extend BaseOAuthHandler: - -```python -from mcpadapt.auth import BaseOAuthHandler - -class CustomOAuthHandler(BaseOAuthHandler): - async def handle_redirect(self, authorization_url: str) -> None: - # Custom redirect logic (e.g., print URL for headless environments) - print(f"Please open: {authorization_url}") - - async def handle_callback(self) -> tuple[str, str | None]: - # Custom callback logic (e.g., manual code input) - auth_code = input("Enter authorization code: ") - return auth_code, None -``` """ from .oauth import InMemoryTokenStorage +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthToken, + InvalidScopeError, + OAuthClientMetadata, + InvalidRedirectUriError, + OAuthMetadata, + ProtectedResourceMetadata, +) +from mcp.client.auth import TokenStorage, OAuthClientProvider from .handlers import ( BaseOAuthHandler, LocalBrowserOAuthHandler, @@ -125,4 +55,15 @@ async def handle_callback(self) -> tuple[str, str | None]: "OAuthConfigurationError", "OAuthServerError", "OAuthCallbackError", + # Re-exported classes from mcp.client.auth + "TokenStorage", + "OAuthClientProvider", + # Re-exported classes from mcp.shared.auth + "OAuthClientInformationFull", + "OAuthToken", + "InvalidScopeError", + "OAuthClientMetadata", + "InvalidRedirectUriError", + "OAuthMetadata", + "ProtectedResourceMetadata", ] diff --git a/src/mcpadapt/auth/authenticate.py b/src/mcpadapt/auth/authenticate.py deleted file mode 100644 index 519694a..0000000 --- a/src/mcpadapt/auth/authenticate.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Authentication utilities for pre-authenticating providers.""" - -from typing import Any - -from mcp.client.auth import OAuthClientProvider - -from .providers import ApiKeyAuthProvider, BearerAuthProvider, create_auth_provider -from .models import AuthConfig, OAuthConfig - - -async def authenticate( - auth_config: AuthConfig, - server_url: str -) -> OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider: - """ - Create and prepare an auth provider for use with MCPAdapt. - - For OAuth: Creates a configured OAuth provider that will perform the OAuth flow - (browser redirect, callback, token exchange) when first used by MCPAdapt - For API Key/Bearer: Creates a ready-to-use provider (no additional flow needed) - - Args: - auth_config: Authentication configuration - server_url: Server URL (needed for OAuth server endpoint discovery) - - Returns: - Auth provider ready to use with MCPAdapt - - Example: - >>> # Prepare OAuth provider - >>> oauth_config = OAuthConfig(client_metadata={...}) - >>> auth_provider = await authenticate(oauth_config, "https://mcp.canva.com/mcp") - >>> - >>> # Use with MCPAdapt - OAuth flow will happen during connection - >>> with MCPAdapt(server_config, adapter, auth_provider=auth_provider) as tools: - >>> print(tools) - - Note: - For OAuth, the actual authentication flow (browser redirect, token exchange) - occurs when MCPAdapt makes its first connection to the MCP server. This function - prepares the OAuth provider with all necessary configuration. - """ - return await create_auth_provider(auth_config, server_url) diff --git a/src/mcpadapt/auth/exceptions.py b/src/mcpadapt/auth/exceptions.py index c06fad5..6593827 100644 --- a/src/mcpadapt/auth/exceptions.py +++ b/src/mcpadapt/auth/exceptions.py @@ -3,10 +3,12 @@ class OAuthError(Exception): """Base class for all OAuth authentication errors.""" - - def __init__(self, message: str, error_code: str | None = None, context: dict | None = None): + + def __init__( + self, message: str, error_code: str | None = None, context: dict | None = None + ): """Initialize OAuth error. - + Args: message: Human-readable error message error_code: Machine-readable error code (optional) @@ -19,7 +21,7 @@ def __init__(self, message: str, error_code: str | None = None, context: dict | class OAuthTimeoutError(OAuthError): """Raised when OAuth callback doesn't arrive within the specified timeout.""" - + def __init__(self, timeout_seconds: int, context: dict | None = None): message = ( f"OAuth authentication timed out after {timeout_seconds} seconds. " @@ -32,7 +34,7 @@ def __init__(self, timeout_seconds: int, context: dict | None = None): class OAuthCancellationError(OAuthError): """Raised when the user cancels or denies the OAuth authorization.""" - + def __init__(self, error_details: str | None = None, context: dict | None = None): base_message = "OAuth authorization was cancelled or denied by the user." if error_details: @@ -45,7 +47,7 @@ def __init__(self, error_details: str | None = None, context: dict | None = None class OAuthNetworkError(OAuthError): """Raised when network-related issues prevent OAuth completion.""" - + def __init__(self, original_error: Exception, context: dict | None = None): message = ( f"OAuth authentication failed due to network error: {str(original_error)}. " @@ -57,7 +59,7 @@ def __init__(self, original_error: Exception, context: dict | None = None): class OAuthConfigurationError(OAuthError): """Raised when OAuth configuration is invalid or incomplete.""" - + def __init__(self, config_issue: str, context: dict | None = None): message = f"OAuth configuration error: {config_issue}" super().__init__(message, "oauth_config_error", context) @@ -66,8 +68,13 @@ def __init__(self, config_issue: str, context: dict | None = None): class OAuthServerError(OAuthError): """Raised when the OAuth server returns an error response.""" - - def __init__(self, server_error: str, error_description: str | None = None, context: dict | None = None): + + def __init__( + self, + server_error: str, + error_description: str | None = None, + context: dict | None = None, + ): message = f"OAuth server error: {server_error}" if error_description: message += f" - {error_description}" @@ -78,7 +85,7 @@ def __init__(self, server_error: str, error_description: str | None = None, cont class OAuthCallbackError(OAuthError): """Raised when there's an issue with the OAuth callback handling.""" - + def __init__(self, callback_issue: str, context: dict | None = None): message = f"OAuth callback error: {callback_issue}" super().__init__(message, "oauth_callback_error", context) diff --git a/src/mcpadapt/auth/handlers.py b/src/mcpadapt/auth/handlers.py index 5de9173..1efd7e8 100644 --- a/src/mcpadapt/auth/handlers.py +++ b/src/mcpadapt/auth/handlers.py @@ -5,7 +5,6 @@ import webbrowser from abc import ABC, abstractmethod from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any from urllib.parse import parse_qs, urlparse from .exceptions import ( @@ -75,7 +74,7 @@ class LocalCallbackServer: def __init__(self, port: int = 3030): """Initialize callback server. - + Args: port: Port to listen on for OAuth callbacks """ @@ -96,30 +95,32 @@ def __init__(self, request, client_address, server): def start(self) -> None: """Start the callback server in a background thread. - + Raises: OAuthCallbackError: If server cannot be started """ try: handler_class = self._create_handler_with_data() self.server = HTTPServer(("localhost", self.port), handler_class) - self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread = threading.Thread( + target=self.server.serve_forever, daemon=True + ) self.thread.start() except OSError as e: if e.errno == 48: # Address already in use raise OAuthCallbackError( f"Port {self.port} is already in use. Try using a different port or check if another OAuth flow is running.", - context={"port": self.port, "original_error": str(e)} + context={"port": self.port, "original_error": str(e)}, ) else: raise OAuthCallbackError( f"Failed to start OAuth callback server on port {self.port}: {str(e)}", - context={"port": self.port, "original_error": str(e)} + context={"port": self.port, "original_error": str(e)}, ) except Exception as e: raise OAuthCallbackError( f"Unexpected error starting OAuth callback server: {str(e)}", - context={"port": self.port, "original_error": str(e)} + context={"port": self.port, "original_error": str(e)}, ) def stop(self) -> None: @@ -132,13 +133,13 @@ def stop(self) -> None: def wait_for_callback(self, timeout: int = 300) -> str: """Wait for OAuth callback with timeout. - + Args: timeout: Maximum time to wait for callback in seconds - + Returns: Authorization code from OAuth callback - + Raises: OAuthTimeoutError: If timeout occurs OAuthCancellationError: If user cancels authorization @@ -151,29 +152,20 @@ def wait_for_callback(self, timeout: int = 300) -> str: elif self.callback_data["error"]: error = self.callback_data["error"] context = {"port": self.port, "timeout": timeout} - + # Map common OAuth error codes to specific exceptions if error in ["access_denied", "user_cancelled"]: - raise OAuthCancellationError( - error_details=error, - context=context - ) + raise OAuthCancellationError(error_details=error, context=context) else: # Generic server error for other OAuth errors - raise OAuthServerError( - server_error=error, - context=context - ) + raise OAuthServerError(server_error=error, context=context) time.sleep(0.1) - - raise OAuthTimeoutError( - timeout_seconds=timeout, - context={"port": self.port} - ) + + raise OAuthTimeoutError(timeout_seconds=timeout, context={"port": self.port}) def get_state(self) -> str | None: """Get the received state parameter. - + Returns: OAuth state parameter or None """ @@ -182,28 +174,27 @@ def get_state(self) -> str | None: class BaseOAuthHandler(ABC): """Base class for OAuth authentication handlers. - - Combines redirect and callback handling into a single cohesive interface. + Subclasses should implement both the redirect flow (opening authorization URL) and callback flow (receiving authorization code). """ - + @abstractmethod async def handle_redirect(self, authorization_url: str) -> None: """Handle OAuth redirect to authorization URL. - + Args: authorization_url: The OAuth authorization URL to redirect the user to """ pass - + @abstractmethod async def handle_callback(self) -> tuple[str, str | None]: """Handle OAuth callback and return authorization code and state. - + Returns: Tuple of (authorization_code, state) received from OAuth provider - + Raises: Exception: If OAuth flow fails or times out """ @@ -212,15 +203,15 @@ async def handle_callback(self) -> tuple[str, str | None]: class LocalBrowserOAuthHandler(BaseOAuthHandler): """Default OAuth handler using local browser and callback server. - + Opens authorization URL in the user's default browser and starts a local HTTP server to receive the OAuth callback. This is the most user-friendly approach for desktop applications. """ - + def __init__(self, callback_port: int = 3030, timeout: int = 300): """Initialize the local browser OAuth handler. - + Args: callback_port: Port to run the local callback server on timeout: Maximum time to wait for OAuth callback in seconds @@ -228,41 +219,42 @@ def __init__(self, callback_port: int = 3030, timeout: int = 300): self.callback_port = callback_port self.timeout = timeout self.callback_server: LocalCallbackServer | None = None - + async def handle_redirect(self, authorization_url: str) -> None: """Open authorization URL in the user's default browser. - + Args: authorization_url: OAuth authorization URL to open - + Raises: OAuthNetworkError: If browser cannot be opened """ print(f"Opening OAuth authorization URL: {authorization_url}") - + try: success = webbrowser.open(authorization_url) if not success: - print("Failed to automatically open browser. Please manually open the URL above.") + print( + "Failed to automatically open browser. Please manually open the URL above." + ) raise OAuthNetworkError( Exception("Failed to open browser - no suitable browser found"), - context={"authorization_url": authorization_url} + context={"authorization_url": authorization_url}, ) except Exception as e: if isinstance(e, OAuthNetworkError): raise - print("Failed to automatically open browser. Please manually open the URL above.") - raise OAuthNetworkError( - e, - context={"authorization_url": authorization_url} + print( + "Failed to automatically open browser. Please manually open the URL above." ) - + raise OAuthNetworkError(e, context={"authorization_url": authorization_url}) + async def handle_callback(self) -> tuple[str, str | None]: """Start local server and wait for OAuth callback. - + Returns: Tuple of (authorization_code, state) from OAuth callback - + Raises: OAuthCallbackError: If callback server cannot be started OAuthTimeoutError: If callback doesn't arrive within timeout @@ -272,12 +264,17 @@ async def handle_callback(self) -> tuple[str, str | None]: try: self.callback_server = LocalCallbackServer(port=self.callback_port) self.callback_server.start() - + auth_code = self.callback_server.wait_for_callback(timeout=self.timeout) state = self.callback_server.get_state() return auth_code, state - - except (OAuthTimeoutError, OAuthCancellationError, OAuthServerError, OAuthCallbackError): + + except ( + OAuthTimeoutError, + OAuthCancellationError, + OAuthServerError, + OAuthCallbackError, + ): # Re-raise OAuth-specific exceptions as-is raise except Exception as e: @@ -287,8 +284,8 @@ async def handle_callback(self) -> tuple[str, str | None]: context={ "port": self.callback_port, "timeout": self.timeout, - "original_error": str(e) - } + "original_error": str(e), + }, ) finally: if self.callback_server: diff --git a/src/mcpadapt/auth/models.py b/src/mcpadapt/auth/models.py deleted file mode 100644 index bfb5277..0000000 --- a/src/mcpadapt/auth/models.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Authentication models for MCPAdapt - Legacy file kept for backwards compatibility.""" - -# This file previously contained OAuth/API Key/Bearer auth configuration classes -# but they have been removed in favor of direct auth provider usage. -# -# Users now create providers directly: -# - OAuthClientProvider (from MCP SDK) -# - ApiKeyAuthProvider/BearerAuthProvider (from mcpadapt.auth.providers) -# -# This file is kept empty to avoid breaking existing imports, -# but may be removed in a future version. diff --git a/src/mcpadapt/auth/oauth.py b/src/mcpadapt/auth/oauth.py index 6195b8b..7a2b20a 100644 --- a/src/mcpadapt/auth/oauth.py +++ b/src/mcpadapt/auth/oauth.py @@ -7,14 +7,19 @@ class InMemoryTokenStorage(TokenStorage): """Simple in-memory token storage implementation.""" - def __init__(self): - """Initialize empty token storage.""" + def __init__(self, client_info: OAuthClientInformationFull | None = None): + """Initialize token storage, optionally with pre-configured client credentials. + + Args: + client_info: Optional OAuth client information to pre-configure. + If provided, skips Dynamic Client Registration. + """ self._tokens: OAuthToken | None = None - self._client_info: OAuthClientInformationFull | None = None + self._client_info = client_info async def get_tokens(self) -> OAuthToken | None: """Get stored OAuth tokens. - + Returns: Stored OAuth tokens or None if not available """ @@ -22,7 +27,7 @@ async def get_tokens(self) -> OAuthToken | None: async def set_tokens(self, tokens: OAuthToken) -> None: """Store OAuth tokens. - + Args: tokens: OAuth tokens to store """ @@ -30,7 +35,7 @@ async def set_tokens(self, tokens: OAuthToken) -> None: async def get_client_info(self) -> OAuthClientInformationFull | None: """Get stored OAuth client information. - + Returns: Stored OAuth client information or None if not available """ @@ -38,7 +43,7 @@ async def get_client_info(self) -> OAuthClientInformationFull | None: async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: """Store OAuth client information. - + Args: client_info: OAuth client information to store """ diff --git a/src/mcpadapt/auth/providers.py b/src/mcpadapt/auth/providers.py index 3c014e7..5c82a90 100644 --- a/src/mcpadapt/auth/providers.py +++ b/src/mcpadapt/auth/providers.py @@ -8,7 +8,7 @@ class ApiKeyAuthProvider: def __init__(self, header_name: str, header_value: str): """Initialize with API key configuration. - + Args: header_name: Name of the header to send the API key in header_value: The API key value @@ -18,7 +18,7 @@ def __init__(self, header_name: str, header_value: str): def get_headers(self) -> dict[str, str]: """Get authentication headers. - + Returns: Dictionary of headers to add to requests """ @@ -30,7 +30,7 @@ class BearerAuthProvider: def __init__(self, token: str): """Initialize with Bearer token configuration. - + Args: token: The bearer token """ @@ -38,7 +38,7 @@ def __init__(self, token: str): def get_headers(self) -> dict[str, str]: """Get authentication headers. - + Returns: Dictionary of headers to add to requests """ @@ -47,10 +47,10 @@ def get_headers(self) -> dict[str, str]: def get_auth_headers(auth_provider: Any) -> dict[str, str]: """Get authentication headers from provider. - + Args: auth_provider: Authentication provider instance - + Returns: Dictionary of headers to add to requests """ diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index d112233..d3246b5 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -20,7 +20,6 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client -from .auth.exceptions import OAuthError from .auth.providers import ApiKeyAuthProvider, BearerAuthProvider, get_auth_headers @@ -106,7 +105,7 @@ async def mcptools( # Create a deep copy to avoid modifying the original dict client_params = copy.deepcopy(serverparams) transport = client_params.pop("transport", "sse") - + # Add authentication if provided if auth_provider is not None: if isinstance(auth_provider, OAuthClientProvider): @@ -118,7 +117,7 @@ async def mcptools( client_params["headers"].update(headers) else: client_params["headers"] = headers - + if transport == "sse": client = sse_client(**client_params) elif transport == "streamable-http": @@ -199,7 +198,11 @@ def __init__( adapter: ToolAdapter, connect_timeout: int = 30, client_session_timeout_seconds: float | timedelta | None = 5, - auth_provider: OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider | list[OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider | None] | None = None, + auth_provider: OAuthClientProvider + | ApiKeyAuthProvider + | BearerAuthProvider + | list[OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider | None] + | None = None, ): """ Manage the MCP server / client lifecycle and expose tools adapted with the adapter. @@ -259,9 +262,13 @@ async def setup(): async with AsyncExitStack() as stack: connections = [ await stack.enter_async_context( - mcptools(params, self.client_session_timeout_seconds, auth_provider) + mcptools( + params, self.client_session_timeout_seconds, auth_provider + ) + ) + for params, auth_provider in zip( + self.serverparams, self.auth_providers ) - for params, auth_provider in zip(self.serverparams, self.auth_providers) ] self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] self.ready.set() # Signal initialization is complete diff --git a/tests/auth/conftest.py b/tests/auth/conftest.py new file mode 100644 index 0000000..694f341 --- /dev/null +++ b/tests/auth/conftest.py @@ -0,0 +1,115 @@ +"""Shared fixtures for authentication tests.""" + +import pytest +from unittest.mock import Mock, AsyncMock +from mcp.shared.auth import OAuthClientMetadata, OAuthToken, OAuthClientInformationFull +from pydantic import HttpUrl + + +@pytest.fixture +def mock_oauth_client_metadata(): + """Mock OAuth client metadata.""" + return OAuthClientMetadata( + client_name="Test App", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + +@pytest.fixture +def mock_oauth_token(): + """Mock OAuth token.""" + return OAuthToken( + access_token="test_access_token", + token_type="Bearer", + expires_in=3600, + refresh_token="test_refresh_token", + scope="read write", + ) + + +@pytest.fixture +def mock_oauth_client_info(): + """Mock OAuth client information.""" + return OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + client_id_issued_at=1234567890, + client_secret_expires_at=0, + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + ) + + +@pytest.fixture +def mock_webbrowser(mocker): + """Mock webbrowser module.""" + return mocker.patch("mcpadapt.auth.handlers.webbrowser") + + +@pytest.fixture +def mock_http_server(mocker): + """Mock HTTPServer.""" + mock_server = Mock() + mock_server.serve_forever = Mock() + mock_server.shutdown = Mock() + mock_server.server_close = Mock() + + mock_server_class = mocker.patch("mcpadapt.auth.handlers.HTTPServer") + mock_server_class.return_value = mock_server + + return mock_server, mock_server_class + + +@pytest.fixture +def mock_threading(mocker): + """Mock threading module.""" + mock_thread = Mock() + mock_thread.start = Mock() + mock_thread.join = Mock() + + mock_thread_class = mocker.patch("mcpadapt.auth.handlers.threading.Thread") + mock_thread_class.return_value = mock_thread + + return mock_thread, mock_thread_class + + +@pytest.fixture +def mock_time(mocker): + """Mock time module.""" + return mocker.patch("mcpadapt.auth.handlers.time") + + +@pytest.fixture +def callback_data_success(): + """Mock successful callback data.""" + return { + "authorization_code": "test_auth_code", + "state": "test_state", + "error": None, + } + + +@pytest.fixture +def callback_data_error(): + """Mock error callback data.""" + return {"authorization_code": None, "state": None, "error": "access_denied"} + + +@pytest.fixture +def mock_mcp_client_session(): + """Mock MCP ClientSession.""" + session = AsyncMock() + session.initialize = AsyncMock() + session.list_tools = AsyncMock() + session.call_tool = AsyncMock() + return session + + +@pytest.fixture +def mock_oauth_client_provider(): + """Mock OAuthClientProvider.""" + provider = Mock() + provider.get_headers = Mock(return_value={"Authorization": "Bearer test_token"}) + return provider diff --git a/tests/auth/test_core_auth.py b/tests/auth/test_core_auth.py new file mode 100644 index 0000000..accc5bf --- /dev/null +++ b/tests/auth/test_core_auth.py @@ -0,0 +1,574 @@ +"""Tests for authentication integration with MCPAdapt core.""" + +import pytest +from typing import Any, Callable, Coroutine +from unittest.mock import Mock, patch, AsyncMock + +from mcpadapt.core import MCPAdapt, mcptools, ToolAdapter +from mcpadapt.auth.providers import ApiKeyAuthProvider, BearerAuthProvider + + +class DummyAdapter(ToolAdapter): + """Dummy adapter for testing.""" + + def adapt(self, func: Callable[[dict[str, Any] | None], Any], mcp_tool: Any) -> Any: + return func + + def async_adapt( + self, + afunc: Callable[[dict[str, Any] | None], Coroutine[Any, Any, Any]], + mcp_tool: Any, + ) -> Any: + return afunc + + +class TestMCPAdaptAuthIntegration: + """Test authentication integration with MCPAdapt.""" + + def test_api_key_auth_provider_sync(self): + """Test API key authentication with sync MCPAdapt.""" + auth_provider = ApiKeyAuthProvider("X-API-Key", "test-key-123") + + # Mock server parameters for streamable-http + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + # Mock the mcptools function to verify auth headers are passed + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_provider, + ) + + # Verify auth provider is stored + assert adapter.auth_providers[0] is auth_provider + + def test_bearer_auth_provider_sync(self): + """Test Bearer token authentication with sync MCPAdapt.""" + auth_provider = BearerAuthProvider("bearer-token-456") + + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_provider, + ) + + assert adapter.auth_providers[0] is auth_provider + + def test_multiple_servers_different_auth(self): + """Test multiple servers with different authentication.""" + api_key_provider = ApiKeyAuthProvider("X-API-Key", "api-key") + bearer_provider = BearerAuthProvider("bearer-token") + + server_params = [ + {"url": "https://api1.example.com/mcp", "transport": "streamable-http"}, + {"url": "https://api2.example.com/mcp", "transport": "streamable-http"}, + ] + + auth_providers = [api_key_provider, bearer_provider] + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_providers, + ) + + assert len(adapter.auth_providers) == 2 + assert adapter.auth_providers[0] is api_key_provider + assert adapter.auth_providers[1] is bearer_provider + + def test_auth_provider_list_length_mismatch(self): + """Test auth provider list length mismatch raises error.""" + server_params = [ + {"url": "https://api1.example.com/mcp", "transport": "streamable-http"}, + {"url": "https://api2.example.com/mcp", "transport": "streamable-http"}, + ] + + # Only one auth provider for two servers + auth_providers = [ApiKeyAuthProvider("X-API-Key", "key")] + + with pytest.raises(ValueError) as exc_info: + MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_providers, + ) + + assert ( + "auth_provider list length (1) must match serverparams length (2)" + in str(exc_info.value) + ) + + def test_single_auth_provider_for_multiple_servers(self): + """Test single auth provider applied to multiple servers.""" + auth_provider = ApiKeyAuthProvider("X-API-Key", "shared-key") + + server_params = [ + {"url": "https://api1.example.com/mcp", "transport": "streamable-http"}, + {"url": "https://api2.example.com/mcp", "transport": "streamable-http"}, + ] + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_provider, + ) + + # Should replicate auth provider for each server + assert len(adapter.auth_providers) == 2 + assert adapter.auth_providers[0] is auth_provider + assert adapter.auth_providers[1] is auth_provider + + def test_no_auth_provider(self): + """Test MCPAdapt without authentication.""" + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt(serverparams=server_params, adapter=DummyAdapter()) + + # Should have None for auth providers + assert len(adapter.auth_providers) == 1 + assert adapter.auth_providers[0] is None + + +class TestMCPToolsAuthIntegration: + """Test authentication integration with mcptools function.""" + + @pytest.mark.asyncio + async def test_mcptools_with_api_key_auth(self): + """Test mcptools with API key authentication.""" + auth_provider = ApiKeyAuthProvider("X-API-Key", "test-key") + + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_session = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + with patch("mcpadapt.core.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + return_value=mock_session + ) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock() + mock_session.list_tools.return_value.tools = [] + + async with mcptools(server_params, auth_provider=auth_provider) as ( + session, + tools, + ): + assert session is mock_session + assert tools == [] + + # Verify auth provider was used in client setup + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert "headers" in call_kwargs + assert call_kwargs["headers"]["X-API-Key"] == "test-key" + + @pytest.mark.asyncio + async def test_mcptools_with_bearer_auth(self): + """Test mcptools with Bearer token authentication.""" + auth_provider = BearerAuthProvider("test-token") + + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_session = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + with patch("mcpadapt.core.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + return_value=mock_session + ) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock() + mock_session.list_tools.return_value.tools = [] + + async with mcptools(server_params, auth_provider=auth_provider) as ( + session, + tools, + ): + assert session is mock_session + assert tools == [] + + # Verify auth provider was used + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert "headers" in call_kwargs + assert ( + call_kwargs["headers"]["Authorization"] == "Bearer test-token" + ) + + @pytest.mark.asyncio + async def test_mcptools_oauth_provider(self, mock_oauth_client_provider): + """Test mcptools with OAuth provider.""" + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_session = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + with patch("mcpadapt.core.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + return_value=mock_session + ) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock() + mock_session.list_tools.return_value.tools = [] + + async with mcptools( + server_params, auth_provider=mock_oauth_client_provider + ) as (session, tools): + assert session is mock_session + assert tools == [] + + # Verify OAuth provider was passed to auth parameter + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert call_kwargs["auth"] is mock_oauth_client_provider + + @pytest.mark.asyncio + async def test_mcptools_sse_transport_with_auth(self): + """Test mcptools with SSE transport and authentication.""" + auth_provider = ApiKeyAuthProvider("X-API-Key", "sse-key") + + server_params = {"url": "https://example.com/sse", "transport": "sse"} + + with patch("mcpadapt.core.sse_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_session = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + with patch("mcpadapt.core.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + return_value=mock_session + ) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock() + mock_session.list_tools.return_value.tools = [] + + async with mcptools(server_params, auth_provider=auth_provider) as ( + session, + tools, + ): + assert session is mock_session + + # Verify SSE client was used with auth headers + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert "headers" in call_kwargs + assert call_kwargs["headers"]["X-API-Key"] == "sse-key" + + +class TestAuthErrorHandling: + """Test error handling in authentication scenarios.""" + + def test_invalid_transport_with_auth(self): + """Test invalid transport parameter with authentication.""" + # Note: MCPAdapt doesn't validate transport until connection time + # So this test should verify the auth provider is stored correctly even with invalid transport + auth_provider = ApiKeyAuthProvider("X-API-Key", "key") + + server_params = { + "url": "https://example.com/mcp", + "transport": "invalid-transport", + } + + # MCPAdapt should accept any transport type - validation happens later + with patch("mcpadapt.core.streamablehttp_client"): + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_provider, + ) + + # Verify auth provider is stored correctly + assert adapter.auth_providers[0] is auth_provider + + @pytest.mark.asyncio + async def test_auth_provider_header_merge(self): + """Test that auth headers are properly merged with existing headers.""" + auth_provider = ApiKeyAuthProvider("X-API-Key", "merge-test") + + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + "headers": {"User-Agent": "TestAgent", "Accept": "application/json"}, + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_session = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + with patch("mcpadapt.core.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + return_value=mock_session + ) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock() + mock_session.list_tools.return_value.tools = [] + + async with mcptools(server_params, auth_provider=auth_provider) as ( + session, + tools, + ): + # Verify headers were merged correctly + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + + expected_headers = { + "User-Agent": "TestAgent", + "Accept": "application/json", + "X-API-Key": "merge-test", + } + assert call_kwargs["headers"] == expected_headers + + +class TestAuthProviderEdgeCases: + """Test edge cases with auth providers.""" + + def test_mixed_auth_provider_list(self): + """Test mixed list of auth providers and None values.""" + api_provider = ApiKeyAuthProvider("X-API-Key", "key1") + bearer_provider = BearerAuthProvider("token2") + + server_params = [ + {"url": "https://api1.example.com/mcp", "transport": "streamable-http"}, + {"url": "https://api2.example.com/mcp", "transport": "streamable-http"}, + {"url": "https://api3.example.com/mcp", "transport": "streamable-http"}, + ] + + # Mixed auth providers with None + auth_providers = [api_provider, None, bearer_provider] + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider=auth_providers, + ) + + assert len(adapter.auth_providers) == 3 + assert adapter.auth_providers[0] is api_provider + assert adapter.auth_providers[1] is None + assert adapter.auth_providers[2] is bearer_provider + + def test_auth_with_stdio_server(self): + """Test that auth providers work with stdio servers.""" + from mcp import StdioServerParameters + + auth_provider = ApiKeyAuthProvider("X-API-Key", "stdio-key") + + # For stdio servers, auth should be stored but not used + with patch("mcpadapt.core.stdio_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + adapter = MCPAdapt( + serverparams=StdioServerParameters(command="echo", args=["test"]), + adapter=DummyAdapter(), + auth_provider=auth_provider, + ) + + # Auth provider should be stored even for stdio + assert adapter.auth_providers[0] is auth_provider + + +class TestAuthenticationFlowIntegration: + """Test complete authentication flows with MCPAdapt.""" + + @pytest.mark.asyncio + async def test_complete_oauth_flow_mock(self): + """Test complete OAuth flow integration (mocked).""" + # This would be used with a real OAuth provider in practice + mock_oauth_provider = Mock() + mock_oauth_provider.get_headers = Mock( + return_value={"Authorization": "Bearer oauth_token"} + ) + + server_params = { + "url": "https://oauth.example.com/mcp", + "transport": "streamable-http", + } + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_session = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + with patch("mcpadapt.core.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + return_value=mock_session + ) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock() + mock_session.list_tools.return_value.tools = [] + + # Use OAuth provider that looks like an OAuthClientProvider + async with mcptools( + server_params, auth_provider=mock_oauth_provider + ) as (session, tools): + # Since it's not a real OAuthClientProvider, it should be treated as unknown + # and passed to the auth parameter + mock_client.assert_called_once() + call_kwargs = mock_client.call_args[1] + assert call_kwargs["auth"] is mock_oauth_provider + + def test_auth_provider_parameter_validation(self): + """Test auth provider parameter validation.""" + server_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + } + + # Test with invalid auth provider type + with patch("mcpadapt.core.streamablehttp_client"): + # String should not cause error, just be ignored + adapter = MCPAdapt( + serverparams=server_params, + adapter=DummyAdapter(), + auth_provider="invalid", + ) + assert adapter.auth_providers[0] == "invalid" + + def test_auth_deep_copy_serverparams(self): + """Test that serverparams are deep copied when adding auth.""" + original_params = { + "url": "https://example.com/mcp", + "transport": "streamable-http", + "headers": {"Original": "Header"}, + } + + auth_provider = ApiKeyAuthProvider("X-API-Key", "test-key") + + with patch("mcpadapt.core.streamablehttp_client") as mock_client: + mock_client.return_value.__aenter__ = AsyncMock( + return_value=(Mock(), Mock()) + ) + mock_client.return_value.__aexit__ = AsyncMock() + + MCPAdapt( + serverparams=original_params, + adapter=DummyAdapter(), + auth_provider=auth_provider, + ) + + # Original params should be unchanged + assert "X-API-Key" not in original_params.get("headers", {}) + assert original_params["headers"] == {"Original": "Header"} diff --git a/tests/auth/test_exceptions.py b/tests/auth/test_exceptions.py new file mode 100644 index 0000000..f3a6411 --- /dev/null +++ b/tests/auth/test_exceptions.py @@ -0,0 +1,295 @@ +"""Tests for authentication exception classes.""" + +from mcpadapt.auth.exceptions import ( + OAuthError, + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, + OAuthConfigurationError, + OAuthServerError, + OAuthCallbackError, +) + + +class TestOAuthError: + """Test base OAuth error class.""" + + def test_basic_initialization(self): + """Test basic error initialization.""" + error = OAuthError("Test message") + assert str(error) == "Test message" + assert error.error_code is None + assert error.context == {} + + def test_initialization_with_code(self): + """Test error initialization with error code.""" + error = OAuthError("Test message", "test_code") + assert str(error) == "Test message" + assert error.error_code == "test_code" + assert error.context == {} + + def test_initialization_with_context(self): + """Test error initialization with context.""" + context = {"key": "value", "number": 42} + error = OAuthError("Test message", "test_code", context) + assert str(error) == "Test message" + assert error.error_code == "test_code" + assert error.context == context + + def test_inheritance(self): + """Test that OAuthError inherits from Exception.""" + error = OAuthError("Test message") + assert isinstance(error, Exception) + + +class TestOAuthTimeoutError: + """Test OAuth timeout error class.""" + + def test_basic_initialization(self): + """Test basic timeout error initialization.""" + error = OAuthTimeoutError(300) + expected_message = ( + "OAuth authentication timed out after 300 seconds. " + "The user may have closed the browser window or the OAuth server may be unreachable. " + "Try refreshing the browser or check your network connection." + ) + assert str(error) == expected_message + assert error.error_code == "oauth_timeout" + assert error.timeout_seconds == 300 + assert error.context == {} + + def test_initialization_with_context(self): + """Test timeout error initialization with context.""" + context = {"port": 3030} + error = OAuthTimeoutError(120, context) + expected_message = ( + "OAuth authentication timed out after 120 seconds. " + "The user may have closed the browser window or the OAuth server may be unreachable. " + "Try refreshing the browser or check your network connection." + ) + assert str(error) == expected_message + assert error.error_code == "oauth_timeout" + assert error.timeout_seconds == 120 + assert error.context == context + + def test_inheritance(self): + """Test that OAuthTimeoutError inherits from OAuthError.""" + error = OAuthTimeoutError(300) + assert isinstance(error, OAuthError) + assert isinstance(error, Exception) + + +class TestOAuthCancellationError: + """Test OAuth cancellation error class.""" + + def test_basic_initialization(self): + """Test basic cancellation error initialization.""" + error = OAuthCancellationError() + expected_message = "OAuth authorization was cancelled or denied by the user." + assert str(error) == expected_message + assert error.error_code == "oauth_cancelled" + assert error.error_details is None + assert error.context == {} + + def test_initialization_with_details(self): + """Test cancellation error initialization with details.""" + error = OAuthCancellationError("access_denied") + expected_message = "OAuth authorization was cancelled or denied by the user. Details: access_denied" + assert str(error) == expected_message + assert error.error_code == "oauth_cancelled" + assert error.error_details == "access_denied" + assert error.context == {} + + def test_initialization_with_context(self): + """Test cancellation error initialization with context.""" + context = {"port": 3030} + error = OAuthCancellationError("user_cancelled", context) + expected_message = "OAuth authorization was cancelled or denied by the user. Details: user_cancelled" + assert str(error) == expected_message + assert error.error_code == "oauth_cancelled" + assert error.error_details == "user_cancelled" + assert error.context == context + + def test_inheritance(self): + """Test that OAuthCancellationError inherits from OAuthError.""" + error = OAuthCancellationError() + assert isinstance(error, OAuthError) + assert isinstance(error, Exception) + + +class TestOAuthNetworkError: + """Test OAuth network error class.""" + + def test_basic_initialization(self): + """Test basic network error initialization.""" + original_error = ConnectionError("Network unreachable") + error = OAuthNetworkError(original_error) + expected_message = ( + "OAuth authentication failed due to network error: Network unreachable. " + "Check your internet connection and try again." + ) + assert str(error) == expected_message + assert error.error_code == "oauth_network_error" + assert error.original_error is original_error + assert error.context == {} + + def test_initialization_with_context(self): + """Test network error initialization with context.""" + original_error = TimeoutError("Connection timed out") + context = {"host": "example.com", "port": 443} + error = OAuthNetworkError(original_error, context) + expected_message = ( + "OAuth authentication failed due to network error: Connection timed out. " + "Check your internet connection and try again." + ) + assert str(error) == expected_message + assert error.error_code == "oauth_network_error" + assert error.original_error is original_error + assert error.context == context + + def test_inheritance(self): + """Test that OAuthNetworkError inherits from OAuthError.""" + original_error = Exception("Test error") + error = OAuthNetworkError(original_error) + assert isinstance(error, OAuthError) + assert isinstance(error, Exception) + + +class TestOAuthConfigurationError: + """Test OAuth configuration error class.""" + + def test_basic_initialization(self): + """Test basic configuration error initialization.""" + error = OAuthConfigurationError("Invalid client ID") + expected_message = "OAuth configuration error: Invalid client ID" + assert str(error) == expected_message + assert error.error_code == "oauth_config_error" + assert error.config_issue == "Invalid client ID" + assert error.context == {} + + def test_initialization_with_context(self): + """Test configuration error initialization with context.""" + context = {"client_id": "invalid_id"} + error = OAuthConfigurationError("Missing redirect URI", context) + expected_message = "OAuth configuration error: Missing redirect URI" + assert str(error) == expected_message + assert error.error_code == "oauth_config_error" + assert error.config_issue == "Missing redirect URI" + assert error.context == context + + def test_inheritance(self): + """Test that OAuthConfigurationError inherits from OAuthError.""" + error = OAuthConfigurationError("Test config issue") + assert isinstance(error, OAuthError) + assert isinstance(error, Exception) + + +class TestOAuthServerError: + """Test OAuth server error class.""" + + def test_basic_initialization(self): + """Test basic server error initialization.""" + error = OAuthServerError("invalid_request") + expected_message = "OAuth server error: invalid_request" + assert str(error) == expected_message + assert error.error_code == "oauth_server_error" + assert error.server_error == "invalid_request" + assert error.error_description is None + assert error.context == {} + + def test_initialization_with_description(self): + """Test server error initialization with description.""" + error = OAuthServerError( + "invalid_grant", "The provided authorization grant is invalid" + ) + expected_message = "OAuth server error: invalid_grant - The provided authorization grant is invalid" + assert str(error) == expected_message + assert error.error_code == "oauth_server_error" + assert error.server_error == "invalid_grant" + assert error.error_description == "The provided authorization grant is invalid" + assert error.context == {} + + def test_initialization_with_context(self): + """Test server error initialization with context.""" + context = {"endpoint": "/oauth/token"} + error = OAuthServerError("server_error", "Internal server error", context) + expected_message = "OAuth server error: server_error - Internal server error" + assert str(error) == expected_message + assert error.error_code == "oauth_server_error" + assert error.server_error == "server_error" + assert error.error_description == "Internal server error" + assert error.context == context + + def test_inheritance(self): + """Test that OAuthServerError inherits from OAuthError.""" + error = OAuthServerError("test_error") + assert isinstance(error, OAuthError) + assert isinstance(error, Exception) + + +class TestOAuthCallbackError: + """Test OAuth callback error class.""" + + def test_basic_initialization(self): + """Test basic callback error initialization.""" + error = OAuthCallbackError("Port already in use") + expected_message = "OAuth callback error: Port already in use" + assert str(error) == expected_message + assert error.error_code == "oauth_callback_error" + assert error.callback_issue == "Port already in use" + assert error.context == {} + + def test_initialization_with_context(self): + """Test callback error initialization with context.""" + context = {"port": 3030, "pid": 12345} + error = OAuthCallbackError("Failed to start server", context) + expected_message = "OAuth callback error: Failed to start server" + assert str(error) == expected_message + assert error.error_code == "oauth_callback_error" + assert error.callback_issue == "Failed to start server" + assert error.context == context + + def test_inheritance(self): + """Test that OAuthCallbackError inherits from OAuthError.""" + error = OAuthCallbackError("Test callback issue") + assert isinstance(error, OAuthError) + assert isinstance(error, Exception) + + +class TestExceptionContext: + """Test exception context handling across all error types.""" + + def test_context_preservation(self): + """Test that context is preserved across different error types.""" + context = {"request_id": "req_123", "timestamp": 1234567890} + + errors = [ + OAuthError("Test", "code", context), + OAuthTimeoutError(300, context), + OAuthCancellationError("details", context), + OAuthNetworkError(Exception("test"), context), + OAuthConfigurationError("issue", context), + OAuthServerError("error", "desc", context), + OAuthCallbackError("issue", context), + ] + + for error in errors: + assert error.context == context + assert error.context["request_id"] == "req_123" + assert error.context["timestamp"] == 1234567890 + + def test_empty_context_default(self): + """Test that context defaults to empty dict when not provided.""" + errors = [ + OAuthError("Test"), + OAuthTimeoutError(300), + OAuthCancellationError(), + OAuthNetworkError(Exception("test")), + OAuthConfigurationError("issue"), + OAuthServerError("error"), + OAuthCallbackError("issue"), + ] + + for error in errors: + assert error.context == {} + assert isinstance(error.context, dict) diff --git a/tests/auth/test_handlers.py b/tests/auth/test_handlers.py new file mode 100644 index 0000000..6345096 --- /dev/null +++ b/tests/auth/test_handlers.py @@ -0,0 +1,623 @@ +"""Tests for OAuth handler classes.""" + +import pytest +from unittest.mock import Mock, patch +from mcpadapt.auth.handlers import ( + CallbackHandler, + LocalCallbackServer, + LocalBrowserOAuthHandler, +) +from mcpadapt.auth.exceptions import ( + OAuthTimeoutError, + OAuthCancellationError, + OAuthNetworkError, + OAuthCallbackError, + OAuthServerError, +) + + +class TestCallbackHandler: + """Test OAuth callback HTTP handler.""" + + def test_callback_handler_success(self): + """Test successful OAuth callback handling.""" + callback_data = {"authorization_code": None, "state": None, "error": None} + + # Mock request and server components + mock_request = Mock() + mock_client_address = ("127.0.0.1", 12345) + mock_server = Mock() + + # Create handler with callback data + with patch( + "mcpadapt.auth.handlers.BaseHTTPRequestHandler.__init__" + ) as mock_init: + mock_init.return_value = None + handler = CallbackHandler( + mock_request, mock_client_address, mock_server, callback_data + ) + + # Mock the required attributes + handler.path = "/callback?code=test_auth_code&state=test_state" + handler.send_response = Mock() + handler.send_header = Mock() + handler.end_headers = Mock() + handler.wfile = Mock() + handler.wfile.write = Mock() + + # Call do_GET + handler.do_GET() + + # Verify callback data was set + assert callback_data["authorization_code"] == "test_auth_code" + assert callback_data["state"] == "test_state" + assert callback_data["error"] is None + + # Verify HTTP response + handler.send_response.assert_called_once_with(200) + handler.send_header.assert_called_with("Content-type", "text/html") + handler.end_headers.assert_called_once() + handler.wfile.write.assert_called_once() + + def test_callback_handler_error(self): + """Test OAuth callback handling with error.""" + callback_data = {"authorization_code": None, "state": None, "error": None} + + # Mock request and server components + mock_request = Mock() + mock_client_address = ("127.0.0.1", 12345) + mock_server = Mock() + + with patch( + "mcpadapt.auth.handlers.BaseHTTPRequestHandler.__init__" + ) as mock_init: + mock_init.return_value = None + handler = CallbackHandler( + mock_request, mock_client_address, mock_server, callback_data + ) + + # Mock the required attributes + handler.path = "/callback?error=access_denied&error_description=user_denied" + handler.send_response = Mock() + handler.send_header = Mock() + handler.end_headers = Mock() + handler.wfile = Mock() + handler.wfile.write = Mock() + + # Call do_GET + handler.do_GET() + + # Verify callback data was set + assert callback_data["authorization_code"] is None + assert callback_data["state"] is None + assert callback_data["error"] == "access_denied" + + # Verify HTTP error response + handler.send_response.assert_called_once_with(400) + + def test_callback_handler_no_params(self): + """Test callback handler with no parameters.""" + callback_data = {"authorization_code": None, "state": None, "error": None} + + mock_request = Mock() + mock_client_address = ("127.0.0.1", 12345) + mock_server = Mock() + + with patch( + "mcpadapt.auth.handlers.BaseHTTPRequestHandler.__init__" + ) as mock_init: + mock_init.return_value = None + handler = CallbackHandler( + mock_request, mock_client_address, mock_server, callback_data + ) + + handler.path = "/callback" + handler.send_response = Mock() + handler.end_headers = Mock() + + handler.do_GET() + + # Verify no data was set + assert callback_data["authorization_code"] is None + assert callback_data["state"] is None + assert callback_data["error"] is None + + # Verify 404 response + handler.send_response.assert_called_once_with(404) + + def test_callback_handler_log_message_suppressed(self): + """Test that log messages are suppressed.""" + callback_data = {"authorization_code": None, "state": None, "error": None} + + mock_request = Mock() + mock_client_address = ("127.0.0.1", 12345) + mock_server = Mock() + + with patch( + "mcpadapt.auth.handlers.BaseHTTPRequestHandler.__init__" + ) as mock_init: + mock_init.return_value = None + handler = CallbackHandler( + mock_request, mock_client_address, mock_server, callback_data + ) + + # log_message should do nothing (return None) + result = handler.log_message("test %s", "message") + assert result is None + + +class TestLocalCallbackServer: + """Test local OAuth callback server.""" + + def test_initialization(self): + """Test server initialization.""" + server = LocalCallbackServer(port=3030) + assert server.port == 3030 + assert server.server is None + assert server.thread is None + assert server.callback_data == { + "authorization_code": None, + "state": None, + "error": None, + } + + def test_initialization_default_port(self): + """Test server initialization with default port.""" + server = LocalCallbackServer() + assert server.port == 3030 + + def test_start_success(self, mock_http_server, mock_threading): + """Test successful server start.""" + mock_server, mock_server_class = mock_http_server + mock_thread, mock_thread_class = mock_threading + + server = LocalCallbackServer(port=3030) + server.start() + + # Verify HTTPServer creation + mock_server_class.assert_called_once() + args, kwargs = mock_server_class.call_args + assert args[0] == ("localhost", 3030) + + # Verify thread creation and start + mock_thread_class.assert_called_once() + mock_thread.start.assert_called_once() + + assert server.server is mock_server + assert server.thread is mock_thread + + def test_start_port_in_use(self, mock_http_server): + """Test server start with port already in use.""" + mock_server, mock_server_class = mock_http_server + + # Configure mock to raise "Address already in use" error + os_error = OSError() + os_error.errno = 48 # EADDRINUSE + mock_server_class.side_effect = os_error + + server = LocalCallbackServer(port=3030) + + with pytest.raises(OAuthCallbackError) as exc_info: + server.start() + + assert "Port 3030 is already in use" in str(exc_info.value) + assert exc_info.value.context["port"] == 3030 + + def test_start_other_os_error(self, mock_http_server): + """Test server start with other OS error.""" + mock_server, mock_server_class = mock_http_server + + # Configure mock to raise different OS error + os_error = OSError("Permission denied") + os_error.errno = 13 # EACCES + mock_server_class.side_effect = os_error + + server = LocalCallbackServer(port=3030) + + with pytest.raises(OAuthCallbackError) as exc_info: + server.start() + + assert "Failed to start OAuth callback server" in str(exc_info.value) + assert "Permission denied" in str(exc_info.value) + + def test_start_unexpected_error(self, mock_http_server): + """Test server start with unexpected error.""" + mock_server, mock_server_class = mock_http_server + + mock_server_class.side_effect = ValueError("Unexpected error") + + server = LocalCallbackServer(port=3030) + + with pytest.raises(OAuthCallbackError) as exc_info: + server.start() + + assert "Unexpected error starting OAuth callback server" in str(exc_info.value) + + def test_stop(self, mock_http_server, mock_threading): + """Test server stop.""" + mock_server, mock_server_class = mock_http_server + mock_thread, mock_thread_class = mock_threading + + server = LocalCallbackServer(port=3030) + server.start() + + # Stop the server + server.stop() + + # Verify shutdown calls + mock_server.shutdown.assert_called_once() + mock_server.server_close.assert_called_once() + mock_thread.join.assert_called_once_with(timeout=1) + + def test_stop_no_server(self): + """Test stop when no server is running.""" + server = LocalCallbackServer(port=3030) + + # Should not raise an error + server.stop() + + def test_wait_for_callback_success(self, mock_time): + """Test successful callback waiting.""" + # Configure mock time to return increasing values + mock_time.time.side_effect = [0, 1, 2, 3] # Simulate time progression + + server = LocalCallbackServer(port=3030) + # Type ignore needed for test - callback_data values can change during runtime + server.callback_data["authorization_code"] = "test_code" # type: ignore[assignment] + + result = server.wait_for_callback(timeout=300) + + assert result == "test_code" + + def test_wait_for_callback_timeout(self, mock_time): + """Test callback waiting timeout.""" + # Configure mock time to simulate timeout + mock_time.time.side_effect = [0, 100, 200, 300, 400] # Exceed timeout + + server = LocalCallbackServer(port=3030) + + with pytest.raises(OAuthTimeoutError) as exc_info: + server.wait_for_callback(timeout=300) + + assert exc_info.value.timeout_seconds == 300 + assert exc_info.value.context["port"] == 3030 + + def test_wait_for_callback_access_denied(self, mock_time): + """Test callback waiting with access denied error.""" + mock_time.time.side_effect = [0, 1] + + server = LocalCallbackServer(port=3030) + server.callback_data["error"] = "access_denied" # type: ignore[assignment] + + with pytest.raises(OAuthCancellationError) as exc_info: + server.wait_for_callback(timeout=300) + + assert exc_info.value.error_details == "access_denied" + + def test_wait_for_callback_server_error(self, mock_time): + """Test callback waiting with server error.""" + mock_time.time.side_effect = [0, 1] + + server = LocalCallbackServer(port=3030) + server.callback_data["error"] = "invalid_request" # type: ignore[assignment] + + with pytest.raises(OAuthServerError) as exc_info: + server.wait_for_callback(timeout=300) + + assert exc_info.value.server_error == "invalid_request" + + def test_get_state(self): + """Test getting OAuth state parameter.""" + server = LocalCallbackServer(port=3030) + + # Initially None + assert server.get_state() is None + + # Set state + server.callback_data["state"] = "test_state" # type: ignore[assignment] + assert server.get_state() == "test_state" + + +class TestLocalBrowserOAuthHandler: + """Test local browser OAuth handler.""" + + def test_initialization(self): + """Test handler initialization.""" + handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + assert handler.callback_port == 3030 + assert handler.timeout == 300 + assert handler.callback_server is None + + def test_initialization_defaults(self): + """Test handler initialization with defaults.""" + handler = LocalBrowserOAuthHandler() + assert handler.callback_port == 3030 + assert handler.timeout == 300 + + @pytest.mark.asyncio + async def test_handle_redirect_success(self, mock_webbrowser): + """Test successful redirect handling.""" + mock_webbrowser.open.return_value = True + + handler = LocalBrowserOAuthHandler() + + # Should not raise + await handler.handle_redirect("https://example.com/oauth/authorize") + + mock_webbrowser.open.assert_called_once_with( + "https://example.com/oauth/authorize" + ) + + @pytest.mark.asyncio + async def test_handle_redirect_browser_fail(self, mock_webbrowser): + """Test redirect handling when browser fails to open.""" + mock_webbrowser.open.return_value = False + + handler = LocalBrowserOAuthHandler() + + with pytest.raises(OAuthNetworkError) as exc_info: + await handler.handle_redirect("https://example.com/oauth/authorize") + + assert "Failed to open browser" in str(exc_info.value.original_error) + + @pytest.mark.asyncio + async def test_handle_redirect_browser_exception(self, mock_webbrowser): + """Test redirect handling when browser raises exception.""" + mock_webbrowser.open.side_effect = Exception("Browser error") + + handler = LocalBrowserOAuthHandler() + + with pytest.raises(OAuthNetworkError) as exc_info: + await handler.handle_redirect("https://example.com/oauth/authorize") + + assert "Browser error" in str(exc_info.value.original_error) + + @pytest.mark.asyncio + async def test_handle_callback_success(self): + """Test successful callback handling.""" + handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + + # Mock LocalCallbackServer + mock_server = Mock() + mock_server.wait_for_callback.return_value = "test_auth_code" + mock_server.get_state.return_value = "test_state" + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + auth_code, state = await handler.handle_callback() + + assert auth_code == "test_auth_code" + assert state == "test_state" + + # Verify server lifecycle + mock_server_class.assert_called_once_with(port=3030) + mock_server.start.assert_called_once() + mock_server.wait_for_callback.assert_called_once_with(timeout=300) + mock_server.get_state.assert_called_once() + mock_server.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_callback_timeout(self): + """Test callback handling with timeout.""" + handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=60) + + mock_server = Mock() + mock_server.wait_for_callback.side_effect = OAuthTimeoutError(60) + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + with pytest.raises(OAuthTimeoutError): + await handler.handle_callback() + + # Verify cleanup + mock_server.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_callback_cancellation(self): + """Test callback handling with user cancellation.""" + handler = LocalBrowserOAuthHandler() + + mock_server = Mock() + mock_server.wait_for_callback.side_effect = OAuthCancellationError( + "access_denied" + ) + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + with pytest.raises(OAuthCancellationError): + await handler.handle_callback() + + # Verify cleanup + mock_server.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_callback_server_error(self): + """Test callback handling with server error.""" + handler = LocalBrowserOAuthHandler() + + mock_server = Mock() + mock_server.wait_for_callback.side_effect = OAuthServerError("invalid_request") + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + with pytest.raises(OAuthServerError): + await handler.handle_callback() + + # Verify cleanup + mock_server.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_callback_unexpected_error(self): + """Test callback handling with unexpected error.""" + handler = LocalBrowserOAuthHandler() + + mock_server = Mock() + mock_server.start.side_effect = ValueError("Unexpected error") + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + with pytest.raises(OAuthCallbackError) as exc_info: + await handler.handle_callback() + + assert "Unexpected error during OAuth callback handling" in str( + exc_info.value + ) + assert "Unexpected error" in str(exc_info.value) + + # Verify cleanup attempt + mock_server.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_callback_cleanup_on_success(self): + """Test that cleanup always happens even on success.""" + handler = LocalBrowserOAuthHandler() + + mock_server = Mock() + mock_server.wait_for_callback.return_value = "auth_code" + mock_server.get_state.return_value = "state" + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + await handler.handle_callback() + + # Verify cleanup always happens + mock_server.stop.assert_called_once() + + +class TestLocalBrowserOAuthHandlerIntegration: + """Test integration scenarios for LocalBrowserOAuthHandler.""" + + @pytest.mark.asyncio + async def test_full_oauth_flow_simulation(self, mock_webbrowser): + """Test complete OAuth flow simulation.""" + mock_webbrowser.open.return_value = True + + handler = LocalBrowserOAuthHandler(callback_port=4040, timeout=120) + + # Mock server for callback + mock_server = Mock() + mock_server.wait_for_callback.return_value = "final_auth_code" + mock_server.get_state.return_value = "final_state" + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + # Step 1: Handle redirect + await handler.handle_redirect( + "https://oauth.example.com/authorize?client_id=123" + ) + + # Step 2: Handle callback + auth_code, state = await handler.handle_callback() + + # Verify results + assert auth_code == "final_auth_code" + assert state == "final_state" + + # Verify browser was opened + mock_webbrowser.open.assert_called_once_with( + "https://oauth.example.com/authorize?client_id=123" + ) + + # Verify server operations + mock_server_class.assert_called_once_with(port=4040) + mock_server.start.assert_called_once() + mock_server.wait_for_callback.assert_called_once_with(timeout=120) + mock_server.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_multiple_handlers_independent(self, mock_webbrowser): + """Test that multiple handler instances are independent.""" + mock_webbrowser.open.return_value = True + + handler1 = LocalBrowserOAuthHandler(callback_port=3030, timeout=100) + handler2 = LocalBrowserOAuthHandler(callback_port=4040, timeout=200) + + # Mock servers + mock_server1 = Mock() + mock_server1.wait_for_callback.return_value = "code1" + mock_server1.get_state.return_value = "state1" + + mock_server2 = Mock() + mock_server2.wait_for_callback.return_value = "code2" + mock_server2.get_state.return_value = "state2" + + def server_factory(port): + if port == 3030: + return mock_server1 + else: + return mock_server2 + + with patch( + "mcpadapt.auth.handlers.LocalCallbackServer", side_effect=server_factory + ): + # Use handlers independently + code1, state1 = await handler1.handle_callback() + code2, state2 = await handler2.handle_callback() + + # Verify independence + assert code1 == "code1" + assert state1 == "state1" + assert code2 == "code2" + assert state2 == "state2" + + # Verify both servers were used + mock_server1.start.assert_called_once() + mock_server1.stop.assert_called_once() + mock_server2.start.assert_called_once() + mock_server2.stop.assert_called_once() + + +class TestHandlerErrorScenarios: + """Test comprehensive error scenarios across handlers.""" + + @pytest.mark.asyncio + async def test_network_failure_during_redirect(self, mock_webbrowser): + """Test network failure during redirect.""" + mock_webbrowser.open.side_effect = ConnectionError("Network unreachable") + + handler = LocalBrowserOAuthHandler() + + with pytest.raises(OAuthNetworkError) as exc_info: + await handler.handle_redirect("https://example.com/oauth") + + assert isinstance(exc_info.value.original_error, ConnectionError) + assert "Network unreachable" in str(exc_info.value.original_error) + + def test_callback_server_edge_cases(self): + """Test callback server edge cases.""" + server = LocalCallbackServer(port=0) # Invalid port + + # Should still initialize + assert server.port == 0 + assert server.callback_data is not None + + @pytest.mark.asyncio + async def test_handler_state_consistency(self): + """Test handler state remains consistent across operations.""" + handler = LocalBrowserOAuthHandler(callback_port=5050, timeout=150) + + # Verify initial state + assert handler.callback_port == 5050 + assert handler.timeout == 150 + assert handler.callback_server is None + + # Mock callback operation + mock_server = Mock() + mock_server.wait_for_callback.return_value = "test_code" + mock_server.get_state.return_value = None + + with patch("mcpadapt.auth.handlers.LocalCallbackServer") as mock_server_class: + mock_server_class.return_value = mock_server + + await handler.handle_callback() + + # Verify state is still consistent + assert handler.callback_port == 5050 + assert handler.timeout == 150 diff --git a/tests/auth/test_oauth.py b/tests/auth/test_oauth.py new file mode 100644 index 0000000..9ddc35b --- /dev/null +++ b/tests/auth/test_oauth.py @@ -0,0 +1,524 @@ +"""Tests for OAuth token storage implementation.""" + +import pytest +from mcpadapt.auth.oauth import InMemoryTokenStorage + + +class TestInMemoryTokenStorage: + """Test in-memory token storage implementation.""" + + def test_initialization(self): + """Test basic initialization.""" + storage = InMemoryTokenStorage() + assert storage._tokens is None + assert storage._client_info is None + + def test_initialization_with_client_info(self, mock_oauth_client_info): + """Test initialization with pre-configured client info.""" + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + assert storage._tokens is None + assert storage._client_info is not None + assert storage._client_info is mock_oauth_client_info + + @pytest.mark.asyncio + async def test_get_tokens_initially_none(self): + """Test that tokens are initially None.""" + storage = InMemoryTokenStorage() + tokens = await storage.get_tokens() + assert tokens is None + + @pytest.mark.asyncio + async def test_get_client_info_initially_none(self): + """Test that client info is initially None.""" + storage = InMemoryTokenStorage() + client_info = await storage.get_client_info() + assert client_info is None + + @pytest.mark.asyncio + async def test_get_preconfigured_client_info(self, mock_oauth_client_info): + """Test getting pre-configured client info.""" + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Get client info + retrieved_info = await storage.get_client_info() + assert retrieved_info is not None + assert retrieved_info is mock_oauth_client_info + assert retrieved_info.client_id == "test_client_id" + assert retrieved_info.client_secret == "test_client_secret" + assert retrieved_info.client_id_issued_at == 1234567890 + assert retrieved_info.client_secret_expires_at == 0 + + @pytest.mark.asyncio + async def test_preconfigured_client_info_with_tokens( + self, mock_oauth_client_info, mock_oauth_token + ): + """Test that pre-configured client info works alongside token operations.""" + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Initially has client info but no tokens + assert await storage.get_client_info() is not None + assert await storage.get_tokens() is None + + # Set tokens + await storage.set_tokens(mock_oauth_token) + + # Should have both client info and tokens + assert await storage.get_client_info() is mock_oauth_client_info + assert await storage.get_tokens() is mock_oauth_token + + @pytest.mark.asyncio + async def test_preconfigured_client_info_overwrite(self, mock_oauth_client_info): + """Test overwriting pre-configured client info.""" + from mcp.shared.auth import OAuthClientInformationFull + from pydantic import HttpUrl + + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Verify initial client info + initial_info = await storage.get_client_info() + assert initial_info is not None + assert initial_info is mock_oauth_client_info + assert initial_info.client_id == "test_client_id" + + # Create new client info + new_info = OAuthClientInformationFull( + client_id="overwrite_client_id", + client_secret="overwrite_client_secret", + redirect_uris=[HttpUrl("http://localhost:5050/callback")], + ) + + # Overwrite client info + await storage.set_client_info(new_info) + updated_info = await storage.get_client_info() + + assert updated_info is not None + assert updated_info is new_info + assert updated_info.client_id == "overwrite_client_id" + assert updated_info.client_secret == "overwrite_client_secret" + assert updated_info is not mock_oauth_client_info + + @pytest.mark.asyncio + async def test_set_and_get_tokens(self, mock_oauth_token): + """Test setting and getting tokens.""" + storage = InMemoryTokenStorage() + + # Set tokens + await storage.set_tokens(mock_oauth_token) + + # Get tokens + retrieved_tokens = await storage.get_tokens() + assert retrieved_tokens is not None + assert retrieved_tokens is mock_oauth_token + assert retrieved_tokens.access_token == "test_access_token" + assert retrieved_tokens.token_type == "Bearer" + assert retrieved_tokens.expires_in == 3600 + assert retrieved_tokens.refresh_token == "test_refresh_token" + assert retrieved_tokens.scope == "read write" + + @pytest.mark.asyncio + async def test_set_and_get_client_info(self, mock_oauth_client_info): + """Test setting and getting client info.""" + storage = InMemoryTokenStorage() + + # Set client info + await storage.set_client_info(mock_oauth_client_info) + + # Get client info + retrieved_info = await storage.get_client_info() + assert retrieved_info is not None + assert retrieved_info is mock_oauth_client_info + assert retrieved_info.client_id == "test_client_id" + assert retrieved_info.client_secret == "test_client_secret" + assert retrieved_info.client_id_issued_at == 1234567890 + assert retrieved_info.client_secret_expires_at == 0 + + @pytest.mark.asyncio + async def test_overwrite_tokens(self, mock_oauth_token): + """Test overwriting existing tokens.""" + from mcp.shared.auth import OAuthToken + + storage = InMemoryTokenStorage() + + # Set initial tokens + await storage.set_tokens(mock_oauth_token) + initial_tokens = await storage.get_tokens() + assert initial_tokens is not None + assert initial_tokens.access_token == "test_access_token" + + # Create new tokens + new_token = OAuthToken( + access_token="new_access_token", + token_type="Bearer", + expires_in=7200, + refresh_token="new_refresh_token", + scope="read write delete", + ) + + # Overwrite tokens + await storage.set_tokens(new_token) + updated_tokens = await storage.get_tokens() + + assert updated_tokens is not None + assert updated_tokens is new_token + assert updated_tokens.access_token == "new_access_token" + assert updated_tokens.expires_in == 7200 + assert updated_tokens.refresh_token == "new_refresh_token" + assert updated_tokens.scope == "read write delete" + + @pytest.mark.asyncio + async def test_overwrite_client_info(self, mock_oauth_client_info): + """Test overwriting existing client info.""" + from mcp.shared.auth import OAuthClientInformationFull + from pydantic import HttpUrl + + storage = InMemoryTokenStorage() + + # Set initial client info + await storage.set_client_info(mock_oauth_client_info) + initial_info = await storage.get_client_info() + assert initial_info is not None + assert initial_info.client_id == "test_client_id" + + # Create new client info + new_info = OAuthClientInformationFull( + client_id="new_client_id", + client_secret="new_client_secret", + client_id_issued_at=9876543210, + client_secret_expires_at=9999999999, + redirect_uris=[HttpUrl("http://localhost:4040/callback")], + ) + + # Overwrite client info + await storage.set_client_info(new_info) + updated_info = await storage.get_client_info() + + assert updated_info is not None + assert updated_info is new_info + assert updated_info.client_id == "new_client_id" + assert updated_info.client_secret == "new_client_secret" + assert updated_info.client_id_issued_at == 9876543210 + assert updated_info.client_secret_expires_at == 9999999999 + + @pytest.mark.asyncio + async def test_independent_storage(self, mock_oauth_token, mock_oauth_client_info): + """Test that tokens and client info are stored independently.""" + storage = InMemoryTokenStorage() + + # Set only tokens + await storage.set_tokens(mock_oauth_token) + assert await storage.get_tokens() is not None + assert await storage.get_client_info() is None + + # Set only client info + storage2 = InMemoryTokenStorage() + await storage2.set_client_info(mock_oauth_client_info) + assert await storage2.get_tokens() is None + assert await storage2.get_client_info() is not None + + # Set both in same storage + await storage.set_client_info(mock_oauth_client_info) + assert await storage.get_tokens() is not None + assert await storage.get_client_info() is not None + + @pytest.mark.asyncio + async def test_multiple_instances_independent( + self, mock_oauth_token, mock_oauth_client_info + ): + """Test that multiple storage instances are independent.""" + storage1 = InMemoryTokenStorage() + storage2 = InMemoryTokenStorage() + + # Set tokens in first storage + await storage1.set_tokens(mock_oauth_token) + + # Verify second storage unaffected + assert await storage1.get_tokens() is not None + assert await storage2.get_tokens() is None + + # Set different data in second storage + from mcp.shared.auth import OAuthToken + + different_token = OAuthToken( + access_token="different_token", token_type="Bearer", expires_in=1800 + ) + await storage2.set_tokens(different_token) + + # Verify independence + tokens1 = await storage1.get_tokens() + tokens2 = await storage2.get_tokens() + + assert tokens1 is not None + assert tokens2 is not None + assert tokens1.access_token == "test_access_token" + assert tokens2.access_token == "different_token" + assert tokens1 is not tokens2 + + @pytest.mark.asyncio + async def test_storage_interface_compliance(self): + """Test that InMemoryTokenStorage implements required interface methods.""" + storage = InMemoryTokenStorage() + + # Verify required methods exist and are callable + assert hasattr(storage, "get_tokens") + assert hasattr(storage, "set_tokens") + assert hasattr(storage, "get_client_info") + assert hasattr(storage, "set_client_info") + + assert callable(storage.get_tokens) + assert callable(storage.set_tokens) + assert callable(storage.get_client_info) + assert callable(storage.set_client_info) + + # Verify methods work as expected (interface compliance) + tokens = await storage.get_tokens() + client_info = await storage.get_client_info() + assert tokens is None + assert client_info is None + + +class TestInMemoryTokenStorageEdgeCases: + """Test edge cases and error scenarios for InMemoryTokenStorage.""" + + @pytest.mark.asyncio + async def test_set_tokens_none(self): + """Test setting tokens to None (should be allowed).""" + from mcp.shared.auth import OAuthToken + + storage = InMemoryTokenStorage() + + # Set some tokens first + token = OAuthToken(access_token="test", token_type="Bearer") + await storage.set_tokens(token) + assert await storage.get_tokens() is not None + + @pytest.mark.asyncio + async def test_concurrent_access_simulation(self, mock_oauth_token): + """Test simulated concurrent access to storage.""" + import asyncio + + storage = InMemoryTokenStorage() + results = [] + + async def set_and_get(token_suffix): + from mcp.shared.auth import OAuthToken + + token = OAuthToken( + access_token=f"token_{token_suffix}", token_type="Bearer" + ) + await storage.set_tokens(token) + retrieved = await storage.get_tokens() + assert retrieved is not None + results.append(retrieved.access_token) + + # Run multiple coroutines + await asyncio.gather(set_and_get("1"), set_and_get("2"), set_and_get("3")) + + # Verify we got results (order may vary due to concurrency) + assert len(results) == 3 + assert all(result.startswith("token_") for result in results) + + # Final stored token should be one of the set tokens + final_token = await storage.get_tokens() + assert final_token is not None + assert final_token.access_token in results + + @pytest.mark.asyncio + async def test_token_object_reference_integrity(self, mock_oauth_token): + """Test that stored objects maintain reference integrity.""" + storage = InMemoryTokenStorage() + + # Store token + await storage.set_tokens(mock_oauth_token) + + # Get token multiple times + token1 = await storage.get_tokens() + token2 = await storage.get_tokens() + + # All references should be to the same object + assert token1 is mock_oauth_token + assert token2 is mock_oauth_token + assert token1 is token2 + + @pytest.mark.asyncio + async def test_memory_cleanup_behavior(self): + """Test memory cleanup behavior when overwriting data.""" + from mcp.shared.auth import OAuthToken + + storage = InMemoryTokenStorage() + + # Create and store first token + token1 = OAuthToken(access_token="token1", token_type="Bearer") + await storage.set_tokens(token1) + + # Store reference to verify cleanup + token1_ref = await storage.get_tokens() + assert token1_ref is token1 + + # Overwrite with new token + token2 = OAuthToken(access_token="token2", token_type="Bearer") + await storage.set_tokens(token2) + + # Verify new token is stored and old reference is no longer accessible through storage + current_token = await storage.get_tokens() + assert current_token is not None + assert current_token is token2 + assert current_token.access_token == "token2" + + +class TestInMemoryTokenStorageAsyncPatterns: + """Test async patterns and coroutine behavior.""" + + @pytest.mark.asyncio + async def test_method_chaining_async( + self, mock_oauth_token, mock_oauth_client_info + ): + """Test async method chaining patterns.""" + storage = InMemoryTokenStorage() + + # Chain operations + await storage.set_tokens(mock_oauth_token) + await storage.set_client_info(mock_oauth_client_info) + + # Verify both operations completed + tokens = await storage.get_tokens() + client_info = await storage.get_client_info() + + assert tokens is mock_oauth_token + assert client_info is mock_oauth_client_info + + +class TestInMemoryTokenStoragePreConfiguredCredentials: + """Test pre-configured OAuth credentials functionality for skipping DCR.""" + + def test_preconfigured_credentials_skip_dcr_scenario(self, mock_oauth_client_info): + """Test the main use case: pre-configured credentials to skip DCR.""" + # This simulates the scenario where a user has existing OAuth app credentials + # and wants to skip Dynamic Client Registration + + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Verify storage is properly initialized with client info + assert storage._client_info is not None + assert storage._client_info.client_id == "test_client_id" + assert storage._client_info.client_secret == "test_client_secret" + assert storage._tokens is None # No tokens initially + + @pytest.mark.asyncio + async def test_preconfigured_credentials_workflow( + self, mock_oauth_client_info, mock_oauth_token + ): + """Test complete workflow with pre-configured credentials.""" + # Initialize with pre-configured client credentials + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Step 1: Storage already has client info (skips DCR) + client_info = await storage.get_client_info() + assert client_info is not None + assert client_info.client_id == "test_client_id" + assert client_info.client_secret == "test_client_secret" + + # Step 2: No tokens initially + tokens = await storage.get_tokens() + assert tokens is None + + # Step 3: After OAuth flow, tokens are stored + await storage.set_tokens(mock_oauth_token) + + # Step 4: Now we have both client info and tokens + final_client_info = await storage.get_client_info() + final_tokens = await storage.get_tokens() + + assert final_client_info is mock_oauth_client_info + assert final_tokens is mock_oauth_token + assert final_tokens.access_token == "test_access_token" + + def test_multiple_preconfigured_storages_independent(self): + """Test that multiple pre-configured storages are independent.""" + from mcp.shared.auth import OAuthClientInformationFull + from pydantic import HttpUrl + + # Create two different client infos + client_info_1 = OAuthClientInformationFull( + client_id="client_1", + client_secret="secret_1", + redirect_uris=[HttpUrl("http://localhost:3030/callback")], + ) + + client_info_2 = OAuthClientInformationFull( + client_id="client_2", + client_secret="secret_2", + redirect_uris=[HttpUrl("http://localhost:4040/callback")], + ) + + # Create storages with different pre-configured credentials + storage_1 = InMemoryTokenStorage(client_info=client_info_1) + storage_2 = InMemoryTokenStorage(client_info=client_info_2) + + # Verify independence + assert storage_1._client_info is not None + assert storage_2._client_info is not None + assert storage_1._client_info is not storage_2._client_info + assert storage_1._client_info.client_id == "client_1" + assert storage_2._client_info.client_id == "client_2" + assert storage_1._client_info.client_secret == "secret_1" + assert storage_2._client_info.client_secret == "secret_2" + + @pytest.mark.asyncio + async def test_preconfigured_vs_regular_storage_behavior( + self, mock_oauth_client_info + ): + """Test difference between pre-configured and regular storage behavior.""" + # Regular storage (for DCR) + regular_storage = InMemoryTokenStorage() + + # Pre-configured storage (skips DCR) + preconfigured_storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Regular storage has no client info initially + regular_client_info = await regular_storage.get_client_info() + assert regular_client_info is None + + # Pre-configured storage has client info immediately + preconfigured_client_info = await preconfigured_storage.get_client_info() + assert preconfigured_client_info is not None + assert preconfigured_client_info is mock_oauth_client_info + assert preconfigured_client_info.client_id == "test_client_id" + + # Both should have no tokens initially + assert await regular_storage.get_tokens() is None + assert await preconfigured_storage.get_tokens() is None + + @pytest.mark.asyncio + async def test_preconfigured_storage_interface_methods( + self, mock_oauth_client_info + ): + """Test that pre-configured storage works with all interface methods.""" + from mcp.shared.auth import OAuthToken, OAuthClientInformationFull + from pydantic import HttpUrl + + storage = InMemoryTokenStorage(client_info=mock_oauth_client_info) + + # Test get_client_info returns pre-configured info + client_info = await storage.get_client_info() + assert client_info is mock_oauth_client_info + + # Test set_client_info can override pre-configured info + new_client_info = OAuthClientInformationFull( + client_id="new_client", + client_secret="new_secret", + redirect_uris=[HttpUrl("http://localhost:5050/callback")], + ) + await storage.set_client_info(new_client_info) + + updated_client_info = await storage.get_client_info() + assert updated_client_info is new_client_info + assert updated_client_info.client_id == "new_client" + assert updated_client_info is not mock_oauth_client_info + + # Test tokens still work normally + token = OAuthToken(access_token="test_token", token_type="Bearer") + await storage.set_tokens(token) + + retrieved_token = await storage.get_tokens() + assert retrieved_token is token + assert retrieved_token.access_token == "test_token" diff --git a/tests/auth/test_providers.py b/tests/auth/test_providers.py new file mode 100644 index 0000000..9ee115a --- /dev/null +++ b/tests/auth/test_providers.py @@ -0,0 +1,297 @@ +"""Tests for authentication provider classes.""" + +from unittest.mock import Mock +from mcpadapt.auth.providers import ( + ApiKeyAuthProvider, + BearerAuthProvider, + get_auth_headers, +) + + +class TestApiKeyAuthProvider: + """Test API key authentication provider.""" + + def test_initialization(self): + """Test basic initialization.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key-123") + assert provider.header_name == "X-API-Key" + assert provider.header_value == "test-key-123" + + def test_initialization_different_header(self): + """Test initialization with different header name.""" + provider = ApiKeyAuthProvider("Authorization", "ApiKey test-key-456") + assert provider.header_name == "Authorization" + assert provider.header_value == "ApiKey test-key-456" + + def test_initialization_empty_values(self): + """Test initialization with empty values.""" + provider = ApiKeyAuthProvider("", "") + assert provider.header_name == "" + assert provider.header_value == "" + + def test_get_headers_basic(self): + """Test get_headers method.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key-123") + headers = provider.get_headers() + + assert isinstance(headers, dict) + assert headers == {"X-API-Key": "test-key-123"} + + def test_get_headers_different_header(self): + """Test get_headers with different header name.""" + provider = ApiKeyAuthProvider("Custom-Auth", "custom-value") + headers = provider.get_headers() + + assert isinstance(headers, dict) + assert headers == {"Custom-Auth": "custom-value"} + + def test_get_headers_returns_new_dict(self): + """Test that get_headers returns a new dict instance each time.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key") + headers1 = provider.get_headers() + headers2 = provider.get_headers() + + assert headers1 == headers2 + assert headers1 is not headers2 # Different instances + + def test_get_headers_multiple_calls(self): + """Test multiple calls to get_headers return consistent results.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key") + + for _ in range(5): + headers = provider.get_headers() + assert headers == {"X-API-Key": "test-key"} + + +class TestBearerAuthProvider: + """Test Bearer token authentication provider.""" + + def test_initialization(self): + """Test basic initialization.""" + provider = BearerAuthProvider("test-token-123") + assert provider.token == "test-token-123" + + def test_initialization_empty_token(self): + """Test initialization with empty token.""" + provider = BearerAuthProvider("") + assert provider.token == "" + + def test_initialization_complex_token(self): + """Test initialization with complex token.""" + complex_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + provider = BearerAuthProvider(complex_token) + assert provider.token == complex_token + + def test_get_headers_basic(self): + """Test get_headers method.""" + provider = BearerAuthProvider("test-token-123") + headers = provider.get_headers() + + assert isinstance(headers, dict) + assert headers == {"Authorization": "Bearer test-token-123"} + + def test_get_headers_empty_token(self): + """Test get_headers with empty token.""" + provider = BearerAuthProvider("") + headers = provider.get_headers() + + assert isinstance(headers, dict) + assert headers == {"Authorization": "Bearer "} + + def test_get_headers_returns_new_dict(self): + """Test that get_headers returns a new dict instance each time.""" + provider = BearerAuthProvider("test-token") + headers1 = provider.get_headers() + headers2 = provider.get_headers() + + assert headers1 == headers2 + assert headers1 is not headers2 # Different instances + + def test_get_headers_multiple_calls(self): + """Test multiple calls to get_headers return consistent results.""" + provider = BearerAuthProvider("test-token") + + for _ in range(5): + headers = provider.get_headers() + assert headers == {"Authorization": "Bearer test-token"} + + def test_bearer_format_consistency(self): + """Test that Bearer format is consistent.""" + tokens = [ + "simple", + "complex.token.here", + "token-with-dashes", + "token_with_underscores", + ] + + for token in tokens: + provider = BearerAuthProvider(token) + headers = provider.get_headers() + assert headers["Authorization"] == f"Bearer {token}" + assert headers["Authorization"].startswith("Bearer ") + + +class TestGetAuthHeaders: + """Test get_auth_headers utility function.""" + + def test_with_api_key_provider(self): + """Test with ApiKeyAuthProvider.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key") + headers = get_auth_headers(provider) + + assert isinstance(headers, dict) + assert headers == {"X-API-Key": "test-key"} + + def test_with_bearer_provider(self): + """Test with BearerAuthProvider.""" + provider = BearerAuthProvider("test-token") + headers = get_auth_headers(provider) + + assert isinstance(headers, dict) + assert headers == {"Authorization": "Bearer test-token"} + + def test_with_unknown_provider(self): + """Test with unknown provider type.""" + unknown_provider = Mock() + headers = get_auth_headers(unknown_provider) + + assert isinstance(headers, dict) + assert headers == {} + + def test_with_none_provider(self): + """Test with None provider.""" + headers = get_auth_headers(None) + + assert isinstance(headers, dict) + assert headers == {} + + def test_with_provider_without_get_headers(self): + """Test with object that doesn't have get_headers method.""" + fake_provider = object() + headers = get_auth_headers(fake_provider) + + assert isinstance(headers, dict) + assert headers == {} + + def test_with_string_provider(self): + """Test with string instead of provider object.""" + headers = get_auth_headers("not-a-provider") + + assert isinstance(headers, dict) + assert headers == {} + + def test_with_dict_provider(self): + """Test with dict instead of provider object.""" + headers = get_auth_headers({"key": "value"}) + + assert isinstance(headers, dict) + assert headers == {} + + def test_multiple_provider_types(self): + """Test with multiple different provider types in sequence.""" + api_key_provider = ApiKeyAuthProvider("X-API-Key", "api-key") + bearer_provider = BearerAuthProvider("bearer-token") + + api_headers = get_auth_headers(api_key_provider) + bearer_headers = get_auth_headers(bearer_provider) + none_headers = get_auth_headers(None) + + assert api_headers == {"X-API-Key": "api-key"} + assert bearer_headers == {"Authorization": "Bearer bearer-token"} + assert none_headers == {} + + def test_provider_inheritance_check(self): + """Test that the function properly checks instance types.""" + + # Create a class that has get_headers but isn't a known provider + class FakeProvider: + def get_headers(self): + return {"Fake": "header"} + + fake = FakeProvider() + headers = get_auth_headers(fake) + + # Should return empty dict since it's not an ApiKeyAuthProvider or BearerAuthProvider + assert headers == {} + + +class TestProviderIntegration: + """Test provider integration scenarios.""" + + def test_api_key_provider_real_world_headers(self): + """Test API key provider with real-world header names.""" + test_cases = [ + ("X-API-Key", "sk-1234567890abcdef"), + ("Authorization", "ApiKey sk-abcdef1234567890"), + ("X-RapidAPI-Key", "rapidapi-key-here"), + ("Ocp-Apim-Subscription-Key", "azure-key"), + ("x-api-key", "lowercase-header"), + ] + + for header_name, header_value in test_cases: + provider = ApiKeyAuthProvider(header_name, header_value) + headers = provider.get_headers() + assert headers == {header_name: header_value} + + def test_bearer_provider_real_world_tokens(self): + """Test Bearer provider with real-world token formats.""" + test_tokens = [ + "simple_token", + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature", # JWT-like + "ghp_1234567890abcdef1234567890abcdef12345678", # GitHub token-like + "sk-1234567890abcdef1234567890abcdef1234567890abcdef1234567890", # OpenAI-like + "xoxb-1234567890-1234567890-abcdefghijklmnopqrstuvwx", # Slack-like + ] + + for token in test_tokens: + provider = BearerAuthProvider(token) + headers = provider.get_headers() + assert headers == {"Authorization": f"Bearer {token}"} + + def test_providers_with_special_characters(self): + """Test providers with special characters in values.""" + # API Key with special characters + api_provider = ApiKeyAuthProvider( + "X-API-Key", "key!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" + ) + api_headers = api_provider.get_headers() + assert "X-API-Key" in api_headers + assert api_headers["X-API-Key"] == "key!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" + + # Bearer token with special characters + bearer_provider = BearerAuthProvider("token!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`") + bearer_headers = bearer_provider.get_headers() + assert ( + bearer_headers["Authorization"] + == "Bearer token!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" + ) + + def test_providers_immutability(self): + """Test that providers don't modify their internal state.""" + # Test API Key provider + api_provider = ApiKeyAuthProvider("X-API-Key", "original-key") + original_name = api_provider.header_name + original_value = api_provider.header_value + + # Get headers multiple times + for _ in range(3): + headers = api_provider.get_headers() + headers["X-API-Key"] = "modified-key" # Try to modify returned dict + + # Verify original values unchanged + assert api_provider.header_name == original_name + assert api_provider.header_value == original_value + + # Test Bearer provider + bearer_provider = BearerAuthProvider("original-token") + original_token = bearer_provider.token + + # Get headers multiple times + for _ in range(3): + headers = bearer_provider.get_headers() + headers["Authorization"] = ( + "Bearer modified-token" # Try to modify returned dict + ) + + # Verify original token unchanged + assert bearer_provider.token == original_token diff --git a/uv.lock b/uv.lock index 0db43d4..2770b01 100644 --- a/uv.lock +++ b/uv.lock @@ -681,6 +681,96 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180, upload-time = "2024-03-12T16:53:39.226Z" }, ] +[[package]] +name = "coverage" +version = "7.10.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/14/70/025b179c993f019105b79575ac6edb5e084fb0f0e63f15cdebef4e454fb5/coverage-7.10.6.tar.gz", hash = "sha256:f644a3ae5933a552a29dbb9aa2f90c677a875f80ebea028e5a52a4f429044b90", size = 823736, upload-time = "2025-08-29T15:35:16.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/1d/2e64b43d978b5bd184e0756a41415597dfef30fcbd90b747474bd749d45f/coverage-7.10.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70e7bfbd57126b5554aa482691145f798d7df77489a177a6bef80de78860a356", size = 217025, upload-time = "2025-08-29T15:32:57.169Z" }, + { url = "https://files.pythonhosted.org/packages/23/62/b1e0f513417c02cc10ef735c3ee5186df55f190f70498b3702d516aad06f/coverage-7.10.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e41be6f0f19da64af13403e52f2dec38bbc2937af54df8ecef10850ff8d35301", size = 217419, upload-time = "2025-08-29T15:32:59.908Z" }, + { url = "https://files.pythonhosted.org/packages/e7/16/b800640b7a43e7c538429e4d7223e0a94fd72453a1a048f70bf766f12e96/coverage-7.10.6-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c61fc91ab80b23f5fddbee342d19662f3d3328173229caded831aa0bd7595460", size = 244180, upload-time = "2025-08-29T15:33:01.608Z" }, + { url = "https://files.pythonhosted.org/packages/fb/6f/5e03631c3305cad187eaf76af0b559fff88af9a0b0c180d006fb02413d7a/coverage-7.10.6-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10356fdd33a7cc06e8051413140bbdc6f972137508a3572e3f59f805cd2832fd", size = 245992, upload-time = "2025-08-29T15:33:03.239Z" }, + { url = "https://files.pythonhosted.org/packages/eb/a1/f30ea0fb400b080730125b490771ec62b3375789f90af0bb68bfb8a921d7/coverage-7.10.6-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:80b1695cf7c5ebe7b44bf2521221b9bb8cdf69b1f24231149a7e3eb1ae5fa2fb", size = 247851, upload-time = "2025-08-29T15:33:04.603Z" }, + { url = "https://files.pythonhosted.org/packages/02/8e/cfa8fee8e8ef9a6bb76c7bef039f3302f44e615d2194161a21d3d83ac2e9/coverage-7.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2e4c33e6378b9d52d3454bd08847a8651f4ed23ddbb4a0520227bd346382bbc6", size = 245891, upload-time = "2025-08-29T15:33:06.176Z" }, + { url = "https://files.pythonhosted.org/packages/93/a9/51be09b75c55c4f6c16d8d73a6a1d46ad764acca0eab48fa2ffaef5958fe/coverage-7.10.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c8a3ec16e34ef980a46f60dc6ad86ec60f763c3f2fa0db6d261e6e754f72e945", size = 243909, upload-time = "2025-08-29T15:33:07.74Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a6/ba188b376529ce36483b2d585ca7bdac64aacbe5aa10da5978029a9c94db/coverage-7.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7d79dabc0a56f5af990cc6da9ad1e40766e82773c075f09cc571e2076fef882e", size = 244786, upload-time = "2025-08-29T15:33:08.965Z" }, + { url = "https://files.pythonhosted.org/packages/d0/4c/37ed872374a21813e0d3215256180c9a382c3f5ced6f2e5da0102fc2fd3e/coverage-7.10.6-cp310-cp310-win32.whl", hash = "sha256:86b9b59f2b16e981906e9d6383eb6446d5b46c278460ae2c36487667717eccf1", size = 219521, upload-time = "2025-08-29T15:33:10.599Z" }, + { url = "https://files.pythonhosted.org/packages/8e/36/9311352fdc551dec5b973b61f4e453227ce482985a9368305880af4f85dd/coverage-7.10.6-cp310-cp310-win_amd64.whl", hash = "sha256:e132b9152749bd33534e5bd8565c7576f135f157b4029b975e15ee184325f528", size = 220417, upload-time = "2025-08-29T15:33:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/d4/16/2bea27e212c4980753d6d563a0803c150edeaaddb0771a50d2afc410a261/coverage-7.10.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c706db3cabb7ceef779de68270150665e710b46d56372455cd741184f3868d8f", size = 217129, upload-time = "2025-08-29T15:33:13.575Z" }, + { url = "https://files.pythonhosted.org/packages/2a/51/e7159e068831ab37e31aac0969d47b8c5ee25b7d307b51e310ec34869315/coverage-7.10.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e0c38dc289e0508ef68ec95834cb5d2e96fdbe792eaccaa1bccac3966bbadcc", size = 217532, upload-time = "2025-08-29T15:33:14.872Z" }, + { url = "https://files.pythonhosted.org/packages/e7/c0/246ccbea53d6099325d25cd208df94ea435cd55f0db38099dd721efc7a1f/coverage-7.10.6-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:752a3005a1ded28f2f3a6e8787e24f28d6abe176ca64677bcd8d53d6fe2ec08a", size = 247931, upload-time = "2025-08-29T15:33:16.142Z" }, + { url = "https://files.pythonhosted.org/packages/7d/fb/7435ef8ab9b2594a6e3f58505cc30e98ae8b33265d844007737946c59389/coverage-7.10.6-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:689920ecfd60f992cafca4f5477d55720466ad2c7fa29bb56ac8d44a1ac2b47a", size = 249864, upload-time = "2025-08-29T15:33:17.434Z" }, + { url = "https://files.pythonhosted.org/packages/51/f8/d9d64e8da7bcddb094d511154824038833c81e3a039020a9d6539bf303e9/coverage-7.10.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec98435796d2624d6905820a42f82149ee9fc4f2d45c2c5bc5a44481cc50db62", size = 251969, upload-time = "2025-08-29T15:33:18.822Z" }, + { url = "https://files.pythonhosted.org/packages/43/28/c43ba0ef19f446d6463c751315140d8f2a521e04c3e79e5c5fe211bfa430/coverage-7.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b37201ce4a458c7a758ecc4efa92fa8ed783c66e0fa3c42ae19fc454a0792153", size = 249659, upload-time = "2025-08-29T15:33:20.407Z" }, + { url = "https://files.pythonhosted.org/packages/79/3e/53635bd0b72beaacf265784508a0b386defc9ab7fad99ff95f79ce9db555/coverage-7.10.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:2904271c80898663c810a6b067920a61dd8d38341244a3605bd31ab55250dad5", size = 247714, upload-time = "2025-08-29T15:33:21.751Z" }, + { url = "https://files.pythonhosted.org/packages/4c/55/0964aa87126624e8c159e32b0bc4e84edef78c89a1a4b924d28dd8265625/coverage-7.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5aea98383463d6e1fa4e95416d8de66f2d0cb588774ee20ae1b28df826bcb619", size = 248351, upload-time = "2025-08-29T15:33:23.105Z" }, + { url = "https://files.pythonhosted.org/packages/eb/ab/6cfa9dc518c6c8e14a691c54e53a9433ba67336c760607e299bfcf520cb1/coverage-7.10.6-cp311-cp311-win32.whl", hash = "sha256:e3fb1fa01d3598002777dd259c0c2e6d9d5e10e7222976fc8e03992f972a2cba", size = 219562, upload-time = "2025-08-29T15:33:24.717Z" }, + { url = "https://files.pythonhosted.org/packages/5b/18/99b25346690cbc55922e7cfef06d755d4abee803ef335baff0014268eff4/coverage-7.10.6-cp311-cp311-win_amd64.whl", hash = "sha256:f35ed9d945bece26553d5b4c8630453169672bea0050a564456eb88bdffd927e", size = 220453, upload-time = "2025-08-29T15:33:26.482Z" }, + { url = "https://files.pythonhosted.org/packages/d8/ed/81d86648a07ccb124a5cf1f1a7788712b8d7216b593562683cd5c9b0d2c1/coverage-7.10.6-cp311-cp311-win_arm64.whl", hash = "sha256:99e1a305c7765631d74b98bf7dbf54eeea931f975e80f115437d23848ee8c27c", size = 219127, upload-time = "2025-08-29T15:33:27.777Z" }, + { url = "https://files.pythonhosted.org/packages/26/06/263f3305c97ad78aab066d116b52250dd316e74fcc20c197b61e07eb391a/coverage-7.10.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5b2dd6059938063a2c9fee1af729d4f2af28fd1a545e9b7652861f0d752ebcea", size = 217324, upload-time = "2025-08-29T15:33:29.06Z" }, + { url = "https://files.pythonhosted.org/packages/e9/60/1e1ded9a4fe80d843d7d53b3e395c1db3ff32d6c301e501f393b2e6c1c1f/coverage-7.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:388d80e56191bf846c485c14ae2bc8898aa3124d9d35903fef7d907780477634", size = 217560, upload-time = "2025-08-29T15:33:30.748Z" }, + { url = "https://files.pythonhosted.org/packages/b8/25/52136173c14e26dfed8b106ed725811bb53c30b896d04d28d74cb64318b3/coverage-7.10.6-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:90cb5b1a4670662719591aa92d0095bb41714970c0b065b02a2610172dbf0af6", size = 249053, upload-time = "2025-08-29T15:33:32.041Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1d/ae25a7dc58fcce8b172d42ffe5313fc267afe61c97fa872b80ee72d9515a/coverage-7.10.6-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:961834e2f2b863a0e14260a9a273aff07ff7818ab6e66d2addf5628590c628f9", size = 251802, upload-time = "2025-08-29T15:33:33.625Z" }, + { url = "https://files.pythonhosted.org/packages/f5/7a/1f561d47743710fe996957ed7c124b421320f150f1d38523d8d9102d3e2a/coverage-7.10.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf9a19f5012dab774628491659646335b1928cfc931bf8d97b0d5918dd58033c", size = 252935, upload-time = "2025-08-29T15:33:34.909Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ad/8b97cd5d28aecdfde792dcbf646bac141167a5cacae2cd775998b45fabb5/coverage-7.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99c4283e2a0e147b9c9cc6bc9c96124de9419d6044837e9799763a0e29a7321a", size = 250855, upload-time = "2025-08-29T15:33:36.922Z" }, + { url = "https://files.pythonhosted.org/packages/33/6a/95c32b558d9a61858ff9d79580d3877df3eb5bc9eed0941b1f187c89e143/coverage-7.10.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:282b1b20f45df57cc508c1e033403f02283adfb67d4c9c35a90281d81e5c52c5", size = 248974, upload-time = "2025-08-29T15:33:38.175Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9c/8ce95dee640a38e760d5b747c10913e7a06554704d60b41e73fdea6a1ffd/coverage-7.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8cdbe264f11afd69841bd8c0d83ca10b5b32853263ee62e6ac6a0ab63895f972", size = 250409, upload-time = "2025-08-29T15:33:39.447Z" }, + { url = "https://files.pythonhosted.org/packages/04/12/7a55b0bdde78a98e2eb2356771fd2dcddb96579e8342bb52aa5bc52e96f0/coverage-7.10.6-cp312-cp312-win32.whl", hash = "sha256:a517feaf3a0a3eca1ee985d8373135cfdedfbba3882a5eab4362bda7c7cf518d", size = 219724, upload-time = "2025-08-29T15:33:41.172Z" }, + { url = "https://files.pythonhosted.org/packages/36/4a/32b185b8b8e327802c9efce3d3108d2fe2d9d31f153a0f7ecfd59c773705/coverage-7.10.6-cp312-cp312-win_amd64.whl", hash = "sha256:856986eadf41f52b214176d894a7de05331117f6035a28ac0016c0f63d887629", size = 220536, upload-time = "2025-08-29T15:33:42.524Z" }, + { url = "https://files.pythonhosted.org/packages/08/3a/d5d8dc703e4998038c3099eaf77adddb00536a3cec08c8dcd556a36a3eb4/coverage-7.10.6-cp312-cp312-win_arm64.whl", hash = "sha256:acf36b8268785aad739443fa2780c16260ee3fa09d12b3a70f772ef100939d80", size = 219171, upload-time = "2025-08-29T15:33:43.974Z" }, + { url = "https://files.pythonhosted.org/packages/bd/e7/917e5953ea29a28c1057729c1d5af9084ab6d9c66217523fd0e10f14d8f6/coverage-7.10.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ffea0575345e9ee0144dfe5701aa17f3ba546f8c3bb48db62ae101afb740e7d6", size = 217351, upload-time = "2025-08-29T15:33:45.438Z" }, + { url = "https://files.pythonhosted.org/packages/eb/86/2e161b93a4f11d0ea93f9bebb6a53f113d5d6e416d7561ca41bb0a29996b/coverage-7.10.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:95d91d7317cde40a1c249d6b7382750b7e6d86fad9d8eaf4fa3f8f44cf171e80", size = 217600, upload-time = "2025-08-29T15:33:47.269Z" }, + { url = "https://files.pythonhosted.org/packages/0e/66/d03348fdd8df262b3a7fb4ee5727e6e4936e39e2f3a842e803196946f200/coverage-7.10.6-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3e23dd5408fe71a356b41baa82892772a4cefcf758f2ca3383d2aa39e1b7a003", size = 248600, upload-time = "2025-08-29T15:33:48.953Z" }, + { url = "https://files.pythonhosted.org/packages/73/dd/508420fb47d09d904d962f123221bc249f64b5e56aa93d5f5f7603be475f/coverage-7.10.6-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0f3f56e4cb573755e96a16501a98bf211f100463d70275759e73f3cbc00d4f27", size = 251206, upload-time = "2025-08-29T15:33:50.697Z" }, + { url = "https://files.pythonhosted.org/packages/e9/1f/9020135734184f439da85c70ea78194c2730e56c2d18aee6e8ff1719d50d/coverage-7.10.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:db4a1d897bbbe7339946ffa2fe60c10cc81c43fab8b062d3fcb84188688174a4", size = 252478, upload-time = "2025-08-29T15:33:52.303Z" }, + { url = "https://files.pythonhosted.org/packages/a4/a4/3d228f3942bb5a2051fde28c136eea23a761177dc4ff4ef54533164ce255/coverage-7.10.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d8fd7879082953c156d5b13c74aa6cca37f6a6f4747b39538504c3f9c63d043d", size = 250637, upload-time = "2025-08-29T15:33:53.67Z" }, + { url = "https://files.pythonhosted.org/packages/36/e3/293dce8cdb9a83de971637afc59b7190faad60603b40e32635cbd15fbf61/coverage-7.10.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:28395ca3f71cd103b8c116333fa9db867f3a3e1ad6a084aa3725ae002b6583bc", size = 248529, upload-time = "2025-08-29T15:33:55.022Z" }, + { url = "https://files.pythonhosted.org/packages/90/26/64eecfa214e80dd1d101e420cab2901827de0e49631d666543d0e53cf597/coverage-7.10.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:61c950fc33d29c91b9e18540e1aed7d9f6787cc870a3e4032493bbbe641d12fc", size = 250143, upload-time = "2025-08-29T15:33:56.386Z" }, + { url = "https://files.pythonhosted.org/packages/3e/70/bd80588338f65ea5b0d97e424b820fb4068b9cfb9597fbd91963086e004b/coverage-7.10.6-cp313-cp313-win32.whl", hash = "sha256:160c00a5e6b6bdf4e5984b0ef21fc860bc94416c41b7df4d63f536d17c38902e", size = 219770, upload-time = "2025-08-29T15:33:58.063Z" }, + { url = "https://files.pythonhosted.org/packages/a7/14/0b831122305abcc1060c008f6c97bbdc0a913ab47d65070a01dc50293c2b/coverage-7.10.6-cp313-cp313-win_amd64.whl", hash = "sha256:628055297f3e2aa181464c3808402887643405573eb3d9de060d81531fa79d32", size = 220566, upload-time = "2025-08-29T15:33:59.766Z" }, + { url = "https://files.pythonhosted.org/packages/83/c6/81a83778c1f83f1a4a168ed6673eeedc205afb562d8500175292ca64b94e/coverage-7.10.6-cp313-cp313-win_arm64.whl", hash = "sha256:df4ec1f8540b0bcbe26ca7dd0f541847cc8a108b35596f9f91f59f0c060bfdd2", size = 219195, upload-time = "2025-08-29T15:34:01.191Z" }, + { url = "https://files.pythonhosted.org/packages/d7/1c/ccccf4bf116f9517275fa85047495515add43e41dfe8e0bef6e333c6b344/coverage-7.10.6-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c9a8b7a34a4de3ed987f636f71881cd3b8339f61118b1aa311fbda12741bff0b", size = 218059, upload-time = "2025-08-29T15:34:02.91Z" }, + { url = "https://files.pythonhosted.org/packages/92/97/8a3ceff833d27c7492af4f39d5da6761e9ff624831db9e9f25b3886ddbca/coverage-7.10.6-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8dd5af36092430c2b075cee966719898f2ae87b636cefb85a653f1d0ba5d5393", size = 218287, upload-time = "2025-08-29T15:34:05.106Z" }, + { url = "https://files.pythonhosted.org/packages/92/d8/50b4a32580cf41ff0423777a2791aaf3269ab60c840b62009aec12d3970d/coverage-7.10.6-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b0353b0f0850d49ada66fdd7d0c7cdb0f86b900bb9e367024fd14a60cecc1e27", size = 259625, upload-time = "2025-08-29T15:34:06.575Z" }, + { url = "https://files.pythonhosted.org/packages/7e/7e/6a7df5a6fb440a0179d94a348eb6616ed4745e7df26bf2a02bc4db72c421/coverage-7.10.6-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d6b9ae13d5d3e8aeca9ca94198aa7b3ebbc5acfada557d724f2a1f03d2c0b0df", size = 261801, upload-time = "2025-08-29T15:34:08.006Z" }, + { url = "https://files.pythonhosted.org/packages/3a/4c/a270a414f4ed5d196b9d3d67922968e768cd971d1b251e1b4f75e9362f75/coverage-7.10.6-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:675824a363cc05781b1527b39dc2587b8984965834a748177ee3c37b64ffeafb", size = 264027, upload-time = "2025-08-29T15:34:09.806Z" }, + { url = "https://files.pythonhosted.org/packages/9c/8b/3210d663d594926c12f373c5370bf1e7c5c3a427519a8afa65b561b9a55c/coverage-7.10.6-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:692d70ea725f471a547c305f0d0fc6a73480c62fb0da726370c088ab21aed282", size = 261576, upload-time = "2025-08-29T15:34:11.585Z" }, + { url = "https://files.pythonhosted.org/packages/72/d0/e1961eff67e9e1dba3fc5eb7a4caf726b35a5b03776892da8d79ec895775/coverage-7.10.6-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:851430a9a361c7a8484a36126d1d0ff8d529d97385eacc8dfdc9bfc8c2d2cbe4", size = 259341, upload-time = "2025-08-29T15:34:13.159Z" }, + { url = "https://files.pythonhosted.org/packages/3a/06/d6478d152cd189b33eac691cba27a40704990ba95de49771285f34a5861e/coverage-7.10.6-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d9369a23186d189b2fc95cc08b8160ba242057e887d766864f7adf3c46b2df21", size = 260468, upload-time = "2025-08-29T15:34:14.571Z" }, + { url = "https://files.pythonhosted.org/packages/ed/73/737440247c914a332f0b47f7598535b29965bf305e19bbc22d4c39615d2b/coverage-7.10.6-cp313-cp313t-win32.whl", hash = "sha256:92be86fcb125e9bda0da7806afd29a3fd33fdf58fba5d60318399adf40bf37d0", size = 220429, upload-time = "2025-08-29T15:34:16.394Z" }, + { url = "https://files.pythonhosted.org/packages/bd/76/b92d3214740f2357ef4a27c75a526eb6c28f79c402e9f20a922c295c05e2/coverage-7.10.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6b3039e2ca459a70c79523d39347d83b73f2f06af5624905eba7ec34d64d80b5", size = 221493, upload-time = "2025-08-29T15:34:17.835Z" }, + { url = "https://files.pythonhosted.org/packages/fc/8e/6dcb29c599c8a1f654ec6cb68d76644fe635513af16e932d2d4ad1e5ac6e/coverage-7.10.6-cp313-cp313t-win_arm64.whl", hash = "sha256:3fb99d0786fe17b228eab663d16bee2288e8724d26a199c29325aac4b0319b9b", size = 219757, upload-time = "2025-08-29T15:34:19.248Z" }, + { url = "https://files.pythonhosted.org/packages/d3/aa/76cf0b5ec00619ef208da4689281d48b57f2c7fde883d14bf9441b74d59f/coverage-7.10.6-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:6008a021907be8c4c02f37cdc3ffb258493bdebfeaf9a839f9e71dfdc47b018e", size = 217331, upload-time = "2025-08-29T15:34:20.846Z" }, + { url = "https://files.pythonhosted.org/packages/65/91/8e41b8c7c505d398d7730206f3cbb4a875a35ca1041efc518051bfce0f6b/coverage-7.10.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5e75e37f23eb144e78940b40395b42f2321951206a4f50e23cfd6e8a198d3ceb", size = 217607, upload-time = "2025-08-29T15:34:22.433Z" }, + { url = "https://files.pythonhosted.org/packages/87/7f/f718e732a423d442e6616580a951b8d1ec3575ea48bcd0e2228386805e79/coverage-7.10.6-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0f7cb359a448e043c576f0da00aa8bfd796a01b06aa610ca453d4dde09cc1034", size = 248663, upload-time = "2025-08-29T15:34:24.425Z" }, + { url = "https://files.pythonhosted.org/packages/e6/52/c1106120e6d801ac03e12b5285e971e758e925b6f82ee9b86db3aa10045d/coverage-7.10.6-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c68018e4fc4e14b5668f1353b41ccf4bc83ba355f0e1b3836861c6f042d89ac1", size = 251197, upload-time = "2025-08-29T15:34:25.906Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ec/3a8645b1bb40e36acde9c0609f08942852a4af91a937fe2c129a38f2d3f5/coverage-7.10.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cd4b2b0707fc55afa160cd5fc33b27ccbf75ca11d81f4ec9863d5793fc6df56a", size = 252551, upload-time = "2025-08-29T15:34:27.337Z" }, + { url = "https://files.pythonhosted.org/packages/a1/70/09ecb68eeb1155b28a1d16525fd3a9b65fbe75337311a99830df935d62b6/coverage-7.10.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:4cec13817a651f8804a86e4f79d815b3b28472c910e099e4d5a0e8a3b6a1d4cb", size = 250553, upload-time = "2025-08-29T15:34:29.065Z" }, + { url = "https://files.pythonhosted.org/packages/c6/80/47df374b893fa812e953b5bc93dcb1427a7b3d7a1a7d2db33043d17f74b9/coverage-7.10.6-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:f2a6a8e06bbda06f78739f40bfb56c45d14eb8249d0f0ea6d4b3d48e1f7c695d", size = 248486, upload-time = "2025-08-29T15:34:30.897Z" }, + { url = "https://files.pythonhosted.org/packages/4a/65/9f98640979ecee1b0d1a7164b589de720ddf8100d1747d9bbdb84be0c0fb/coverage-7.10.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:081b98395ced0d9bcf60ada7661a0b75f36b78b9d7e39ea0790bb4ed8da14747", size = 249981, upload-time = "2025-08-29T15:34:32.365Z" }, + { url = "https://files.pythonhosted.org/packages/1f/55/eeb6603371e6629037f47bd25bef300387257ed53a3c5fdb159b7ac8c651/coverage-7.10.6-cp314-cp314-win32.whl", hash = "sha256:6937347c5d7d069ee776b2bf4e1212f912a9f1f141a429c475e6089462fcecc5", size = 220054, upload-time = "2025-08-29T15:34:34.124Z" }, + { url = "https://files.pythonhosted.org/packages/15/d1/a0912b7611bc35412e919a2cd59ae98e7ea3b475e562668040a43fb27897/coverage-7.10.6-cp314-cp314-win_amd64.whl", hash = "sha256:adec1d980fa07e60b6ef865f9e5410ba760e4e1d26f60f7e5772c73b9a5b0713", size = 220851, upload-time = "2025-08-29T15:34:35.651Z" }, + { url = "https://files.pythonhosted.org/packages/ef/2d/11880bb8ef80a45338e0b3e0725e4c2d73ffbb4822c29d987078224fd6a5/coverage-7.10.6-cp314-cp314-win_arm64.whl", hash = "sha256:a80f7aef9535442bdcf562e5a0d5a5538ce8abe6bb209cfbf170c462ac2c2a32", size = 219429, upload-time = "2025-08-29T15:34:37.16Z" }, + { url = "https://files.pythonhosted.org/packages/83/c0/1f00caad775c03a700146f55536ecd097a881ff08d310a58b353a1421be0/coverage-7.10.6-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:0de434f4fbbe5af4fa7989521c655c8c779afb61c53ab561b64dcee6149e4c65", size = 218080, upload-time = "2025-08-29T15:34:38.919Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c4/b1c5d2bd7cc412cbeb035e257fd06ed4e3e139ac871d16a07434e145d18d/coverage-7.10.6-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6e31b8155150c57e5ac43ccd289d079eb3f825187d7c66e755a055d2c85794c6", size = 218293, upload-time = "2025-08-29T15:34:40.425Z" }, + { url = "https://files.pythonhosted.org/packages/3f/07/4468d37c94724bf6ec354e4ec2f205fda194343e3e85fd2e59cec57e6a54/coverage-7.10.6-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:98cede73eb83c31e2118ae8d379c12e3e42736903a8afcca92a7218e1f2903b0", size = 259800, upload-time = "2025-08-29T15:34:41.996Z" }, + { url = "https://files.pythonhosted.org/packages/82/d8/f8fb351be5fee31690cd8da768fd62f1cfab33c31d9f7baba6cd8960f6b8/coverage-7.10.6-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f863c08f4ff6b64fa8045b1e3da480f5374779ef187f07b82e0538c68cb4ff8e", size = 261965, upload-time = "2025-08-29T15:34:43.61Z" }, + { url = "https://files.pythonhosted.org/packages/e8/70/65d4d7cfc75c5c6eb2fed3ee5cdf420fd8ae09c4808723a89a81d5b1b9c3/coverage-7.10.6-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b38261034fda87be356f2c3f42221fdb4171c3ce7658066ae449241485390d5", size = 264220, upload-time = "2025-08-29T15:34:45.387Z" }, + { url = "https://files.pythonhosted.org/packages/98/3c/069df106d19024324cde10e4ec379fe2fb978017d25e97ebee23002fbadf/coverage-7.10.6-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0e93b1476b79eae849dc3872faeb0bf7948fd9ea34869590bc16a2a00b9c82a7", size = 261660, upload-time = "2025-08-29T15:34:47.288Z" }, + { url = "https://files.pythonhosted.org/packages/fc/8a/2974d53904080c5dc91af798b3a54a4ccb99a45595cc0dcec6eb9616a57d/coverage-7.10.6-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ff8a991f70f4c0cf53088abf1e3886edcc87d53004c7bb94e78650b4d3dac3b5", size = 259417, upload-time = "2025-08-29T15:34:48.779Z" }, + { url = "https://files.pythonhosted.org/packages/30/38/9616a6b49c686394b318974d7f6e08f38b8af2270ce7488e879888d1e5db/coverage-7.10.6-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ac765b026c9f33044419cbba1da913cfb82cca1b60598ac1c7a5ed6aac4621a0", size = 260567, upload-time = "2025-08-29T15:34:50.718Z" }, + { url = "https://files.pythonhosted.org/packages/76/16/3ed2d6312b371a8cf804abf4e14895b70e4c3491c6e53536d63fd0958a8d/coverage-7.10.6-cp314-cp314t-win32.whl", hash = "sha256:441c357d55f4936875636ef2cfb3bee36e466dcf50df9afbd398ce79dba1ebb7", size = 220831, upload-time = "2025-08-29T15:34:52.653Z" }, + { url = "https://files.pythonhosted.org/packages/d5/e5/d38d0cb830abede2adb8b147770d2a3d0e7fecc7228245b9b1ae6c24930a/coverage-7.10.6-cp314-cp314t-win_amd64.whl", hash = "sha256:073711de3181b2e204e4870ac83a7c4853115b42e9cd4d145f2231e12d670930", size = 221950, upload-time = "2025-08-29T15:34:54.212Z" }, + { url = "https://files.pythonhosted.org/packages/f4/51/e48e550f6279349895b0ffcd6d2a690e3131ba3a7f4eafccc141966d4dea/coverage-7.10.6-cp314-cp314t-win_arm64.whl", hash = "sha256:137921f2bac5559334ba66122b753db6dc5d1cf01eb7b64eb412bb0d064ef35b", size = 219969, upload-time = "2025-08-29T15:34:55.83Z" }, + { url = "https://files.pythonhosted.org/packages/44/0c/50db5379b615854b5cf89146f8f5bd1d5a9693d7f3a987e269693521c404/coverage-7.10.6-py3-none-any.whl", hash = "sha256:92c4ecf6bf11b2e85fd4d8204814dc26e6a19f0c9d938c207c5cb0eadfcabbe3", size = 208986, upload-time = "2025-08-29T15:35:14.506Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + [[package]] name = "crewai" version = "0.134.0" @@ -2504,6 +2594,7 @@ smolagents = [ { name = "smolagents" }, ] test = [ + { name = "coverage", extra = ["toml"] }, { name = "crewai" }, { name = "google-genai" }, { name = "langchain" }, @@ -2511,7 +2602,9 @@ test = [ { name = "langgraph" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-cov" }, { name = "pytest-datadir" }, + { name = "pytest-mock" }, { name = "smolagents" }, { name = "soundfile" }, { name = "torchaudio" }, @@ -2530,6 +2623,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "coverage", extras = ["toml"], marker = "extra == 'test'", specifier = ">=7.0.0" }, { name = "crewai", marker = "extra == 'crewai'", specifier = ">=0.108.0" }, { name = "crewai", marker = "extra == 'test'", specifier = ">=0.108.0" }, { name = "google-genai", marker = "extra == 'google-genai'", specifier = ">=1.2.0" }, @@ -2546,7 +2640,9 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.10.6" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.4" }, { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.25.2" }, + { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" }, { name = "pytest-datadir", marker = "extra == 'test'", specifier = ">=1.7.2" }, + { name = "pytest-mock", marker = "extra == 'test'", specifier = ">=3.10.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "smolagents", marker = "extra == 'smolagents'", specifier = ">=1.2.2" }, { name = "smolagents", marker = "extra == 'test'", specifier = ">=1.2.2" }, @@ -4367,6 +4463,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976, upload-time = "2025-05-26T04:54:39.035Z" }, ] +[[package]] +name = "pytest-cov" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/99/668cade231f434aaa59bbfbf49469068d2ddd945000621d3d165d2e7dd7b/pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2", size = 69432, upload-time = "2025-06-12T10:47:47.684Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" }, +] + [[package]] name = "pytest-datadir" version = "1.7.2" @@ -4379,6 +4489,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d3/61/a06f3406852534e186413c75f544c90251db00fd8eb9625ee3ac239499f3/pytest_datadir-1.7.2-py3-none-any.whl", hash = "sha256:8392ba0e9eaf37030e663dcd91cc5123dec99c44300f0c5eac44f35f13f0e086", size = 6273, upload-time = "2025-06-06T11:24:16.388Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/99/3323ee5c16b3637b4d941c362182d3e749c11e400bea31018c42219f3a98/pytest_mock-3.15.0.tar.gz", hash = "sha256:ab896bd190316b9d5d87b277569dfcdf718b2d049a2ccff5f7aca279c002a1cf", size = 33838, upload-time = "2025-09-04T20:57:48.679Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/b3/7fefc43fb706380144bcd293cc6e446e6f637ddfa8b83f48d1734156b529/pytest_mock-3.15.0-py3-none-any.whl", hash = "sha256:ef2219485fb1bd256b00e7ad7466ce26729b30eadfc7cbcdb4fa9a92ca68db6f", size = 10050, upload-time = "2025-09-04T20:57:47.274Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" From 366a5a9653cd4bdfced70866fa2185e6a34b5a21 Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sun, 7 Sep 2025 03:55:47 +0530 Subject: [PATCH 4/7] Remove sample from python-mcp --- impl_sample/mcp_sample.py | 363 -------------------------------------- 1 file changed, 363 deletions(-) delete mode 100644 impl_sample/mcp_sample.py diff --git a/impl_sample/mcp_sample.py b/impl_sample/mcp_sample.py deleted file mode 100644 index 7a9e322..0000000 --- a/impl_sample/mcp_sample.py +++ /dev/null @@ -1,363 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple MCP client example with OAuth authentication support. - -This client connects to an MCP server using streamable HTTP transport with OAuth. - -""" - -import asyncio -import os -import threading -import time -import webbrowser -from datetime import timedelta -from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any -from urllib.parse import parse_qs, urlparse - -from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamablehttp_client -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken - - -class InMemoryTokenStorage(TokenStorage): - """Simple in-memory token storage implementation.""" - - def __init__(self): - self._tokens: OAuthToken | None = None - self._client_info: OAuthClientInformationFull | None = None - - async def get_tokens(self) -> OAuthToken | None: - return self._tokens - - async def set_tokens(self, tokens: OAuthToken) -> None: - self._tokens = tokens - - async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info - - async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: - self._client_info = client_info - - -class CallbackHandler(BaseHTTPRequestHandler): - """Simple HTTP handler to capture OAuth callback.""" - - def __init__(self, request, client_address, server, callback_data): - """Initialize with callback data storage.""" - self.callback_data = callback_data - super().__init__(request, client_address, server) - - def do_GET(self): - """Handle GET request from OAuth redirect.""" - parsed = urlparse(self.path) - query_params = parse_qs(parsed.query) - - if "code" in query_params: - self.callback_data["authorization_code"] = query_params["code"][0] - self.callback_data["state"] = query_params.get("state", [None])[0] - self.send_response(200) - self.send_header("Content-type", "text/html") - self.end_headers() - self.wfile.write(b""" - - -

Authorization Successful!

-

You can close this window and return to the terminal.

- - - - """) - elif "error" in query_params: - self.callback_data["error"] = query_params["error"][0] - self.send_response(400) - self.send_header("Content-type", "text/html") - self.end_headers() - self.wfile.write( - f""" - - -

Authorization Failed

-

Error: {query_params["error"][0]}

-

You can close this window and return to the terminal.

- - - """.encode() - ) - else: - self.send_response(404) - self.end_headers() - - def log_message(self, format, *args): - """Suppress default logging.""" - pass - - -class CallbackServer: - """Simple server to handle OAuth callbacks.""" - - def __init__(self, port=3000): - self.port = port - self.server = None - self.thread = None - self.callback_data = {"authorization_code": None, "state": None, "error": None} - - def _create_handler_with_data(self): - """Create a handler class with access to callback data.""" - callback_data = self.callback_data - - class DataCallbackHandler(CallbackHandler): - def __init__(self, request, client_address, server): - super().__init__(request, client_address, server, callback_data) - - return DataCallbackHandler - - def start(self): - """Start the callback server in a background thread.""" - handler_class = self._create_handler_with_data() - self.server = HTTPServer(("localhost", self.port), handler_class) - self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) - self.thread.start() - print(f"🖥️ Started callback server on http://localhost:{self.port}") - - def stop(self): - """Stop the callback server.""" - if self.server: - self.server.shutdown() - self.server.server_close() - if self.thread: - self.thread.join(timeout=1) - - def wait_for_callback(self, timeout=300): - """Wait for OAuth callback with timeout.""" - start_time = time.time() - while time.time() - start_time < timeout: - if self.callback_data["authorization_code"]: - return self.callback_data["authorization_code"] - elif self.callback_data["error"]: - raise Exception(f"OAuth error: {self.callback_data['error']}") - time.sleep(0.1) - raise Exception("Timeout waiting for OAuth callback") - - def get_state(self): - """Get the received state parameter.""" - return self.callback_data["state"] - - -class SimpleAuthClient: - """Simple MCP client with auth support.""" - - def __init__(self, server_url: str, transport_type: str = "streamable_http"): - self.server_url = server_url - self.transport_type = transport_type - self.session: ClientSession | None = None - - async def connect(self): - """Connect to the MCP server.""" - print(f"🔗 Attempting to connect to {self.server_url}...") - - try: - callback_server = CallbackServer(port=3030) - callback_server.start() - - async def callback_handler() -> tuple[str, str | None]: - """Wait for OAuth callback and return auth code and state.""" - print("⏳ Waiting for authorization callback...") - try: - auth_code = callback_server.wait_for_callback(timeout=300) - return auth_code, callback_server.get_state() - finally: - callback_server.stop() - - client_metadata_dict = { - "client_name": "Simple Auth Client", - "redirect_uris": ["http://localhost:3030/callback"], - "grant_types": ["authorization_code", "refresh_token"], - "response_types": ["code"], - "token_endpoint_auth_method": "client_secret_post", - } - - async def _default_redirect_handler(authorization_url: str) -> None: - """Default redirect handler that opens the URL in a browser.""" - print(f"Opening browser for authorization: {authorization_url}") - webbrowser.open(authorization_url) - - # Create OAuth authentication handler using the new interface - oauth_auth = OAuthClientProvider( - server_url=self.server_url.replace("/mcp", ""), - client_metadata=OAuthClientMetadata.model_validate( - client_metadata_dict - ), - storage=InMemoryTokenStorage(), - redirect_handler=_default_redirect_handler, - callback_handler=callback_handler, - ) - - # Create transport with auth handler based on transport type - if self.transport_type == "sse": - print("📡 Opening SSE transport connection with auth...") - async with sse_client( - url=self.server_url, - auth=oauth_auth, - timeout=60, - ) as (read_stream, write_stream): - await self._run_session(read_stream, write_stream, None) - else: - print("📡 Opening StreamableHTTP transport connection with auth...") - async with streamablehttp_client( - url=self.server_url, - auth=oauth_auth, - timeout=timedelta(seconds=60), - ) as (read_stream, write_stream, get_session_id): - await self._run_session(read_stream, write_stream, get_session_id) - - except Exception as e: - print(f"❌ Failed to connect: {e}") - import traceback - - traceback.print_exc() - - async def _run_session(self, read_stream, write_stream, get_session_id): - """Run the MCP session with the given streams.""" - print("🤝 Initializing MCP session...") - async with ClientSession(read_stream, write_stream) as session: - self.session = session - print("⚡ Starting session initialization...") - await session.initialize() - print("✨ Session initialization complete!") - - print(f"\n✅ Connected to MCP server at {self.server_url}") - if get_session_id: - session_id = get_session_id() - if session_id: - print(f"Session ID: {session_id}") - - # Run interactive loop - await self.interactive_loop() - - async def list_tools(self): - """List available tools from the server.""" - if not self.session: - print("❌ Not connected to server") - return - - try: - result = await self.session.list_tools() - if hasattr(result, "tools") and result.tools: - print("\n📋 Available tools:") - for i, tool in enumerate(result.tools, 1): - print(f"{i}. {tool.name}") - if tool.description: - print(f" Description: {tool.description}") - print() - else: - print("No tools available") - except Exception as e: - print(f"❌ Failed to list tools: {e}") - - async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): - """Call a specific tool.""" - if not self.session: - print("❌ Not connected to server") - return - - try: - result = await self.session.call_tool(tool_name, arguments or {}) - print(f"\n🔧 Tool '{tool_name}' result:") - if hasattr(result, "content"): - for content in result.content: - if content.type == "text": - print(content.text) - else: - print(content) - else: - print(result) - except Exception as e: - print(f"❌ Failed to call tool '{tool_name}': {e}") - - async def interactive_loop(self): - """Run interactive command loop.""" - print("\n🎯 Interactive MCP Client") - print("Commands:") - print(" list - List available tools") - print(" call [args] - Call a tool") - print(" quit - Exit the client") - print() - - while True: - try: - command = input("mcp> ").strip() - - if not command: - continue - - if command == "quit": - break - - elif command == "list": - await self.list_tools() - - elif command.startswith("call "): - parts = command.split(maxsplit=2) - tool_name = parts[1] if len(parts) > 1 else "" - - if not tool_name: - print("❌ Please specify a tool name") - continue - - # Parse arguments (simple JSON-like format) - arguments = {} - if len(parts) > 2: - import json - - try: - arguments = json.loads(parts[2]) - except json.JSONDecodeError: - print("❌ Invalid arguments format (expected JSON)") - continue - - await self.call_tool(tool_name, arguments) - - else: - print( - "❌ Unknown command. Try 'list', 'call ', or 'quit'" - ) - - except KeyboardInterrupt: - print("\n\n👋 Goodbye!") - break - except EOFError: - break - - -async def main(): - """Main entry point.""" - # Default server URL - can be overridden with environment variable - # Most MCP streamable HTTP servers use /mcp as the endpoint - server_url = os.getenv("MCP_SERVER_PORT", 8000) - transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http") - server_url = ( - f"http://localhost:{server_url}/mcp" - if transport_type == "streamable_http" - else f"http://localhost:{server_url}/sse" - ) - - print("🚀 Simple MCP Auth Client") - print(f"Connecting to: {server_url}") - print(f"Transport type: {transport_type}") - - # Start connection flow - OAuth will be handled automatically - client = SimpleAuthClient(server_url, transport_type) - await client.connect() - - -def cli(): - """CLI entry point for uv script.""" - asyncio.run(main()) - - -if __name__ == "__main__": - cli() From 3426d820b67d69521679a6518bebb52976529816 Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sun, 7 Sep 2025 14:52:11 +0530 Subject: [PATCH 5/7] nicer interface for OAuth that avoids unnecessary repetition of parameters --- .gitignore | 3 + docs/auth/api-key.md | 43 ------ docs/auth/custom-handlers.md | 60 ++++++--- docs/auth/oauth.md | 149 ++++++++++----------- docs/auth/overview.md | 12 +- docs/auth/quickstart.md | 29 ++-- examples/canva_oauth_example.py | 21 ++- examples/oauth_with_credentials_example.py | 15 ++- src/mcpadapt/auth/__init__.py | 10 +- src/mcpadapt/auth/handlers.py | 52 ++++++- src/mcpadapt/auth/oauth.py | 19 +++ src/mcpadapt/auth/providers.py | 26 ++++ tests/auth/test_handlers.py | 75 ++++++----- tests/auth/test_providers.py | 120 ++++++++++++++++- 14 files changed, 415 insertions(+), 219 deletions(-) diff --git a/.gitignore b/.gitignore index 1b449d4..57ac625 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,6 @@ notebooks/ # Claude files **/.claude/settings.local.json + +.coverage +coverage.xml diff --git a/docs/auth/api-key.md b/docs/auth/api-key.md index ea48e19..81a0027 100644 --- a/docs/auth/api-key.md +++ b/docs/auth/api-key.md @@ -96,14 +96,6 @@ with MCPAdapt( ```python ApiKeyAuthProvider("X-API-Key", "abc123def456") ``` - -### Prefixed API Key - -```python -ApiKeyAuthProvider("X-API-Key", "Bearer abc123def456") -ApiKeyAuthProvider("Authorization", "API-Key abc123def456") -``` - ### Base64 Encoded Credentials ```python @@ -112,38 +104,3 @@ import base64 credentials = base64.b64encode(b"username:password").decode() ApiKeyAuthProvider("Authorization", f"Basic {credentials}") ``` - -## Best Practices - -### Security -- Never hard-code API keys in source code -- Use environment variables or secure configuration files -- Rotate API keys regularly -- Use the principle of least privilege for API key permissions - -### Configuration -- Use descriptive environment variable names -- Document required API keys in your README -- Provide clear error messages for missing keys -- Validate API key format before using - -## Integration Examples - -### With Different Frameworks - -```python -# SmolAgents -from mcpadapt.smolagents_adapter import SmolAgentsAdapter -adapter = SmolAgentsAdapter() - -# CrewAI -from mcpadapt.crewai_adapter import CrewAIAdapter -adapter = CrewAIAdapter() - -# LangChain -from mcpadapt.langchain_adapter import LangChainAdapter -adapter = LangChainAdapter() - -# All use the same API key provider -api_key_provider = ApiKeyAuthProvider("X-API-Key", os.getenv("API_KEY")) -``` diff --git a/docs/auth/custom-handlers.md b/docs/auth/custom-handlers.md index 5b64892..b27da4e 100644 --- a/docs/auth/custom-handlers.md +++ b/docs/auth/custom-handlers.md @@ -1,15 +1,21 @@ # Creating Custom Handlers -Custom OAuth handlers allow you to implement specialized authentication flows for different environments and use cases. +Custom OAuth handlers allow you to implement specialized authentication flows that are specific to how your application handles OAuth Redirects. ## BaseOAuthHandler Interface -All custom OAuth handlers must extend the `BaseOAuthHandler` abstract class: +All custom OAuth handlers must extend the `BaseOAuthHandler` abstract class and implement the required methods: ```python -from mcpadapt.auth import BaseOAuthHandler +from mcpadapt.auth import BaseOAuthHandler, OAuthClientMetadata +from typing import List +from pydantic import AnyUrl class CustomOAuthHandler(BaseOAuthHandler): + def get_redirect_uris(self) -> List[AnyUrl]: + """Return redirect URIs that this handler can process.""" + return [AnyUrl("http://localhost:8080/callback")] + async def handle_redirect(self, authorization_url: str) -> None: """Handle OAuth redirect to authorization URL.""" # Your custom redirect logic here @@ -26,11 +32,17 @@ class CustomOAuthHandler(BaseOAuthHandler): For server environments without a browser: ```python -from mcpadapt.auth import BaseOAuthHandler +from mcpadapt.auth import BaseOAuthHandler, OAuthClientMetadata +from typing import List +from pydantic import AnyUrl class HeadlessOAuthHandler(BaseOAuthHandler): """OAuth handler for headless environments.""" + def get_redirect_uris(self) -> List[AnyUrl]: + """Return out-of-band redirect for headless environments.""" + return [AnyUrl("urn:ietf:wg:oauth:2.0:oob")] + async def handle_redirect(self, authorization_url: str) -> None: print(f"Please open this URL in your browser:") print(f"{authorization_url}") @@ -47,17 +59,24 @@ class HeadlessOAuthHandler(BaseOAuthHandler): For applications with existing web servers: ```python -from mcpadapt.auth import BaseOAuthHandler +from mcpadapt.auth import BaseOAuthHandler, OAuthClientMetadata +from typing import List +from pydantic import AnyUrl import asyncio class CustomCallbackHandler(BaseOAuthHandler): """OAuth handler that integrates with existing web application.""" - def __init__(self, callback_url: str): + def __init__(self, client_metadata: OAuthClientMetadata, callback_url: str): + super().__init__(client_metadata) self.callback_url = callback_url self.callback_data = {} self.callback_received = asyncio.Event() + def get_redirect_uris(self) -> List[AnyUrl]: + """Return custom callback URL.""" + return [AnyUrl(self.callback_url)] + async def handle_redirect(self, authorization_url: str) -> None: # In a real app, you might redirect the user's current request print(f"Redirecting to: {authorization_url}") @@ -86,16 +105,24 @@ class CustomCallbackHandler(BaseOAuthHandler): For command-line applications: ```python -from mcpadapt.auth import BaseOAuthHandler +from mcpadapt.auth import BaseOAuthHandler, OAuthClientMetadata +from typing import List +from pydantic import AnyUrl import webbrowser import urllib.parse class CLIHandler(BaseOAuthHandler): """OAuth handler optimized for CLI applications.""" - def __init__(self, auto_open_browser: bool = True): + def __init__(self, client_metadata: OAuthClientMetadata, callback_port: int = 3030, auto_open_browser: bool = True): + super().__init__(client_metadata) + self.callback_port = callback_port self.auto_open_browser = auto_open_browser + def get_redirect_uris(self) -> List[AnyUrl]: + """Return localhost callback URL.""" + return [AnyUrl(f"http://localhost:{self.callback_port}/callback")] + async def handle_redirect(self, authorization_url: str) -> None: if self.auto_open_browser: try: @@ -128,28 +155,25 @@ class CLIHandler(BaseOAuthHandler): ## Using Custom Handlers ```python -from mcpadapt.auth import OAuthClientProvider, OAuthClientMetadata, InMemoryTokenStorage +from mcpadapt.auth import OAuthProvider, OAuthClientMetadata, InMemoryTokenStorage from mcpadapt.core import MCPAdapt from mcpadapt.smolagents_adapter import SmolAgentsAdapter -from pydantic import HttpUrl - -# Use your custom handler -custom_handler = HeadlessOAuthHandler() +# Create client metadata client_metadata = OAuthClientMetadata( client_name="My Application", - redirect_uris=[HttpUrl("http://localhost:3030/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) -oauth_provider = OAuthClientProvider( +# Use your custom handler with the new interface +custom_handler = HeadlessOAuthHandler(client_metadata) + +oauth_provider = OAuthProvider( server_url="https://oauth-server.com", - client_metadata=client_metadata, + oauth_handler=custom_handler, storage=InMemoryTokenStorage(), - redirect_handler=custom_handler.handle_redirect, - callback_handler=custom_handler.handle_callback, ) with MCPAdapt( diff --git a/docs/auth/oauth.md b/docs/auth/oauth.md index cfaf698..02399e8 100644 --- a/docs/auth/oauth.md +++ b/docs/auth/oauth.md @@ -15,10 +15,8 @@ The built in provider helps you perform the following sequence: ## Basic OAuth Setup ```python -from pydantic import HttpUrl - from mcpadapt.auth import ( - OAuthClientProvider, + OAuthProvider, OAuthClientMetadata, InMemoryTokenStorage, LocalBrowserOAuthHandler, @@ -26,26 +24,29 @@ from mcpadapt.auth import ( from mcpadapt.core import MCPAdapt from mcpadapt.smolagents_adapter import SmolAgentsAdapter -# Configure client metadata +# Configure client metadata (no need to specify redirect_uris - handled automatically) client_metadata = OAuthClientMetadata( client_name="My Application", - redirect_uris=[HttpUrl("http://localhost:3030/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) -# Set up OAuth components -oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) +# Create OAuth handler with metadata +oauth_handler = LocalBrowserOAuthHandler( + client_metadata=client_metadata, + callback_port=3030, + timeout=300 +) + +# Set up token storage token_storage = InMemoryTokenStorage() -# Create OAuth provider -oauth_provider = OAuthClientProvider( +# Create simplified OAuth provider +oauth_provider = OAuthProvider( server_url="https://oauth-server.com", - client_metadata=client_metadata, + oauth_handler=oauth_handler, storage=token_storage, - redirect_handler=oauth_handler.handle_redirect, - callback_handler=oauth_handler.handle_callback, ) # Use with MCPAdapt @@ -62,7 +63,7 @@ with MCPAdapt( ### OAuthClientMetadata -Configure your application's OAuth settings: +Configure your application's OAuth settings ```python from pydantic import HttpUrl @@ -70,7 +71,6 @@ from mcpadapt.auth import OAuthClientMetadata client_metadata = OAuthClientMetadata( client_name="Your App Name", - redirect_uris=[HttpUrl("http://localhost:3030/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", @@ -83,16 +83,25 @@ client_metadata = OAuthClientMetadata( ### LocalBrowserOAuthHandler -Handles the OAuth flow using the user's browser: +Handles the OAuth flow using the user's browser (now requires client metadata): ```python -from mcpadapt.auth import LocalBrowserOAuthHandler +from mcpadapt.auth import LocalBrowserOAuthHandler, OAuthClientMetadata + +# Create client metadata first +client_metadata = OAuthClientMetadata( + client_name="Your App Name", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) # Default configuration -oauth_handler = LocalBrowserOAuthHandler() +oauth_handler = LocalBrowserOAuthHandler(client_metadata) # Custom configuration oauth_handler = LocalBrowserOAuthHandler( + client_metadata, callback_port=8080, # Custom port timeout=600, # 10 minute timeout ) @@ -120,7 +129,7 @@ from pydantic import HttpUrl from mcp.shared.auth import OAuthClientInformationFull from mcpadapt.auth import ( - OAuthClientProvider, + OAuthProvider, OAuthClientMetadata, InMemoryTokenStorage, LocalBrowserOAuthHandler, @@ -144,22 +153,23 @@ token_storage = InMemoryTokenStorage(client_info=client_info) # Configure client metadata (still needed for OAuth flow) client_metadata = OAuthClientMetadata( client_name="My Pre-configured App", - redirect_uris=[HttpUrl(REDIRECT_URI)], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) -# Set up OAuth handler -oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) +# Create OAuth handler with metadata +oauth_handler = LocalBrowserOAuthHandler( + client_metadata=client_metadata, + callback_port=3030, + timeout=300 +) -# Create OAuth provider -oauth_provider = OAuthClientProvider( +# Create simplified OAuth provider +oauth_provider = OAuthProvider( server_url="https://oauth-server.com", - client_metadata=client_metadata, + oauth_handler=oauth_handler, storage=token_storage, # Contains pre-configured credentials - redirect_handler=oauth_handler.handle_redirect, - callback_handler=oauth_handler.handle_callback, ) # Use with MCPAdapt - DCR will be skipped @@ -179,36 +189,6 @@ Use pre-configured credentials when: - **Existing OAuth app**: You already have a registered OAuth application with client credentials - **Compliance requirements**: Your organization requires using specific pre-registered applications -### Environment Variables for Credentials - -Store your OAuth credentials securely using environment variables: - -```bash -# Set environment variables -export OAUTH_CLIENT_ID="your-actual-client-id" -export OAUTH_CLIENT_SECRET="your-actual-client-secret" -export OAUTH_REDIRECT_URI="http://localhost:3030/callback" -``` - -Then reference them in your code: - -```python -import os -from pydantic import HttpUrl -from mcp.shared.auth import OAuthClientInformationFull -from mcpadapt.auth import InMemoryTokenStorage - -# Load from environment -client_info = OAuthClientInformationFull( - client_id=os.getenv("OAUTH_CLIENT_ID"), - client_secret=os.getenv("OAUTH_CLIENT_SECRET"), - redirect_uris=[HttpUrl(os.getenv("OAUTH_REDIRECT_URI", "http://localhost:3030/callback"))] -) - -# Create storage with pre-configured credentials -token_storage = InMemoryTokenStorage(client_info=client_info) -``` - ### Complete Example See `examples/oauth_with_credentials_example.py` for a complete working example of using pre-configured OAuth credentials. @@ -218,26 +198,46 @@ See `examples/oauth_with_credentials_example.py` for a complete working example Create custom OAuth handlers for production environments or when you are integrating into a larger app: ```python -from mcpadapt.auth import BaseOAuthHandler +from mcpadapt.auth import BaseOAuthHandler, OAuthProvider, OAuthClientMetadata +from typing import List +from pydantic import HttpUrl -class HeadlessOAuthHandler(BaseOAuthHandler): - """OAuth handler for headless environments.""" +class CustomOAuthHandler(BaseOAuthHandler): + """Custom OAuth handler with different callback port.""" + + def __init__(self, client_metadata: OAuthClientMetadata, callback_port: int = 8080): + super().__init__(client_metadata) + self.callback_port = callback_port + + def get_redirect_uris(self) -> List[HttpUrl]: + """Return redirect URIs for this handler.""" + return [HttpUrl(f"http://localhost:{self.callback_port}/oauth/callback")] async def handle_redirect(self, authorization_url: str) -> None: - print(f"Open this URL in your browser: {authorization_url}") + print(f"Please open this URL in your browser: {authorization_url}") + # Custom logging or integration logic here async def handle_callback(self) -> tuple[str, str | None]: - auth_code = input("Enter the authorization code: ") + # Custom callback handling logic + print("Waiting for OAuth callback...") + # In a real implementation, you'd set up your own server or integration + auth_code = input("Enter the authorization code from the callback: ") return auth_code, None -# Use custom handler -custom_handler = HeadlessOAuthHandler() -oauth_provider = OAuthClientProvider( +# Create client metadata +client_metadata = OAuthClientMetadata( + client_name="Custom OAuth App", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", +) + +# Use custom handler with the new interface +custom_handler = CustomOAuthHandler(client_metadata, callback_port=8080) +oauth_provider = OAuthProvider( server_url="https://oauth-server.com", - client_metadata=client_metadata, + oauth_handler=custom_handler, storage=token_storage, - redirect_handler=custom_handler.handle_redirect, - callback_handler=custom_handler.handle_callback, ) ``` @@ -344,17 +344,4 @@ except OAuthError as e: print(f"General OAuth error: {e}") if e.context: print(f"Context: {e.context}") -``` - -## Configuration Tips - -### Port Selection -- Default port is 3030 -- Ensure the port is available and not blocked by firewalls -- Use different ports for different applications - -## Security Considerations - -- Store tokens securely in production environments, you may want to use something like Vault -- Use HTTPS for all OAuth flows in production -- Use appropriate OAuth scopes (minimal permissions) +``` \ No newline at end of file diff --git a/docs/auth/overview.md b/docs/auth/overview.md index 3a68bd1..c150929 100644 --- a/docs/auth/overview.md +++ b/docs/auth/overview.md @@ -16,7 +16,7 @@ Standard Bearer token authentication for JWT-based systems and modern API patter ## Core Components **Authentication Providers:** -- `OAuthClientProvider` - OAuth 2.0 authentication +- `OAuthProvider` - OAuth 2.0 authentication - `ApiKeyAuthProvider` - API key authentication - `BearerAuthProvider` - Bearer token authentication @@ -49,13 +49,3 @@ with MCPAdapt( # Authentication handled automatically result = tools[0]({"param": "value"}) ``` - -## Next Steps - -- [Quick Start Guide](quickstart.md) - Get started immediately -- [OAuth 2.0 Guide](oauth.md) - Complete OAuth implementation -- [API Key Guide](api-key.md) - Key-based authentication -- [Bearer Token Guide](bearer-token.md) - Token authentication -- [Custom Handlers](custom-handlers.md) - Custom authentication flows -- [Error Handling](error-handling.md) - Handle authentication errors -- [Examples](examples.md) - Real-world examples diff --git a/docs/auth/quickstart.md b/docs/auth/quickstart.md index 0840196..254ead7 100644 --- a/docs/auth/quickstart.md +++ b/docs/auth/quickstart.md @@ -5,32 +5,37 @@ Get authentication working quickly with these minimal examples. ## OAuth 2.0 with a provider like Canva ```python -from mcp.client.auth import OAuthClientProvider -from mcp.shared.auth import OAuthClientMetadata -from pydantic import HttpUrl - -from mcpadapt.auth import InMemoryTokenStorage, LocalBrowserOAuthHandler +from mcpadapt.auth import ( + OAuthProvider, + OAuthClientMetadata, + InMemoryTokenStorage, + LocalBrowserOAuthHandler, +) from mcpadapt.core import MCPAdapt from mcpadapt.smolagents_adapter import SmolAgentsAdapter -# Configure OAuth +# Configure OAuth (no need to specify redirect_uris - handled automatically) client_metadata = OAuthClientMetadata( client_name="My App", - redirect_uris=[HttpUrl("http://localhost:3030/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) -oauth_handler = LocalBrowserOAuthHandler(callback_port=3030) +# Create OAuth handler with metadata +oauth_handler = LocalBrowserOAuthHandler( + client_metadata=client_metadata, + callback_port=3030, + timeout=300 +) + token_storage = InMemoryTokenStorage() -oauth_provider = OAuthClientProvider( +# Create simplified OAuth provider +oauth_provider = OAuthProvider( server_url="https://mcp.canva.com", - client_metadata=client_metadata, + oauth_handler=oauth_handler, storage=token_storage, - redirect_handler=oauth_handler.handle_redirect, - callback_handler=oauth_handler.handle_callback, ) # Use with MCPAdapt diff --git a/examples/canva_oauth_example.py b/examples/canva_oauth_example.py index c0546be..b47af34 100644 --- a/examples/canva_oauth_example.py +++ b/examples/canva_oauth_example.py @@ -7,8 +7,6 @@ It fully complies with OAuth 2.0 Dynamic Client Registration """ -from pydantic import HttpUrl - from mcpadapt.auth import ( InMemoryTokenStorage, LocalBrowserOAuthHandler, @@ -16,7 +14,7 @@ OAuthCancellationError, OAuthNetworkError, OAuthConfigurationError, - OAuthClientProvider, + OAuthProvider, OAuthClientMetadata, ) from mcpadapt.core import MCPAdapt @@ -31,23 +29,24 @@ def main(): # Create OAuth client metadata client_metadata = OAuthClientMetadata( client_name="MCPAdapt Canva Example", - redirect_uris=[HttpUrl("http://localhost:3030/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) - # Create OAuth handler and token storage - oauth_handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + # Create OAuth handler and token storage + oauth_handler = LocalBrowserOAuthHandler( + client_metadata=client_metadata, + callback_port=3030, + timeout=300 + ) token_storage = InMemoryTokenStorage() - # Create OAuth provider directly - oauth_provider = OAuthClientProvider( + # Create OAuth provider + oauth_provider = OAuthProvider( server_url="https://mcp.canva.com", - client_metadata=client_metadata, + oauth_handler=oauth_handler, storage=token_storage, - redirect_handler=oauth_handler.handle_redirect, - callback_handler=oauth_handler.handle_callback, ) # Server configuration for Canva MCP diff --git a/examples/oauth_with_credentials_example.py b/examples/oauth_with_credentials_example.py index 1b50b8b..15639c4 100644 --- a/examples/oauth_with_credentials_example.py +++ b/examples/oauth_with_credentials_example.py @@ -17,7 +17,7 @@ OAuthCancellationError, OAuthNetworkError, OAuthConfigurationError, - OAuthClientProvider, + OAuthProvider, OAuthClientMetadata, ) from mcpadapt.core import MCPAdapt @@ -49,26 +49,27 @@ def main(): # Create OAuth client metadata (still needed for the OAuth flow) client_metadata = OAuthClientMetadata( client_name="MCPAdapt Pre-configured OAuth Example", - redirect_uris=[HttpUrl(REDIRECT_URI)], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) # Create OAuth handler - oauth_handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + oauth_handler = LocalBrowserOAuthHandler( + client_metadata=client_metadata, + callback_port=3030, + timeout=300 + ) # Create token storage WITH pre-configured client information # This is the key difference - we pass the client_info object token_storage = InMemoryTokenStorage(client_info=client_info) # Create OAuth provider - oauth_provider = OAuthClientProvider( + oauth_provider = OAuthProvider( server_url="https://api.example.com", - client_metadata=client_metadata, + oauth_handler=oauth_handler, storage=token_storage, # Storage contains pre-configured credentials - redirect_handler=oauth_handler.handle_redirect, - callback_handler=oauth_handler.handle_callback, ) # Server configuration diff --git a/src/mcpadapt/auth/__init__.py b/src/mcpadapt/auth/__init__.py index d1f783d..dc410e3 100644 --- a/src/mcpadapt/auth/__init__.py +++ b/src/mcpadapt/auth/__init__.py @@ -4,12 +4,12 @@ for MCP servers. """ -from .oauth import InMemoryTokenStorage +from .oauth import InMemoryTokenStorage, OAuthClientMetadata from mcp.shared.auth import ( OAuthClientInformationFull, OAuthToken, InvalidScopeError, - OAuthClientMetadata, + OAuthClientMetadata as MCPOAuthClientMetadata, InvalidRedirectUriError, OAuthMetadata, ProtectedResourceMetadata, @@ -23,6 +23,7 @@ from .providers import ( ApiKeyAuthProvider, BearerAuthProvider, + OAuthProvider, get_auth_headers, ) from .exceptions import ( @@ -43,6 +44,7 @@ # Provider classes "ApiKeyAuthProvider", "BearerAuthProvider", + "OAuthProvider", # Default implementations "InMemoryTokenStorage", # Provider functions @@ -62,8 +64,10 @@ "OAuthClientInformationFull", "OAuthToken", "InvalidScopeError", - "OAuthClientMetadata", "InvalidRedirectUriError", "OAuthMetadata", "ProtectedResourceMetadata", + # Our OAuth classes + "OAuthClientMetadata", + "MCPOAuthClientMetadata", ] diff --git a/src/mcpadapt/auth/handlers.py b/src/mcpadapt/auth/handlers.py index 1efd7e8..159e4e1 100644 --- a/src/mcpadapt/auth/handlers.py +++ b/src/mcpadapt/auth/handlers.py @@ -6,7 +6,11 @@ from abc import ABC, abstractmethod from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import parse_qs, urlparse +from typing import List, Tuple, Optional +from pydantic import AnyUrl +from mcp.shared.auth import OAuthClientMetadata as MCPOAuthClientMetadata +from .oauth import OAuthClientMetadata from .exceptions import ( OAuthCallbackError, OAuthCancellationError, @@ -178,6 +182,42 @@ class BaseOAuthHandler(ABC): Subclasses should implement both the redirect flow (opening authorization URL) and callback flow (receiving authorization code). """ + + def __init__(self, client_metadata: OAuthClientMetadata): + """Initialize handler with OAuth client metadata. + + Args: + client_metadata: OAuth client metadata configuration + """ + self.client_metadata = client_metadata + + @abstractmethod + def get_redirect_uris(self) -> List[AnyUrl]: + """Get redirect URIs for this handler. + + Returns: + List of redirect URIs that this handler can process + """ + pass + + def get_client_metadata(self) -> MCPOAuthClientMetadata: + """Get complete OAuth client metadata with redirect URIs populated. + + Returns: + Complete OAuth client metadata for MCP usage + """ + return MCPOAuthClientMetadata( + client_name=self.client_metadata.client_name, + redirect_uris=self.get_redirect_uris(), + grant_types=self.client_metadata.grant_types, + response_types=self.client_metadata.response_types, + token_endpoint_auth_method=self.client_metadata.token_endpoint_auth_method, + scope=self.client_metadata.scope, + client_uri=self.client_metadata.client_uri, + logo_uri=self.client_metadata.logo_uri, + tos_uri=self.client_metadata.tos_uri, + policy_uri=self.client_metadata.policy_uri, + ) @abstractmethod async def handle_redirect(self, authorization_url: str) -> None: @@ -209,16 +249,26 @@ class LocalBrowserOAuthHandler(BaseOAuthHandler): approach for desktop applications. """ - def __init__(self, callback_port: int = 3030, timeout: int = 300): + def __init__(self, client_metadata: OAuthClientMetadata, callback_port: int = 3030, timeout: int = 300): """Initialize the local browser OAuth handler. Args: + client_metadata: OAuth client metadata configuration callback_port: Port to run the local callback server on timeout: Maximum time to wait for OAuth callback in seconds """ + super().__init__(client_metadata) self.callback_port = callback_port self.timeout = timeout self.callback_server: LocalCallbackServer | None = None + + def get_redirect_uris(self) -> List[AnyUrl]: + """Get redirect URIs for this handler. + + Returns: + List of redirect URIs based on callback port + """ + return [AnyUrl(f"http://localhost:{self.callback_port}/callback")] async def handle_redirect(self, authorization_url: str) -> None: """Open authorization URL in the user's default browser. diff --git a/src/mcpadapt/auth/oauth.py b/src/mcpadapt/auth/oauth.py index 7a2b20a..bbb8411 100644 --- a/src/mcpadapt/auth/oauth.py +++ b/src/mcpadapt/auth/oauth.py @@ -1,9 +1,28 @@ """OAuth token storage and utility implementations.""" +from typing import List, Optional, Literal +from pydantic import BaseModel, AnyUrl, AnyHttpUrl from mcp.client.auth import TokenStorage from mcp.shared.auth import OAuthClientInformationFull, OAuthToken +class OAuthClientMetadata(BaseModel): + """OAuth client metadata without required redirect_uris. + + This is our custom version that allows handlers to manage redirect URIs internally. + """ + client_name: str + grant_types: List[Literal["authorization_code", "refresh_token"]] = ["authorization_code", "refresh_token"] + response_types: List[Literal["code"]] = ["code"] + token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" + redirect_uris: Optional[List[AnyUrl]] = None + scope: Optional[str] = None + client_uri: Optional[AnyHttpUrl] = None + logo_uri: Optional[AnyHttpUrl] = None + tos_uri: Optional[AnyHttpUrl] = None + policy_uri: Optional[AnyHttpUrl] = None + + class InMemoryTokenStorage(TokenStorage): """Simple in-memory token storage implementation.""" diff --git a/src/mcpadapt/auth/providers.py b/src/mcpadapt/auth/providers.py index 5c82a90..8e8c4bc 100644 --- a/src/mcpadapt/auth/providers.py +++ b/src/mcpadapt/auth/providers.py @@ -1,6 +1,8 @@ """Authentication provider classes for MCPAdapt.""" from typing import Any +from mcp.client.auth import OAuthClientProvider, TokenStorage +from .handlers import BaseOAuthHandler class ApiKeyAuthProvider: @@ -45,6 +47,30 @@ def get_headers(self) -> dict[str, str]: return {"Authorization": f"Bearer {self.token}"} +class OAuthProvider(OAuthClientProvider): + """OAuth provider that accepts a handler directly. + + This class simplifies OAuth configuration by taking an OAuthHandler + and internally extracting the client metadata and callback handlers. + """ + + def __init__(self, server_url: str, oauth_handler: BaseOAuthHandler, storage: TokenStorage): + """Initialize OAuth provider with handler. + + Args: + server_url: MCP server URL + oauth_handler: OAuth handler containing all configuration + storage: Token storage implementation + """ + super().__init__( + server_url=server_url, + client_metadata=oauth_handler.get_client_metadata(), + storage=storage, + redirect_handler=oauth_handler.handle_redirect, + callback_handler=oauth_handler.handle_callback, + ) + + def get_auth_headers(auth_provider: Any) -> dict[str, str]: """Get authentication headers from provider. diff --git a/tests/auth/test_handlers.py b/tests/auth/test_handlers.py index 6345096..c45ed4f 100644 --- a/tests/auth/test_handlers.py +++ b/tests/auth/test_handlers.py @@ -7,6 +7,7 @@ LocalCallbackServer, LocalBrowserOAuthHandler, ) +from mcpadapt.auth.oauth import OAuthClientMetadata from mcpadapt.auth.exceptions import ( OAuthTimeoutError, OAuthCancellationError, @@ -14,6 +15,18 @@ OAuthCallbackError, OAuthServerError, ) +from pydantic import AnyUrl + + +@pytest.fixture +def sample_client_metadata(): + """Create sample OAuth client metadata for tests.""" + return OAuthClientMetadata( + client_name="Test OAuth Client", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) class TestCallbackHandler: @@ -321,25 +334,25 @@ def test_get_state(self): class TestLocalBrowserOAuthHandler: """Test local browser OAuth handler.""" - def test_initialization(self): + def test_initialization(self, sample_client_metadata): """Test handler initialization.""" - handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=300) assert handler.callback_port == 3030 assert handler.timeout == 300 assert handler.callback_server is None - def test_initialization_defaults(self): + def test_initialization_defaults(self, sample_client_metadata): """Test handler initialization with defaults.""" - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) assert handler.callback_port == 3030 assert handler.timeout == 300 @pytest.mark.asyncio - async def test_handle_redirect_success(self, mock_webbrowser): + async def test_handle_redirect_success(self, mock_webbrowser, sample_client_metadata): """Test successful redirect handling.""" mock_webbrowser.open.return_value = True - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) # Should not raise await handler.handle_redirect("https://example.com/oauth/authorize") @@ -349,11 +362,11 @@ async def test_handle_redirect_success(self, mock_webbrowser): ) @pytest.mark.asyncio - async def test_handle_redirect_browser_fail(self, mock_webbrowser): + async def test_handle_redirect_browser_fail(self, mock_webbrowser, sample_client_metadata): """Test redirect handling when browser fails to open.""" mock_webbrowser.open.return_value = False - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) with pytest.raises(OAuthNetworkError) as exc_info: await handler.handle_redirect("https://example.com/oauth/authorize") @@ -361,11 +374,11 @@ async def test_handle_redirect_browser_fail(self, mock_webbrowser): assert "Failed to open browser" in str(exc_info.value.original_error) @pytest.mark.asyncio - async def test_handle_redirect_browser_exception(self, mock_webbrowser): + async def test_handle_redirect_browser_exception(self, mock_webbrowser, sample_client_metadata): """Test redirect handling when browser raises exception.""" mock_webbrowser.open.side_effect = Exception("Browser error") - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) with pytest.raises(OAuthNetworkError) as exc_info: await handler.handle_redirect("https://example.com/oauth/authorize") @@ -373,9 +386,9 @@ async def test_handle_redirect_browser_exception(self, mock_webbrowser): assert "Browser error" in str(exc_info.value.original_error) @pytest.mark.asyncio - async def test_handle_callback_success(self): + async def test_handle_callback_success(self, sample_client_metadata): """Test successful callback handling.""" - handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=300) + handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=300) # Mock LocalCallbackServer mock_server = Mock() @@ -398,9 +411,9 @@ async def test_handle_callback_success(self): mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_handle_callback_timeout(self): + async def test_handle_callback_timeout(self, sample_client_metadata): """Test callback handling with timeout.""" - handler = LocalBrowserOAuthHandler(callback_port=3030, timeout=60) + handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=60) mock_server = Mock() mock_server.wait_for_callback.side_effect = OAuthTimeoutError(60) @@ -415,9 +428,9 @@ async def test_handle_callback_timeout(self): mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_handle_callback_cancellation(self): + async def test_handle_callback_cancellation(self, sample_client_metadata): """Test callback handling with user cancellation.""" - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) mock_server = Mock() mock_server.wait_for_callback.side_effect = OAuthCancellationError( @@ -434,9 +447,9 @@ async def test_handle_callback_cancellation(self): mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_handle_callback_server_error(self): + async def test_handle_callback_server_error(self, sample_client_metadata): """Test callback handling with server error.""" - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) mock_server = Mock() mock_server.wait_for_callback.side_effect = OAuthServerError("invalid_request") @@ -451,9 +464,9 @@ async def test_handle_callback_server_error(self): mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_handle_callback_unexpected_error(self): + async def test_handle_callback_unexpected_error(self, sample_client_metadata): """Test callback handling with unexpected error.""" - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) mock_server = Mock() mock_server.start.side_effect = ValueError("Unexpected error") @@ -473,9 +486,9 @@ async def test_handle_callback_unexpected_error(self): mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_handle_callback_cleanup_on_success(self): + async def test_handle_callback_cleanup_on_success(self, sample_client_metadata): """Test that cleanup always happens even on success.""" - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) mock_server = Mock() mock_server.wait_for_callback.return_value = "auth_code" @@ -494,11 +507,11 @@ class TestLocalBrowserOAuthHandlerIntegration: """Test integration scenarios for LocalBrowserOAuthHandler.""" @pytest.mark.asyncio - async def test_full_oauth_flow_simulation(self, mock_webbrowser): + async def test_full_oauth_flow_simulation(self, mock_webbrowser, sample_client_metadata): """Test complete OAuth flow simulation.""" mock_webbrowser.open.return_value = True - handler = LocalBrowserOAuthHandler(callback_port=4040, timeout=120) + handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=4040, timeout=120) # Mock server for callback mock_server = Mock() @@ -532,12 +545,12 @@ async def test_full_oauth_flow_simulation(self, mock_webbrowser): mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_multiple_handlers_independent(self, mock_webbrowser): + async def test_multiple_handlers_independent(self, mock_webbrowser, sample_client_metadata): """Test that multiple handler instances are independent.""" mock_webbrowser.open.return_value = True - handler1 = LocalBrowserOAuthHandler(callback_port=3030, timeout=100) - handler2 = LocalBrowserOAuthHandler(callback_port=4040, timeout=200) + handler1 = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=100) + handler2 = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=4040, timeout=200) # Mock servers mock_server1 = Mock() @@ -578,11 +591,11 @@ class TestHandlerErrorScenarios: """Test comprehensive error scenarios across handlers.""" @pytest.mark.asyncio - async def test_network_failure_during_redirect(self, mock_webbrowser): + async def test_network_failure_during_redirect(self, mock_webbrowser, sample_client_metadata): """Test network failure during redirect.""" mock_webbrowser.open.side_effect = ConnectionError("Network unreachable") - handler = LocalBrowserOAuthHandler() + handler = LocalBrowserOAuthHandler(sample_client_metadata) with pytest.raises(OAuthNetworkError) as exc_info: await handler.handle_redirect("https://example.com/oauth") @@ -599,9 +612,9 @@ def test_callback_server_edge_cases(self): assert server.callback_data is not None @pytest.mark.asyncio - async def test_handler_state_consistency(self): + async def test_handler_state_consistency(self, sample_client_metadata): """Test handler state remains consistent across operations.""" - handler = LocalBrowserOAuthHandler(callback_port=5050, timeout=150) + handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=5050, timeout=150) # Verify initial state assert handler.callback_port == 5050 diff --git a/tests/auth/test_providers.py b/tests/auth/test_providers.py index 9ee115a..1331be1 100644 --- a/tests/auth/test_providers.py +++ b/tests/auth/test_providers.py @@ -1,11 +1,30 @@ """Tests for authentication provider classes.""" -from unittest.mock import Mock +from unittest.mock import Mock, AsyncMock, patch from mcpadapt.auth.providers import ( ApiKeyAuthProvider, BearerAuthProvider, + OAuthProvider, get_auth_headers, ) +from mcpadapt.auth.handlers import BaseOAuthHandler +from mcpadapt.auth.oauth import OAuthClientMetadata +from mcpadapt.auth import InMemoryTokenStorage +from pydantic import AnyUrl +from typing import List + + +class MockOAuthHandler(BaseOAuthHandler): + """Mock OAuth handler for testing.""" + + def get_redirect_uris(self) -> List[AnyUrl]: + return [AnyUrl("http://localhost:3030/callback")] + + async def handle_redirect(self, authorization_url: str) -> None: + pass + + async def handle_callback(self) -> tuple[str, str | None]: + return "test_code", "test_state" class TestApiKeyAuthProvider: @@ -295,3 +314,102 @@ def test_providers_immutability(self): # Verify original token unchanged assert bearer_provider.token == original_token + + +class TestOAuthProvider: + """Test OAuthProvider class.""" + + def test_initialization(self): + """Test OAuthProvider initialization.""" + client_metadata = OAuthClientMetadata( + client_name="Test App", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + oauth_handler = MockOAuthHandler(client_metadata) + storage = InMemoryTokenStorage() + + provider = OAuthProvider( + server_url="https://test-server.com", + oauth_handler=oauth_handler, + storage=storage, + ) + + # Verify the provider was created successfully + assert provider is not None + + def test_oauth_provider_extracts_metadata(self): + """Test that OAuthProvider properly extracts metadata from handler.""" + client_metadata = OAuthClientMetadata( + client_name="Test App", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + oauth_handler = MockOAuthHandler(client_metadata) + storage = InMemoryTokenStorage() + + with patch('mcp.client.auth.OAuthClientProvider.__init__') as mock_init: + mock_init.return_value = None + + OAuthProvider( + server_url="https://test-server.com", + oauth_handler=oauth_handler, + storage=storage, + ) + + # Verify parent constructor was called with correct parameters + mock_init.assert_called_once() + call_args = mock_init.call_args + + assert call_args[1]['server_url'] == "https://test-server.com" + assert call_args[1]['storage'] == storage + assert call_args[1]['redirect_handler'] == oauth_handler.handle_redirect + assert call_args[1]['callback_handler'] == oauth_handler.handle_callback + + # Verify client_metadata was properly constructed + client_metadata_arg = call_args[1]['client_metadata'] + assert client_metadata_arg.client_name == "Test App" + assert len(client_metadata_arg.redirect_uris) == 1 + assert str(client_metadata_arg.redirect_uris[0]) == "http://localhost:3030/callback" + + def test_oauth_provider_with_custom_handler(self): + """Test OAuthProvider with custom handler that has different redirect URIs.""" + + class CustomTestHandler(BaseOAuthHandler): + def get_redirect_uris(self) -> List[AnyUrl]: + return [AnyUrl("http://localhost:8080/auth/callback")] + + async def handle_redirect(self, authorization_url: str) -> None: + pass + + async def handle_callback(self) -> tuple[str, str | None]: + return "custom_code", "custom_state" + + client_metadata = OAuthClientMetadata( + client_name="Custom Test App", + grant_types=["authorization_code"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + custom_handler = CustomTestHandler(client_metadata) + storage = InMemoryTokenStorage() + + with patch('mcp.client.auth.OAuthClientProvider.__init__') as mock_init: + mock_init.return_value = None + + OAuthProvider( + server_url="https://custom-server.com", + oauth_handler=custom_handler, + storage=storage, + ) + + # Verify the custom redirect URI was used + call_args = mock_init.call_args + client_metadata_arg = call_args[1]['client_metadata'] + assert len(client_metadata_arg.redirect_uris) == 1 + assert str(client_metadata_arg.redirect_uris[0]) == "http://localhost:8080/auth/callback" From f51e5037ac2a9217930a2235d3ed5865faae7f4f Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sun, 7 Sep 2025 15:08:05 +0530 Subject: [PATCH 6/7] Run format and lint --- examples/canva_oauth_example.py | 8 +-- examples/oauth_with_credentials_example.py | 4 +- src/mcpadapt/auth/handlers.py | 23 +++++--- src/mcpadapt/auth/oauth.py | 12 +++- src/mcpadapt/auth/providers.py | 10 ++-- tests/auth/test_handlers.py | 53 ++++++++++++----- tests/auth/test_providers.py | 68 ++++++++++++---------- 7 files changed, 109 insertions(+), 69 deletions(-) diff --git a/examples/canva_oauth_example.py b/examples/canva_oauth_example.py index b47af34..4124b43 100644 --- a/examples/canva_oauth_example.py +++ b/examples/canva_oauth_example.py @@ -34,15 +34,13 @@ def main(): token_endpoint_auth_method="client_secret_post", ) - # Create OAuth handler and token storage + # Create OAuth handler and token storage oauth_handler = LocalBrowserOAuthHandler( - client_metadata=client_metadata, - callback_port=3030, - timeout=300 + client_metadata=client_metadata, callback_port=3030, timeout=300 ) token_storage = InMemoryTokenStorage() - # Create OAuth provider + # Create OAuth provider oauth_provider = OAuthProvider( server_url="https://mcp.canva.com", oauth_handler=oauth_handler, diff --git a/examples/oauth_with_credentials_example.py b/examples/oauth_with_credentials_example.py index 15639c4..01a7fab 100644 --- a/examples/oauth_with_credentials_example.py +++ b/examples/oauth_with_credentials_example.py @@ -56,9 +56,7 @@ def main(): # Create OAuth handler oauth_handler = LocalBrowserOAuthHandler( - client_metadata=client_metadata, - callback_port=3030, - timeout=300 + client_metadata=client_metadata, callback_port=3030, timeout=300 ) # Create token storage WITH pre-configured client information diff --git a/src/mcpadapt/auth/handlers.py b/src/mcpadapt/auth/handlers.py index 159e4e1..eb1256c 100644 --- a/src/mcpadapt/auth/handlers.py +++ b/src/mcpadapt/auth/handlers.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import parse_qs, urlparse -from typing import List, Tuple, Optional +from typing import List from pydantic import AnyUrl from mcp.shared.auth import OAuthClientMetadata as MCPOAuthClientMetadata @@ -182,10 +182,10 @@ class BaseOAuthHandler(ABC): Subclasses should implement both the redirect flow (opening authorization URL) and callback flow (receiving authorization code). """ - + def __init__(self, client_metadata: OAuthClientMetadata): """Initialize handler with OAuth client metadata. - + Args: client_metadata: OAuth client metadata configuration """ @@ -194,15 +194,15 @@ def __init__(self, client_metadata: OAuthClientMetadata): @abstractmethod def get_redirect_uris(self) -> List[AnyUrl]: """Get redirect URIs for this handler. - + Returns: List of redirect URIs that this handler can process """ pass - + def get_client_metadata(self) -> MCPOAuthClientMetadata: """Get complete OAuth client metadata with redirect URIs populated. - + Returns: Complete OAuth client metadata for MCP usage """ @@ -249,7 +249,12 @@ class LocalBrowserOAuthHandler(BaseOAuthHandler): approach for desktop applications. """ - def __init__(self, client_metadata: OAuthClientMetadata, callback_port: int = 3030, timeout: int = 300): + def __init__( + self, + client_metadata: OAuthClientMetadata, + callback_port: int = 3030, + timeout: int = 300, + ): """Initialize the local browser OAuth handler. Args: @@ -261,10 +266,10 @@ def __init__(self, client_metadata: OAuthClientMetadata, callback_port: int = 30 self.callback_port = callback_port self.timeout = timeout self.callback_server: LocalCallbackServer | None = None - + def get_redirect_uris(self) -> List[AnyUrl]: """Get redirect URIs for this handler. - + Returns: List of redirect URIs based on callback port """ diff --git a/src/mcpadapt/auth/oauth.py b/src/mcpadapt/auth/oauth.py index bbb8411..6e36a43 100644 --- a/src/mcpadapt/auth/oauth.py +++ b/src/mcpadapt/auth/oauth.py @@ -8,13 +8,19 @@ class OAuthClientMetadata(BaseModel): """OAuth client metadata without required redirect_uris. - + This is our custom version that allows handlers to manage redirect URIs internally. """ + client_name: str - grant_types: List[Literal["authorization_code", "refresh_token"]] = ["authorization_code", "refresh_token"] + grant_types: List[Literal["authorization_code", "refresh_token"]] = [ + "authorization_code", + "refresh_token", + ] response_types: List[Literal["code"]] = ["code"] - token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" + token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( + "client_secret_post" + ) redirect_uris: Optional[List[AnyUrl]] = None scope: Optional[str] = None client_uri: Optional[AnyHttpUrl] = None diff --git a/src/mcpadapt/auth/providers.py b/src/mcpadapt/auth/providers.py index 8e8c4bc..26487e1 100644 --- a/src/mcpadapt/auth/providers.py +++ b/src/mcpadapt/auth/providers.py @@ -49,14 +49,16 @@ def get_headers(self) -> dict[str, str]: class OAuthProvider(OAuthClientProvider): """OAuth provider that accepts a handler directly. - + This class simplifies OAuth configuration by taking an OAuthHandler and internally extracting the client metadata and callback handlers. """ - - def __init__(self, server_url: str, oauth_handler: BaseOAuthHandler, storage: TokenStorage): + + def __init__( + self, server_url: str, oauth_handler: BaseOAuthHandler, storage: TokenStorage + ): """Initialize OAuth provider with handler. - + Args: server_url: MCP server URL oauth_handler: OAuth handler containing all configuration diff --git a/tests/auth/test_handlers.py b/tests/auth/test_handlers.py index c45ed4f..7f380a2 100644 --- a/tests/auth/test_handlers.py +++ b/tests/auth/test_handlers.py @@ -15,7 +15,6 @@ OAuthCallbackError, OAuthServerError, ) -from pydantic import AnyUrl @pytest.fixture @@ -336,7 +335,9 @@ class TestLocalBrowserOAuthHandler: def test_initialization(self, sample_client_metadata): """Test handler initialization.""" - handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=300) + handler = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=3030, timeout=300 + ) assert handler.callback_port == 3030 assert handler.timeout == 300 assert handler.callback_server is None @@ -348,7 +349,9 @@ def test_initialization_defaults(self, sample_client_metadata): assert handler.timeout == 300 @pytest.mark.asyncio - async def test_handle_redirect_success(self, mock_webbrowser, sample_client_metadata): + async def test_handle_redirect_success( + self, mock_webbrowser, sample_client_metadata + ): """Test successful redirect handling.""" mock_webbrowser.open.return_value = True @@ -362,7 +365,9 @@ async def test_handle_redirect_success(self, mock_webbrowser, sample_client_meta ) @pytest.mark.asyncio - async def test_handle_redirect_browser_fail(self, mock_webbrowser, sample_client_metadata): + async def test_handle_redirect_browser_fail( + self, mock_webbrowser, sample_client_metadata + ): """Test redirect handling when browser fails to open.""" mock_webbrowser.open.return_value = False @@ -374,7 +379,9 @@ async def test_handle_redirect_browser_fail(self, mock_webbrowser, sample_client assert "Failed to open browser" in str(exc_info.value.original_error) @pytest.mark.asyncio - async def test_handle_redirect_browser_exception(self, mock_webbrowser, sample_client_metadata): + async def test_handle_redirect_browser_exception( + self, mock_webbrowser, sample_client_metadata + ): """Test redirect handling when browser raises exception.""" mock_webbrowser.open.side_effect = Exception("Browser error") @@ -388,7 +395,9 @@ async def test_handle_redirect_browser_exception(self, mock_webbrowser, sample_c @pytest.mark.asyncio async def test_handle_callback_success(self, sample_client_metadata): """Test successful callback handling.""" - handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=300) + handler = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=3030, timeout=300 + ) # Mock LocalCallbackServer mock_server = Mock() @@ -413,7 +422,9 @@ async def test_handle_callback_success(self, sample_client_metadata): @pytest.mark.asyncio async def test_handle_callback_timeout(self, sample_client_metadata): """Test callback handling with timeout.""" - handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=60) + handler = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=3030, timeout=60 + ) mock_server = Mock() mock_server.wait_for_callback.side_effect = OAuthTimeoutError(60) @@ -507,11 +518,15 @@ class TestLocalBrowserOAuthHandlerIntegration: """Test integration scenarios for LocalBrowserOAuthHandler.""" @pytest.mark.asyncio - async def test_full_oauth_flow_simulation(self, mock_webbrowser, sample_client_metadata): + async def test_full_oauth_flow_simulation( + self, mock_webbrowser, sample_client_metadata + ): """Test complete OAuth flow simulation.""" mock_webbrowser.open.return_value = True - handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=4040, timeout=120) + handler = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=4040, timeout=120 + ) # Mock server for callback mock_server = Mock() @@ -545,12 +560,18 @@ async def test_full_oauth_flow_simulation(self, mock_webbrowser, sample_client_m mock_server.stop.assert_called_once() @pytest.mark.asyncio - async def test_multiple_handlers_independent(self, mock_webbrowser, sample_client_metadata): + async def test_multiple_handlers_independent( + self, mock_webbrowser, sample_client_metadata + ): """Test that multiple handler instances are independent.""" mock_webbrowser.open.return_value = True - handler1 = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=3030, timeout=100) - handler2 = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=4040, timeout=200) + handler1 = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=3030, timeout=100 + ) + handler2 = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=4040, timeout=200 + ) # Mock servers mock_server1 = Mock() @@ -591,7 +612,9 @@ class TestHandlerErrorScenarios: """Test comprehensive error scenarios across handlers.""" @pytest.mark.asyncio - async def test_network_failure_during_redirect(self, mock_webbrowser, sample_client_metadata): + async def test_network_failure_during_redirect( + self, mock_webbrowser, sample_client_metadata + ): """Test network failure during redirect.""" mock_webbrowser.open.side_effect = ConnectionError("Network unreachable") @@ -614,7 +637,9 @@ def test_callback_server_edge_cases(self): @pytest.mark.asyncio async def test_handler_state_consistency(self, sample_client_metadata): """Test handler state remains consistent across operations.""" - handler = LocalBrowserOAuthHandler(sample_client_metadata, callback_port=5050, timeout=150) + handler = LocalBrowserOAuthHandler( + sample_client_metadata, callback_port=5050, timeout=150 + ) # Verify initial state assert handler.callback_port == 5050 diff --git a/tests/auth/test_providers.py b/tests/auth/test_providers.py index 1331be1..7ad825b 100644 --- a/tests/auth/test_providers.py +++ b/tests/auth/test_providers.py @@ -1,6 +1,6 @@ """Tests for authentication provider classes.""" -from unittest.mock import Mock, AsyncMock, patch +from unittest.mock import Mock, patch from mcpadapt.auth.providers import ( ApiKeyAuthProvider, BearerAuthProvider, @@ -16,13 +16,13 @@ class MockOAuthHandler(BaseOAuthHandler): """Mock OAuth handler for testing.""" - + def get_redirect_uris(self) -> List[AnyUrl]: return [AnyUrl("http://localhost:3030/callback")] - + async def handle_redirect(self, authorization_url: str) -> None: pass - + async def handle_callback(self) -> tuple[str, str | None]: return "test_code", "test_state" @@ -327,16 +327,16 @@ def test_initialization(self): response_types=["code"], token_endpoint_auth_method="client_secret_post", ) - + oauth_handler = MockOAuthHandler(client_metadata) storage = InMemoryTokenStorage() - + provider = OAuthProvider( server_url="https://test-server.com", oauth_handler=oauth_handler, storage=storage, ) - + # Verify the provider was created successfully assert provider is not None @@ -348,68 +348,74 @@ def test_oauth_provider_extracts_metadata(self): response_types=["code"], token_endpoint_auth_method="client_secret_post", ) - + oauth_handler = MockOAuthHandler(client_metadata) storage = InMemoryTokenStorage() - - with patch('mcp.client.auth.OAuthClientProvider.__init__') as mock_init: + + with patch("mcp.client.auth.OAuthClientProvider.__init__") as mock_init: mock_init.return_value = None - + OAuthProvider( server_url="https://test-server.com", oauth_handler=oauth_handler, storage=storage, ) - + # Verify parent constructor was called with correct parameters mock_init.assert_called_once() call_args = mock_init.call_args - - assert call_args[1]['server_url'] == "https://test-server.com" - assert call_args[1]['storage'] == storage - assert call_args[1]['redirect_handler'] == oauth_handler.handle_redirect - assert call_args[1]['callback_handler'] == oauth_handler.handle_callback - + + assert call_args[1]["server_url"] == "https://test-server.com" + assert call_args[1]["storage"] == storage + assert call_args[1]["redirect_handler"] == oauth_handler.handle_redirect + assert call_args[1]["callback_handler"] == oauth_handler.handle_callback + # Verify client_metadata was properly constructed - client_metadata_arg = call_args[1]['client_metadata'] + client_metadata_arg = call_args[1]["client_metadata"] assert client_metadata_arg.client_name == "Test App" assert len(client_metadata_arg.redirect_uris) == 1 - assert str(client_metadata_arg.redirect_uris[0]) == "http://localhost:3030/callback" + assert ( + str(client_metadata_arg.redirect_uris[0]) + == "http://localhost:3030/callback" + ) def test_oauth_provider_with_custom_handler(self): """Test OAuthProvider with custom handler that has different redirect URIs.""" - + class CustomTestHandler(BaseOAuthHandler): def get_redirect_uris(self) -> List[AnyUrl]: return [AnyUrl("http://localhost:8080/auth/callback")] - + async def handle_redirect(self, authorization_url: str) -> None: pass - + async def handle_callback(self) -> tuple[str, str | None]: return "custom_code", "custom_state" - + client_metadata = OAuthClientMetadata( client_name="Custom Test App", grant_types=["authorization_code"], response_types=["code"], token_endpoint_auth_method="client_secret_post", ) - + custom_handler = CustomTestHandler(client_metadata) storage = InMemoryTokenStorage() - - with patch('mcp.client.auth.OAuthClientProvider.__init__') as mock_init: + + with patch("mcp.client.auth.OAuthClientProvider.__init__") as mock_init: mock_init.return_value = None - + OAuthProvider( server_url="https://custom-server.com", oauth_handler=custom_handler, storage=storage, ) - + # Verify the custom redirect URI was used call_args = mock_init.call_args - client_metadata_arg = call_args[1]['client_metadata'] + client_metadata_arg = call_args[1]["client_metadata"] assert len(client_metadata_arg.redirect_uris) == 1 - assert str(client_metadata_arg.redirect_uris[0]) == "http://localhost:8080/auth/callback" + assert ( + str(client_metadata_arg.redirect_uris[0]) + == "http://localhost:8080/auth/callback" + ) From 86e532d67e999fb45855293dcab700f6312688a7 Mon Sep 17 00:00:00 2001 From: Amith K K Date: Sat, 13 Sep 2025 21:01:35 +0530 Subject: [PATCH 7/7] have other auth providers also inherit from httpx.Auth, allow passing in functions to BearerAuthProvider. --- docs/auth/custom-handlers.md | 36 ++- src/mcpadapt/auth/__init__.py | 3 - src/mcpadapt/auth/providers.py | 71 ++--- src/mcpadapt/core.py | 14 +- tests/auth/test_providers.py | 458 ++++++++++++++++----------------- 5 files changed, 295 insertions(+), 287 deletions(-) diff --git a/docs/auth/custom-handlers.md b/docs/auth/custom-handlers.md index b27da4e..75d23b7 100644 --- a/docs/auth/custom-handlers.md +++ b/docs/auth/custom-handlers.md @@ -152,7 +152,41 @@ class CLIHandler(BaseOAuthHandler): return auth_code, state ``` -## Using Custom Handlers +## Using BearerAuthProvider with External Token Sources + +You can pass a function to `BearerAuthProvider` to retrieve tokens from external systems dynamically: + +```python +from mcpadapt.auth import BearerAuthProvider +from mcpadapt.core import MCPAdapt +from mcpadapt.smolagents_adapter import SmolAgentsAdapter +import requests + +def get_token_from_external_service(): + """Retrieve authentication token from an external token service.""" + # Example: Get token from a token management service + response = requests.post('https://token-service.example.com/api/token', + headers={'X-Service-Key': 'your-service-key'}, + json={'service': 'mcp-client'}) + + if response.status_code == 200: + return response.json()['access_token'] + else: + raise Exception(f"Failed to get token: {response.status_code}") + +# Create BearerAuthProvider with external token function +bearer_auth = BearerAuthProvider(get_token_from_external_service) + +# Use with MCPAdapt - token will be fetched fresh on each request +with MCPAdapt( + serverparams={"url": "https://api.example.com/mcp", "transport": "sse"}, + adapter=SmolAgentsAdapter(), + auth_provider=bearer_auth, +) as tools: + print(f"Connected with external token auth: {len(tools)} tools") +``` + +## Using Custom OAuth Handlers ```python from mcpadapt.auth import OAuthProvider, OAuthClientMetadata, InMemoryTokenStorage diff --git a/src/mcpadapt/auth/__init__.py b/src/mcpadapt/auth/__init__.py index dc410e3..122745d 100644 --- a/src/mcpadapt/auth/__init__.py +++ b/src/mcpadapt/auth/__init__.py @@ -24,7 +24,6 @@ ApiKeyAuthProvider, BearerAuthProvider, OAuthProvider, - get_auth_headers, ) from .exceptions import ( OAuthError, @@ -47,8 +46,6 @@ "OAuthProvider", # Default implementations "InMemoryTokenStorage", - # Provider functions - "get_auth_headers", # Exception classes "OAuthError", "OAuthTimeoutError", diff --git a/src/mcpadapt/auth/providers.py b/src/mcpadapt/auth/providers.py index 26487e1..17f0a22 100644 --- a/src/mcpadapt/auth/providers.py +++ b/src/mcpadapt/auth/providers.py @@ -1,11 +1,12 @@ """Authentication provider classes for MCPAdapt.""" -from typing import Any +from typing import Any, Callable, Generator, Union +import httpx from mcp.client.auth import OAuthClientProvider, TokenStorage from .handlers import BaseOAuthHandler -class ApiKeyAuthProvider: +class ApiKeyAuthProvider(httpx.Auth): """Simple API key authentication provider.""" def __init__(self, header_name: str, header_value: str): @@ -18,33 +19,55 @@ def __init__(self, header_name: str, header_value: str): self.header_name = header_name self.header_value = header_value - def get_headers(self) -> dict[str, str]: - """Get authentication headers. + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + """Execute the authentication flow by adding the API key header. - Returns: - Dictionary of headers to add to requests + Args: + request: The request to authenticate + + Yields: + The authenticated request """ - return {self.header_name: self.header_value} + request.headers[self.header_name] = self.header_value + yield request -class BearerAuthProvider: - """Simple Bearer token authentication provider.""" - def __init__(self, token: str): +class BearerAuthProvider(httpx.Auth): + """Simple Bearer token authentication provider. + + Supports both static tokens (strings) and dynamic tokens (callables that return strings). + """ + + def __init__(self, token: Union[str, Callable[[], str]]): """Initialize with Bearer token configuration. Args: - token: The bearer token + token: The bearer token (string) or a callable that returns the token """ - self.token = token - - def get_headers(self) -> dict[str, str]: - """Get authentication headers. + self._token = token + def _get_token_value(self) -> str: + """Get the current token value. + Returns: - Dictionary of headers to add to requests + The token value, calling the token if it's callable """ - return {"Authorization": f"Bearer {self.token}"} + return self._token() if callable(self._token) else self._token + + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + """Execute the authentication flow by adding the Bearer token header. + + Args: + request: The request to authenticate + + Yields: + The authenticated request + """ + token_value = self._get_token_value() + request.headers["Authorization"] = f"Bearer {token_value}" + yield request + class OAuthProvider(OAuthClientProvider): @@ -71,17 +94,3 @@ def __init__( redirect_handler=oauth_handler.handle_redirect, callback_handler=oauth_handler.handle_callback, ) - - -def get_auth_headers(auth_provider: Any) -> dict[str, str]: - """Get authentication headers from provider. - - Args: - auth_provider: Authentication provider instance - - Returns: - Dictionary of headers to add to requests - """ - if isinstance(auth_provider, (ApiKeyAuthProvider, BearerAuthProvider)): - return auth_provider.get_headers() - return {} diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index d3246b5..4f09155 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -20,7 +20,7 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client -from .auth.providers import ApiKeyAuthProvider, BearerAuthProvider, get_auth_headers +from .auth.providers import ApiKeyAuthProvider, BearerAuthProvider class ToolAdapter(ABC): @@ -78,7 +78,7 @@ def async_adapt( async def mcptools( serverparams: StdioServerParameters | dict[str, Any], client_session_timeout_seconds: float | timedelta | None = 5, - auth_provider: Any = None, + auth_provider: OAuthClientProvider | ApiKeyAuthProvider | BearerAuthProvider | None = None, ) -> AsyncGenerator[tuple[ClientSession, list[mcp.types.Tool]], None]: """Async context manager that yields tools from an MCP server. @@ -108,15 +108,7 @@ async def mcptools( # Add authentication if provided if auth_provider is not None: - if isinstance(auth_provider, OAuthClientProvider): - client_params["auth"] = auth_provider - elif isinstance(auth_provider, (ApiKeyAuthProvider, BearerAuthProvider)): - # Add custom headers for API Key and Bearer auth - headers = get_auth_headers(auth_provider) - if "headers" in client_params: - client_params["headers"].update(headers) - else: - client_params["headers"] = headers + client_params["auth"] = auth_provider if transport == "sse": client = sse_client(**client_params) diff --git a/tests/auth/test_providers.py b/tests/auth/test_providers.py index 7ad825b..f30ebe0 100644 --- a/tests/auth/test_providers.py +++ b/tests/auth/test_providers.py @@ -1,11 +1,11 @@ """Tests for authentication provider classes.""" +import httpx from unittest.mock import Mock, patch from mcpadapt.auth.providers import ( ApiKeyAuthProvider, BearerAuthProvider, OAuthProvider, - get_auth_headers, ) from mcpadapt.auth.handlers import BaseOAuthHandler from mcpadapt.auth.oauth import OAuthClientMetadata @@ -48,272 +48,248 @@ def test_initialization_empty_values(self): assert provider.header_name == "" assert provider.header_value == "" - def test_get_headers_basic(self): - """Test get_headers method.""" - provider = ApiKeyAuthProvider("X-API-Key", "test-key-123") - headers = provider.get_headers() - - assert isinstance(headers, dict) - assert headers == {"X-API-Key": "test-key-123"} + def test_httpx_auth_inheritance(self): + """Test that ApiKeyAuthProvider inherits from httpx.Auth.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key") + assert isinstance(provider, httpx.Auth) - def test_get_headers_different_header(self): - """Test get_headers with different header name.""" + def test_auth_flow_basic(self): + """Test auth_flow method.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key-123") + request = httpx.Request("GET", "https://example.com") + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["X-API-Key"] == "test-key-123" + + def test_auth_flow_different_header(self): + """Test auth_flow with different header name.""" provider = ApiKeyAuthProvider("Custom-Auth", "custom-value") - headers = provider.get_headers() - - assert isinstance(headers, dict) - assert headers == {"Custom-Auth": "custom-value"} - - def test_get_headers_returns_new_dict(self): - """Test that get_headers returns a new dict instance each time.""" + request = httpx.Request("GET", "https://example.com") + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["Custom-Auth"] == "custom-value" + + def test_auth_flow_preserves_existing_headers(self): + """Test that auth_flow preserves existing headers.""" provider = ApiKeyAuthProvider("X-API-Key", "test-key") - headers1 = provider.get_headers() - headers2 = provider.get_headers() - - assert headers1 == headers2 - assert headers1 is not headers2 # Different instances - - def test_get_headers_multiple_calls(self): - """Test multiple calls to get_headers return consistent results.""" + request = httpx.Request("GET", "https://example.com", headers={"User-Agent": "test"}) + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["X-API-Key"] == "test-key" + assert authenticated_request.headers["User-Agent"] == "test" + + def test_auth_flow_multiple_calls(self): + """Test multiple calls to auth_flow return consistent results.""" provider = ApiKeyAuthProvider("X-API-Key", "test-key") for _ in range(5): - headers = provider.get_headers() - assert headers == {"X-API-Key": "test-key"} + request = httpx.Request("GET", "https://example.com") + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + assert authenticated_request.headers["X-API-Key"] == "test-key" + + def test_auth_flow_with_special_characters(self): + """Test auth_flow with special characters in values.""" + provider = ApiKeyAuthProvider("X-API-Key", "key!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`") + request = httpx.Request("GET", "https://example.com") + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["X-API-Key"] == "key!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" class TestBearerAuthProvider: """Test Bearer token authentication provider.""" - def test_initialization(self): - """Test basic initialization.""" - provider = BearerAuthProvider("test-token-123") - assert provider.token == "test-token-123" - - def test_initialization_empty_token(self): - """Test initialization with empty token.""" - provider = BearerAuthProvider("") - assert provider.token == "" - - def test_initialization_complex_token(self): - """Test initialization with complex token.""" - complex_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - provider = BearerAuthProvider(complex_token) - assert provider.token == complex_token - - def test_get_headers_basic(self): - """Test get_headers method.""" + def test_initialization_with_string_token(self): + """Test basic initialization with string token.""" provider = BearerAuthProvider("test-token-123") - headers = provider.get_headers() - - assert isinstance(headers, dict) - assert headers == {"Authorization": "Bearer test-token-123"} - - def test_get_headers_empty_token(self): - """Test get_headers with empty token.""" - provider = BearerAuthProvider("") - headers = provider.get_headers() - - assert isinstance(headers, dict) - assert headers == {"Authorization": "Bearer "} - - def test_get_headers_returns_new_dict(self): - """Test that get_headers returns a new dict instance each time.""" - provider = BearerAuthProvider("test-token") - headers1 = provider.get_headers() - headers2 = provider.get_headers() - - assert headers1 == headers2 - assert headers1 is not headers2 # Different instances - - def test_get_headers_multiple_calls(self): - """Test multiple calls to get_headers return consistent results.""" + assert provider._token == "test-token-123" + + def test_initialization_with_callable_token(self): + """Test initialization with callable token.""" + def get_token(): + return "dynamic-token" + + provider = BearerAuthProvider(get_token) + assert provider._token == get_token + assert callable(provider._token) + + def test_httpx_auth_inheritance(self): + """Test that BearerAuthProvider inherits from httpx.Auth.""" provider = BearerAuthProvider("test-token") + assert isinstance(provider, httpx.Auth) - for _ in range(5): - headers = provider.get_headers() - assert headers == {"Authorization": "Bearer test-token"} - - def test_bearer_format_consistency(self): - """Test that Bearer format is consistent.""" - tokens = [ - "simple", - "complex.token.here", - "token-with-dashes", - "token_with_underscores", - ] - - for token in tokens: - provider = BearerAuthProvider(token) - headers = provider.get_headers() - assert headers["Authorization"] == f"Bearer {token}" - assert headers["Authorization"].startswith("Bearer ") - - -class TestGetAuthHeaders: - """Test get_auth_headers utility function.""" - - def test_with_api_key_provider(self): - """Test with ApiKeyAuthProvider.""" - provider = ApiKeyAuthProvider("X-API-Key", "test-key") - headers = get_auth_headers(provider) - - assert isinstance(headers, dict) - assert headers == {"X-API-Key": "test-key"} - - def test_with_bearer_provider(self): - """Test with BearerAuthProvider.""" + def test_auth_flow_with_string_token(self): + """Test auth_flow with string token.""" + provider = BearerAuthProvider("test-token-123") + request = httpx.Request("GET", "https://example.com") + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["Authorization"] == "Bearer test-token-123" + + def test_auth_flow_with_callable_token(self): + """Test auth_flow with callable token.""" + call_count = 0 + + def get_token(): + nonlocal call_count + call_count += 1 + return f"dynamic-token-{call_count}" + + provider = BearerAuthProvider(get_token) + + # First call + request1 = httpx.Request("GET", "https://example.com") + auth_gen1 = provider.auth_flow(request1) + authenticated_request1 = next(auth_gen1) + assert authenticated_request1.headers["Authorization"] == "Bearer dynamic-token-1" + + # Second call should get a new token + request2 = httpx.Request("GET", "https://example.com") + auth_gen2 = provider.auth_flow(request2) + authenticated_request2 = next(auth_gen2) + assert authenticated_request2.headers["Authorization"] == "Bearer dynamic-token-2" + + def test_auth_flow_preserves_existing_headers(self): + """Test that auth_flow preserves existing headers.""" provider = BearerAuthProvider("test-token") - headers = get_auth_headers(provider) - - assert isinstance(headers, dict) - assert headers == {"Authorization": "Bearer test-token"} - - def test_with_unknown_provider(self): - """Test with unknown provider type.""" - unknown_provider = Mock() - headers = get_auth_headers(unknown_provider) - - assert isinstance(headers, dict) - assert headers == {} - - def test_with_none_provider(self): - """Test with None provider.""" - headers = get_auth_headers(None) - - assert isinstance(headers, dict) - assert headers == {} - - def test_with_provider_without_get_headers(self): - """Test with object that doesn't have get_headers method.""" - fake_provider = object() - headers = get_auth_headers(fake_provider) - - assert isinstance(headers, dict) - assert headers == {} - - def test_with_string_provider(self): - """Test with string instead of provider object.""" - headers = get_auth_headers("not-a-provider") - - assert isinstance(headers, dict) - assert headers == {} - - def test_with_dict_provider(self): - """Test with dict instead of provider object.""" - headers = get_auth_headers({"key": "value"}) - - assert isinstance(headers, dict) - assert headers == {} - - def test_multiple_provider_types(self): - """Test with multiple different provider types in sequence.""" - api_key_provider = ApiKeyAuthProvider("X-API-Key", "api-key") - bearer_provider = BearerAuthProvider("bearer-token") - - api_headers = get_auth_headers(api_key_provider) - bearer_headers = get_auth_headers(bearer_provider) - none_headers = get_auth_headers(None) - - assert api_headers == {"X-API-Key": "api-key"} - assert bearer_headers == {"Authorization": "Bearer bearer-token"} - assert none_headers == {} - - def test_provider_inheritance_check(self): - """Test that the function properly checks instance types.""" - - # Create a class that has get_headers but isn't a known provider - class FakeProvider: - def get_headers(self): - return {"Fake": "header"} - - fake = FakeProvider() - headers = get_auth_headers(fake) - - # Should return empty dict since it's not an ApiKeyAuthProvider or BearerAuthProvider - assert headers == {} - - -class TestProviderIntegration: - """Test provider integration scenarios.""" - - def test_api_key_provider_real_world_headers(self): - """Test API key provider with real-world header names.""" - test_cases = [ - ("X-API-Key", "sk-1234567890abcdef"), - ("Authorization", "ApiKey sk-abcdef1234567890"), - ("X-RapidAPI-Key", "rapidapi-key-here"), - ("Ocp-Apim-Subscription-Key", "azure-key"), - ("x-api-key", "lowercase-header"), - ] - - for header_name, header_value in test_cases: - provider = ApiKeyAuthProvider(header_name, header_value) - headers = provider.get_headers() - assert headers == {header_name: header_value} - - def test_bearer_provider_real_world_tokens(self): - """Test Bearer provider with real-world token formats.""" + request = httpx.Request("GET", "https://example.com", headers={"User-Agent": "test"}) + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["Authorization"] == "Bearer test-token" + assert authenticated_request.headers["User-Agent"] == "test" + + def test_get_token_value_with_string(self): + """Test _get_token_value with string token.""" + provider = BearerAuthProvider("static-token") + assert provider._get_token_value() == "static-token" + + def test_get_token_value_with_callable(self): + """Test _get_token_value with callable token.""" + def get_token(): + return "callable-token" + + provider = BearerAuthProvider(get_token) + assert provider._get_token_value() == "callable-token" + + def test_auth_flow_with_complex_tokens(self): + """Test auth_flow with complex token formats.""" test_tokens = [ "simple_token", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature", # JWT-like "ghp_1234567890abcdef1234567890abcdef12345678", # GitHub token-like "sk-1234567890abcdef1234567890abcdef1234567890abcdef1234567890", # OpenAI-like - "xoxb-1234567890-1234567890-abcdefghijklmnopqrstuvwx", # Slack-like ] for token in test_tokens: provider = BearerAuthProvider(token) - headers = provider.get_headers() - assert headers == {"Authorization": f"Bearer {token}"} - - def test_providers_with_special_characters(self): - """Test providers with special characters in values.""" - # API Key with special characters - api_provider = ApiKeyAuthProvider( - "X-API-Key", "key!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" - ) - api_headers = api_provider.get_headers() - assert "X-API-Key" in api_headers - assert api_headers["X-API-Key"] == "key!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" - - # Bearer token with special characters - bearer_provider = BearerAuthProvider("token!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`") - bearer_headers = bearer_provider.get_headers() - assert ( - bearer_headers["Authorization"] - == "Bearer token!@#$%^&*()_+-={}[]|\\:;\"'<>,.?/~`" - ) + request = httpx.Request("GET", "https://example.com") + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["Authorization"] == f"Bearer {token}" + + def test_auth_flow_with_empty_token(self): + """Test auth_flow with empty token.""" + provider = BearerAuthProvider("") + request = httpx.Request("GET", "https://example.com") + + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + + assert authenticated_request.headers["Authorization"] == "Bearer " + + def test_callable_token_exception_handling(self): + """Test that exceptions in callable tokens are not caught by provider.""" + def failing_token(): + raise ValueError("Token generation failed") + + provider = BearerAuthProvider(failing_token) + request = httpx.Request("GET", "https://example.com") + + # Should raise the exception from the callable + try: + auth_gen = provider.auth_flow(request) + next(auth_gen) + assert False, "Expected ValueError to be raised" + except ValueError as e: + assert str(e) == "Token generation failed" - def test_providers_immutability(self): - """Test that providers don't modify their internal state.""" - # Test API Key provider - api_provider = ApiKeyAuthProvider("X-API-Key", "original-key") - original_name = api_provider.header_name - original_value = api_provider.header_value - - # Get headers multiple times - for _ in range(3): - headers = api_provider.get_headers() - headers["X-API-Key"] = "modified-key" # Try to modify returned dict - - # Verify original values unchanged - assert api_provider.header_name == original_name - assert api_provider.header_value == original_value - - # Test Bearer provider - bearer_provider = BearerAuthProvider("original-token") - original_token = bearer_provider.token - - # Get headers multiple times - for _ in range(3): - headers = bearer_provider.get_headers() - headers["Authorization"] = ( - "Bearer modified-token" # Try to modify returned dict - ) - # Verify original token unchanged - assert bearer_provider.token == original_token +class TestProviderIntegration: + """Test provider integration scenarios.""" + + def test_api_key_provider_with_httpx_client(self): + """Test that ApiKeyAuthProvider works with httpx client.""" + provider = ApiKeyAuthProvider("X-API-Key", "test-key-123") + + # This would normally make a real request, but we're just testing the auth setup + client = httpx.Client(auth=provider) + assert client._auth is provider + + def test_bearer_provider_with_httpx_client(self): + """Test that BearerAuthProvider works with httpx client.""" + provider = BearerAuthProvider("test-token-123") + + # This would normally make a real request, but we're just testing the auth setup + client = httpx.Client(auth=provider) + assert client._auth is provider + + def test_both_providers_are_httpx_auth_instances(self): + """Test that both providers are proper httpx.Auth instances.""" + api_provider = ApiKeyAuthProvider("X-API-Key", "key") + bearer_provider = BearerAuthProvider("token") + + assert isinstance(api_provider, httpx.Auth) + assert isinstance(bearer_provider, httpx.Auth) + + def test_providers_with_real_world_scenarios(self): + """Test providers with realistic scenarios.""" + # API Key scenario + api_provider = ApiKeyAuthProvider("X-RapidAPI-Key", "your-rapidapi-key-here") + api_request = httpx.Request("GET", "https://api.example.com/data") + api_auth_gen = api_provider.auth_flow(api_request) + api_authenticated = next(api_auth_gen) + assert api_authenticated.headers["X-RapidAPI-Key"] == "your-rapidapi-key-here" + + # Bearer token scenario + bearer_provider = BearerAuthProvider("your-jwt-token-here") + bearer_request = httpx.Request("POST", "https://api.example.com/users") + bearer_auth_gen = bearer_provider.auth_flow(bearer_request) + bearer_authenticated = next(bearer_auth_gen) + assert bearer_authenticated.headers["Authorization"] == "Bearer your-jwt-token-here" + + def test_callable_token_refresh_scenario(self): + """Test callable token for token refresh scenarios.""" + refresh_count = 0 + + def refresh_token(): + nonlocal refresh_count + refresh_count += 1 + return f"refreshed-token-{refresh_count}" + + provider = BearerAuthProvider(refresh_token) + + # Simulate multiple API calls that would refresh the token + for i in range(3): + request = httpx.Request("GET", "https://api.example.com") + auth_gen = provider.auth_flow(request) + authenticated_request = next(auth_gen) + expected_token = f"refreshed-token-{i + 1}" + assert authenticated_request.headers["Authorization"] == f"Bearer {expected_token}" class TestOAuthProvider: