From f57899e114ac05709063b03371c347d5c3462ac0 Mon Sep 17 00:00:00 2001 From: Juan Hernandez Date: Tue, 24 Jun 2025 08:05:55 +0200 Subject: [PATCH] MGMT-20908: Add support for the OAuth authorization code flow Currently the server requires the offine token in the `OFFLINE_TOKEN` environment variable, or in a request header. But use of offline tokens is deprected, and will be removed in the future. To avoid using them this patch adds an optional mechanism to use the OAuth authorization code flow, intended for use when the server is executed locally by the user. When the server starts it checks if the `USE_AUTHORIZATION_CODE_FLOW` environment variable is set to `true`. If it is then the authorization URL will be opened with the local browser, so that the user can provide the credentials. Then the authorization server will send the authorization code to the `/oauth/callback` endpoint, and the MCP sever will exchange that code for the refresh and access tokens. The refresh token is then used as it was the offline token. Related: https://issues.redhat.com/browse/MGMT-20908 Signed-off-by: Juan Hernandez --- .gitignore | 3 + .vscode/settings.json | 10 ++ server.py | 205 +++++++++++++++++++++---- service_client/assisted_service_api.py | 2 +- 4 files changed, 190 insertions(+), 30 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 505a3b1..e8fc2a7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ wheels/ # Virtual environments .venv + +# Log files +*.log diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e25a9c5 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,10 @@ +{ + "mcp": { + "servers": { + "assisted-sse": { + "type": "sse", + "url": "http://localhost:8000/sse" + } + } + } +} diff --git a/server.py b/server.py index cf75860..522f036 100644 --- a/server.py +++ b/server.py @@ -1,18 +1,144 @@ from mcp.server.fastmcp import FastMCP +import asyncio +import base64 +import hashlib +import httpx import json import os +import starlette.requests +import starlette.responses +import urllib.parse +import webbrowser from service_client import InventoryClient mcp = FastMCP("AssistedService", host="0.0.0.0") -def get_offline_token(): +class AuthCodeFlowHelper: + """ + Simplifies use of the OAuth authorization code flow. + """ + + # These are the settings to use the 'ocm-cli' client. We should probably consider creating a + # client specifically for our use. + CLIENT_ID = "ocm-cli" + AUTH_URL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/auth" + TOKEN_URL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token" + CALLBACK_TIMEOUT = 300 + + def __init__(self): + # Create the async event that will be used to coordinate with the callback: + self._callback_event = asyncio.Event() + self._callback_code: str | None = None + + # We will save the tokens in order to avoid running the flow for each request: + self._refresh_token: str | None = None + self._access_token: str | None = None + + async def run(self) -> tuple[str, str]: + """ + Runs the OAuth authorization code flow. + + This method opens the URL of the authorization server provider in a browser, so that the user + can provide the credentials. If credentials are correct, the identity provider will send a + request to the '/oauth/callback' endpoint, containing the authorization code, which will then + be exchanged for the access and refresh tokens. + + Returns: + (str, str): A tuple containing the refresh and access tokens, in that order. + + Raises: + RuntimeError: If something fails during the process. + """ + # Don't run the flow if we already have the tokens: + if self._refresh_token is not None and self._access_token is not None: + return (self._refresh_token, self._access_token) + + # Generate a random challenge: + challenge_bytes = os.urandom(32) + challenge_verifier = base64.urlsafe_b64encode(challenge_bytes).rstrip(b'=').decode('utf-8') + challenge_hash = hashlib.sha256(challenge_verifier.encode('utf-8')).digest() + challenge_text = base64.urlsafe_b64encode(challenge_hash).rstrip(b'=').decode('utf-8') + + # Build the ahtorization URL and open it in the browser: + callback_url = f"http://127.0.0.1:8000/oauth/callback" + auth_query = urllib.parse.urlencode({ + "response_type": "code", + "client_id": self.CLIENT_ID, + "redirect_uri": callback_url, + "scope": "openid", + "code_challenge": challenge_text, + "code_challenge_method": "S256", + }) + auth_url = f"{self.AUTH_URL}?{auth_query}" + webbrowser.open_new_tab(auth_url) + + # Wait till the code has been received: + try: + await asyncio.wait_for(self._callback_event.wait(), timeout=self.CALLBACK_TIMEOUT) + except TimeoutError: + raise RuntimeError(f"Failed to get the auth code after waiting for {self.CALLBACK_TIMEOUT} seconds") + + # Exchange the code for the tokens: + auth_response = httpx.post( + self.TOKEN_URL, + data={ + "grant_type": "authorization_code", + "code": self._callback_code, + "client_id": self.CLIENT_ID, + "redirect_uri": callback_url, + "code_verifier": challenge_verifier, + }, + ) + auth_response.raise_for_status() + auth_body = auth_response.json() + self._refresh_token = auth_body.get("refresh_token") + if self._refresh_token is None: + raise RuntimeError("Response doesn't contain the refresh token") + self._access_token = auth_body.get("access_token") + if self._access_token is None: + raise RuntimeError("Response doesn't contain the access token") + return (self._refresh_token, self._access_token) + + async def callback(self, request: starlette.requests.Request) -> starlette.responses.Response: + """ + Processes the callback sent by the authorization server. + + Args: + request (starlette.requests.Request): The HTTP request. + + Returns: + response (starlette.responses.Response): The HTTP response. + """ + self._callback_code = request.query_params["code"] + self._callback_event.set() + return starlette.responses.HTMLResponse( + status_code=200, + content="Received the authorization code, you can return to your MCP client now.", + ) + +# Single instance of the authorization code flow helper. It will be created only when the flow +# is enabled. +auth_code_flow_helper: AuthCodeFlowHelper | None = None + +@mcp.custom_route( + path="/oauth/callback", + methods=[ + "GET", + ], +) +async def oauth_callback(request: starlette.requests.Request) -> starlette.responses.Response: + if auth_code_flow_helper is None: + return starlette.responses.Response(status_code=404) + return await auth_code_flow_helper.callback(request) + +async def get_offline_token() -> str: """Retrieve the offline token from environment variables or request headers. - This function attempts to get the Red Hat OpenShift Cluster Manager (OCM) offline token - first from the OFFLINE_TOKEN environment variable, then from the OCM-Offline-Token - request header. The token is required for authenticating with the Red Hat assisted - installer service. + This function attempts to get the Red Hat OpenShift Cluster Manager (OCM) refresh token + first using the authorization code flow, if it is enabled, or trying to get it from + the OFFLINE_TOKEN environment variable, and finally from the OCM-Offline-Token request + header. The token is required for authenticating with the Red Hat assisted installer service. Returns: str: The offline token string used for authentication. @@ -21,6 +147,10 @@ def get_offline_token(): RuntimeError: If no offline token is found in either environment variables or request headers. """ + if auth_code_flow_helper is not None: + (refresh_token, _) = await auth_code_flow_helper.run() + return refresh_token + token = os.environ.get("OFFLINE_TOKEN") if token: return token @@ -32,7 +162,7 @@ def get_offline_token(): raise RuntimeError("No offline token found in environment or request headers") @mcp.tool() -def cluster_info(cluster_id: str) -> str: +async def cluster_info(cluster_id: str) -> str: """Get comprehensive information about a specific assisted installer cluster. Retrieves detailed cluster information including configuration, status, hosts, @@ -49,10 +179,11 @@ def cluster_info(cluster_id: str) -> str: - Network configuration (VIPs, subnets) - Host information and roles """ - return InventoryClient(get_offline_token()).get_cluster(cluster_id=cluster_id).to_str() + token = await get_offline_token() + return InventoryClient(token).get_cluster(cluster_id=cluster_id).to_str() @mcp.tool() -def list_clusters() -> str: +async def list_clusters() -> str: """List all assisted installer clusters for the current user. Retrieves a summary of all clusters associated with the current user's account. @@ -67,12 +198,13 @@ def list_clusters() -> str: - openshift_version (str): The OpenShift version being installed - status (str): Current cluster status (e.g., 'ready', 'installing', 'error') """ - clusters = InventoryClient(get_offline_token()).list_clusters() + token = await get_offline_token() + clusters = InventoryClient(token).list_clusters() resp = [{"name": cluster["name"], "id": cluster["id"], "openshift_version": cluster["openshift_version"], "status": cluster["status"]} for cluster in clusters] return json.dumps(resp) @mcp.tool() -def cluster_events(cluster_id: str) -> str: +async def cluster_events(cluster_id: str) -> str: """Get the events related to a cluster with the given cluster id. Retrieves chronological events related to cluster installation, configuration @@ -86,10 +218,11 @@ def cluster_events(cluster_id: str) -> str: str: A JSON-formatted string containing cluster events with timestamps, event types, and descriptive messages about cluster activities. """ - return InventoryClient(get_offline_token()).get_events(cluster_id=cluster_id) + token = await get_offline_token() + return InventoryClient(token).get_events(cluster_id=cluster_id) @mcp.tool() -def host_events(cluster_id: str, host_id: str) -> str: +async def host_events(cluster_id: str, host_id: str) -> str: """Get events specific to a particular host within a cluster. Retrieves events related to a specific host's installation progress, hardware @@ -103,10 +236,11 @@ def host_events(cluster_id: str, host_id: str) -> str: str: A JSON-formatted string containing host-specific events including hardware validation results, installation steps, and error messages. """ - return InventoryClient(get_offline_token()).get_events(cluster_id=cluster_id, host_id=host_id) + token = await get_offline_token() + return InventoryClient(token).get_events(cluster_id=cluster_id, host_id=host_id) @mcp.tool() -def infraenv_info(infraenv_id: str) -> str: +async def infraenv_info(infraenv_id: str) -> str: """Get detailed information about an infrastructure environment (InfraEnv). An InfraEnv contains the configuration and resources needed to boot and discover @@ -124,10 +258,11 @@ def infraenv_info(infraenv_id: str) -> str: - Associated cluster information - Static network configuration if applicable """ - return InventoryClient(get_offline_token()).get_infra_env(infraenv_id).to_str() + token = await get_offline_token() + return InventoryClient(token).get_infra_env(infraenv_id).to_str() @mcp.tool() -def create_cluster(name: str, version: str, base_domain: str, single_node: bool) -> str: +async def create_cluster(name: str, version: str, base_domain: str, single_node: bool) -> str: """Create a new OpenShift cluster and associated infrastructure environment. Creates both a cluster definition and an InfraEnv for host discovery. The cluster @@ -148,13 +283,14 @@ def create_cluster(name: str, version: str, base_domain: str, single_node: bool) - cluster_id (str): The unique identifier of the created cluster - infraenv_id (str): The unique identifier of the created InfraEnv """ - client = InventoryClient(get_offline_token()) + token = await get_offline_token() + client = InventoryClient(token) cluster = client.create_cluster(name, version, single_node, base_dns_domain=base_domain) infraenv = client.create_infra_env(name, cluster_id=cluster.id, openshift_version=cluster.openshift_version) return json.dumps({'cluster_id': cluster.id, 'infraenv_id': infraenv.id}) @mcp.tool() -def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: +async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: """Configure the virtual IP addresses (VIPs) for cluster API and ingress traffic. Sets the API VIP (for cluster management) and Ingress VIP (for application traffic) @@ -172,10 +308,11 @@ def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: str: A formatted string containing the updated cluster configuration showing the newly set VIP addresses. """ - return InventoryClient(get_offline_token()).update_cluster(cluster_id, api_vip=api_vip, ingress_vip=ingress_vip).to_str() + token = await get_offline_token() + return InventoryClient(token).update_cluster(cluster_id, api_vip=api_vip, ingress_vip=ingress_vip).to_str() @mcp.tool() -def install_cluster(cluster_id: str) -> str: +async def install_cluster(cluster_id: str) -> str: """Trigger the installation process for a prepared cluster. Initiates the OpenShift installation on all discovered and validated hosts. @@ -195,10 +332,11 @@ def install_cluster(cluster_id: str) -> str: - Network configuration is complete (VIPs set if required) - All cluster validations pass """ - return InventoryClient(get_offline_token()).install_cluster(cluster_id).to_str() + token = await get_offline_token() + return InventoryClient(token).install_cluster(cluster_id).to_str() @mcp.tool() -def list_versions() -> str: +async def list_versions() -> str: """List all available OpenShift versions for installation. Retrieves the complete list of OpenShift versions that can be installed @@ -209,10 +347,11 @@ def list_versions() -> str: str: A JSON string containing available OpenShift versions with metadata including version numbers, release dates, and support status. """ - return json.dumps(InventoryClient(get_offline_token()).get_openshift_versions(True)) + token = await get_offline_token() + return json.dumps(InventoryClient(token).get_openshift_versions(True)) @mcp.tool() -def list_operator_bundles() -> str: +async def list_operator_bundles() -> str: """List available operator bundles for cluster installation. Retrieves operator bundles that can be optionally installed during cluster @@ -223,10 +362,11 @@ def list_operator_bundles() -> str: str: A JSON string containing available operator bundles with metadata including bundle names, descriptions, and operator details. """ - return json.dumps(InventoryClient(get_offline_token()).get_operator_bundles()) + token = await get_offline_token() + return json.dumps(InventoryClient(token).get_operator_bundles()) @mcp.tool() -def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: +async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: """Add an operator bundle to be installed with the cluster. Configures the specified operator bundle to be automatically installed @@ -242,10 +382,11 @@ def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: str: A formatted string containing the updated cluster configuration showing the newly added operator bundle. """ - return InventoryClient(get_offline_token()).add_operator_bundle_to_cluster(cluster_id, bundle_name).to_str() + token = await get_offline_token() + return InventoryClient(token).add_operator_bundle_to_cluster(cluster_id, bundle_name).to_str() @mcp.tool() -def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: +async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: """Assign a specific role to a discovered host in the cluster. Sets the role for a host that has been discovered through the InfraEnv boot process. @@ -263,7 +404,13 @@ def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: str: A formatted string containing the updated host configuration showing the newly assigned role. """ - return InventoryClient(get_offline_token()).update_host(host_id, infraenv_id, host_role=role).to_str() + token = await get_offline_token() + return InventoryClient(token).update_host(host_id, infraenv_id, host_role=role).to_str() if __name__ == "__main__": + # Create the authorization code flow helper if enabled: + if os.getenv("USE_AUTHORIZATION_CODE_FLOW", "false").lower() == "true": + auth_code_flow_helper = AuthCodeFlowHelper() + + # Run the server: mcp.run(transport="sse") diff --git a/service_client/assisted_service_api.py b/service_client/assisted_service_api.py index 4e7d36e..65f0919 100644 --- a/service_client/assisted_service_api.py +++ b/service_client/assisted_service_api.py @@ -18,7 +18,7 @@ def __init__(self, offline_token: str): def _get_access_token(self, offline_token: str) -> str: params = { - "client_id": "cloud-services", + "client_id": "ocm-cli", "grant_type": "refresh_token", "refresh_token": offline_token, }