diff --git a/.gitignore b/.gitignore index 505a3b1..e8fc2a7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ wheels/ # Virtual environments .venv + +# Log files +*.log diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e25a9c5 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,10 @@ +{ + "mcp": { + "servers": { + "assisted-sse": { + "type": "sse", + "url": "http://localhost:8000/sse" + } + } + } +} diff --git a/server.py b/server.py index cf75860..522f036 100644 --- a/server.py +++ b/server.py @@ -1,18 +1,144 @@ from mcp.server.fastmcp import FastMCP +import asyncio +import base64 +import hashlib +import httpx import json import os +import starlette.requests +import starlette.responses +import urllib.parse +import webbrowser from service_client import InventoryClient mcp = FastMCP("AssistedService", host="0.0.0.0") -def get_offline_token(): +class AuthCodeFlowHelper: + """ + Simplifies use of the OAuth authorization code flow. + """ + + # These are the settings to use the 'ocm-cli' client. We should probably consider creating a + # client specifically for our use. + CLIENT_ID = "ocm-cli" + AUTH_URL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/auth" + TOKEN_URL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token" + CALLBACK_TIMEOUT = 300 + + def __init__(self): + # Create the async event that will be used to coordinate with the callback: + self._callback_event = asyncio.Event() + self._callback_code: str | None = None + + # We will save the tokens in order to avoid running the flow for each request: + self._refresh_token: str | None = None + self._access_token: str | None = None + + async def run(self) -> tuple[str, str]: + """ + Runs the OAuth authorization code flow. + + This method opens the URL of the authorization server provider in a browser, so that the user + can provide the credentials. If credentials are correct, the identity provider will send a + request to the '/oauth/callback' endpoint, containing the authorization code, which will then + be exchanged for the access and refresh tokens. + + Returns: + (str, str): A tuple containing the refresh and access tokens, in that order. + + Raises: + RuntimeError: If something fails during the process. + """ + # Don't run the flow if we already have the tokens: + if self._refresh_token is not None and self._access_token is not None: + return (self._refresh_token, self._access_token) + + # Generate a random challenge: + challenge_bytes = os.urandom(32) + challenge_verifier = base64.urlsafe_b64encode(challenge_bytes).rstrip(b'=').decode('utf-8') + challenge_hash = hashlib.sha256(challenge_verifier.encode('utf-8')).digest() + challenge_text = base64.urlsafe_b64encode(challenge_hash).rstrip(b'=').decode('utf-8') + + # Build the ahtorization URL and open it in the browser: + callback_url = f"http://127.0.0.1:8000/oauth/callback" + auth_query = urllib.parse.urlencode({ + "response_type": "code", + "client_id": self.CLIENT_ID, + "redirect_uri": callback_url, + "scope": "openid", + "code_challenge": challenge_text, + "code_challenge_method": "S256", + }) + auth_url = f"{self.AUTH_URL}?{auth_query}" + webbrowser.open_new_tab(auth_url) + + # Wait till the code has been received: + try: + await asyncio.wait_for(self._callback_event.wait(), timeout=self.CALLBACK_TIMEOUT) + except TimeoutError: + raise RuntimeError(f"Failed to get the auth code after waiting for {self.CALLBACK_TIMEOUT} seconds") + + # Exchange the code for the tokens: + auth_response = httpx.post( + self.TOKEN_URL, + data={ + "grant_type": "authorization_code", + "code": self._callback_code, + "client_id": self.CLIENT_ID, + "redirect_uri": callback_url, + "code_verifier": challenge_verifier, + }, + ) + auth_response.raise_for_status() + auth_body = auth_response.json() + self._refresh_token = auth_body.get("refresh_token") + if self._refresh_token is None: + raise RuntimeError("Response doesn't contain the refresh token") + self._access_token = auth_body.get("access_token") + if self._access_token is None: + raise RuntimeError("Response doesn't contain the access token") + return (self._refresh_token, self._access_token) + + async def callback(self, request: starlette.requests.Request) -> starlette.responses.Response: + """ + Processes the callback sent by the authorization server. + + Args: + request (starlette.requests.Request): The HTTP request. + + Returns: + response (starlette.responses.Response): The HTTP response. + """ + self._callback_code = request.query_params["code"] + self._callback_event.set() + return starlette.responses.HTMLResponse( + status_code=200, + content="Received the authorization code, you can return to your MCP client now.", + ) + +# Single instance of the authorization code flow helper. It will be created only when the flow +# is enabled. +auth_code_flow_helper: AuthCodeFlowHelper | None = None + +@mcp.custom_route( + path="/oauth/callback", + methods=[ + "GET", + ], +) +async def oauth_callback(request: starlette.requests.Request) -> starlette.responses.Response: + if auth_code_flow_helper is None: + return starlette.responses.Response(status_code=404) + return await auth_code_flow_helper.callback(request) + +async def get_offline_token() -> str: """Retrieve the offline token from environment variables or request headers. - This function attempts to get the Red Hat OpenShift Cluster Manager (OCM) offline token - first from the OFFLINE_TOKEN environment variable, then from the OCM-Offline-Token - request header. The token is required for authenticating with the Red Hat assisted - installer service. + This function attempts to get the Red Hat OpenShift Cluster Manager (OCM) refresh token + first using the authorization code flow, if it is enabled, or trying to get it from + the OFFLINE_TOKEN environment variable, and finally from the OCM-Offline-Token request + header. The token is required for authenticating with the Red Hat assisted installer service. Returns: str: The offline token string used for authentication. @@ -21,6 +147,10 @@ def get_offline_token(): RuntimeError: If no offline token is found in either environment variables or request headers. """ + if auth_code_flow_helper is not None: + (refresh_token, _) = await auth_code_flow_helper.run() + return refresh_token + token = os.environ.get("OFFLINE_TOKEN") if token: return token @@ -32,7 +162,7 @@ def get_offline_token(): raise RuntimeError("No offline token found in environment or request headers") @mcp.tool() -def cluster_info(cluster_id: str) -> str: +async def cluster_info(cluster_id: str) -> str: """Get comprehensive information about a specific assisted installer cluster. Retrieves detailed cluster information including configuration, status, hosts, @@ -49,10 +179,11 @@ def cluster_info(cluster_id: str) -> str: - Network configuration (VIPs, subnets) - Host information and roles """ - return InventoryClient(get_offline_token()).get_cluster(cluster_id=cluster_id).to_str() + token = await get_offline_token() + return InventoryClient(token).get_cluster(cluster_id=cluster_id).to_str() @mcp.tool() -def list_clusters() -> str: +async def list_clusters() -> str: """List all assisted installer clusters for the current user. Retrieves a summary of all clusters associated with the current user's account. @@ -67,12 +198,13 @@ def list_clusters() -> str: - openshift_version (str): The OpenShift version being installed - status (str): Current cluster status (e.g., 'ready', 'installing', 'error') """ - clusters = InventoryClient(get_offline_token()).list_clusters() + token = await get_offline_token() + clusters = InventoryClient(token).list_clusters() resp = [{"name": cluster["name"], "id": cluster["id"], "openshift_version": cluster["openshift_version"], "status": cluster["status"]} for cluster in clusters] return json.dumps(resp) @mcp.tool() -def cluster_events(cluster_id: str) -> str: +async def cluster_events(cluster_id: str) -> str: """Get the events related to a cluster with the given cluster id. Retrieves chronological events related to cluster installation, configuration @@ -86,10 +218,11 @@ def cluster_events(cluster_id: str) -> str: str: A JSON-formatted string containing cluster events with timestamps, event types, and descriptive messages about cluster activities. """ - return InventoryClient(get_offline_token()).get_events(cluster_id=cluster_id) + token = await get_offline_token() + return InventoryClient(token).get_events(cluster_id=cluster_id) @mcp.tool() -def host_events(cluster_id: str, host_id: str) -> str: +async def host_events(cluster_id: str, host_id: str) -> str: """Get events specific to a particular host within a cluster. Retrieves events related to a specific host's installation progress, hardware @@ -103,10 +236,11 @@ def host_events(cluster_id: str, host_id: str) -> str: str: A JSON-formatted string containing host-specific events including hardware validation results, installation steps, and error messages. """ - return InventoryClient(get_offline_token()).get_events(cluster_id=cluster_id, host_id=host_id) + token = await get_offline_token() + return InventoryClient(token).get_events(cluster_id=cluster_id, host_id=host_id) @mcp.tool() -def infraenv_info(infraenv_id: str) -> str: +async def infraenv_info(infraenv_id: str) -> str: """Get detailed information about an infrastructure environment (InfraEnv). An InfraEnv contains the configuration and resources needed to boot and discover @@ -124,10 +258,11 @@ def infraenv_info(infraenv_id: str) -> str: - Associated cluster information - Static network configuration if applicable """ - return InventoryClient(get_offline_token()).get_infra_env(infraenv_id).to_str() + token = await get_offline_token() + return InventoryClient(token).get_infra_env(infraenv_id).to_str() @mcp.tool() -def create_cluster(name: str, version: str, base_domain: str, single_node: bool) -> str: +async def create_cluster(name: str, version: str, base_domain: str, single_node: bool) -> str: """Create a new OpenShift cluster and associated infrastructure environment. Creates both a cluster definition and an InfraEnv for host discovery. The cluster @@ -148,13 +283,14 @@ def create_cluster(name: str, version: str, base_domain: str, single_node: bool) - cluster_id (str): The unique identifier of the created cluster - infraenv_id (str): The unique identifier of the created InfraEnv """ - client = InventoryClient(get_offline_token()) + token = await get_offline_token() + client = InventoryClient(token) cluster = client.create_cluster(name, version, single_node, base_dns_domain=base_domain) infraenv = client.create_infra_env(name, cluster_id=cluster.id, openshift_version=cluster.openshift_version) return json.dumps({'cluster_id': cluster.id, 'infraenv_id': infraenv.id}) @mcp.tool() -def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: +async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: """Configure the virtual IP addresses (VIPs) for cluster API and ingress traffic. Sets the API VIP (for cluster management) and Ingress VIP (for application traffic) @@ -172,10 +308,11 @@ def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: str: A formatted string containing the updated cluster configuration showing the newly set VIP addresses. """ - return InventoryClient(get_offline_token()).update_cluster(cluster_id, api_vip=api_vip, ingress_vip=ingress_vip).to_str() + token = await get_offline_token() + return InventoryClient(token).update_cluster(cluster_id, api_vip=api_vip, ingress_vip=ingress_vip).to_str() @mcp.tool() -def install_cluster(cluster_id: str) -> str: +async def install_cluster(cluster_id: str) -> str: """Trigger the installation process for a prepared cluster. Initiates the OpenShift installation on all discovered and validated hosts. @@ -195,10 +332,11 @@ def install_cluster(cluster_id: str) -> str: - Network configuration is complete (VIPs set if required) - All cluster validations pass """ - return InventoryClient(get_offline_token()).install_cluster(cluster_id).to_str() + token = await get_offline_token() + return InventoryClient(token).install_cluster(cluster_id).to_str() @mcp.tool() -def list_versions() -> str: +async def list_versions() -> str: """List all available OpenShift versions for installation. Retrieves the complete list of OpenShift versions that can be installed @@ -209,10 +347,11 @@ def list_versions() -> str: str: A JSON string containing available OpenShift versions with metadata including version numbers, release dates, and support status. """ - return json.dumps(InventoryClient(get_offline_token()).get_openshift_versions(True)) + token = await get_offline_token() + return json.dumps(InventoryClient(token).get_openshift_versions(True)) @mcp.tool() -def list_operator_bundles() -> str: +async def list_operator_bundles() -> str: """List available operator bundles for cluster installation. Retrieves operator bundles that can be optionally installed during cluster @@ -223,10 +362,11 @@ def list_operator_bundles() -> str: str: A JSON string containing available operator bundles with metadata including bundle names, descriptions, and operator details. """ - return json.dumps(InventoryClient(get_offline_token()).get_operator_bundles()) + token = await get_offline_token() + return json.dumps(InventoryClient(token).get_operator_bundles()) @mcp.tool() -def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: +async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: """Add an operator bundle to be installed with the cluster. Configures the specified operator bundle to be automatically installed @@ -242,10 +382,11 @@ def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: str: A formatted string containing the updated cluster configuration showing the newly added operator bundle. """ - return InventoryClient(get_offline_token()).add_operator_bundle_to_cluster(cluster_id, bundle_name).to_str() + token = await get_offline_token() + return InventoryClient(token).add_operator_bundle_to_cluster(cluster_id, bundle_name).to_str() @mcp.tool() -def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: +async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: """Assign a specific role to a discovered host in the cluster. Sets the role for a host that has been discovered through the InfraEnv boot process. @@ -263,7 +404,13 @@ def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: str: A formatted string containing the updated host configuration showing the newly assigned role. """ - return InventoryClient(get_offline_token()).update_host(host_id, infraenv_id, host_role=role).to_str() + token = await get_offline_token() + return InventoryClient(token).update_host(host_id, infraenv_id, host_role=role).to_str() if __name__ == "__main__": + # Create the authorization code flow helper if enabled: + if os.getenv("USE_AUTHORIZATION_CODE_FLOW", "false").lower() == "true": + auth_code_flow_helper = AuthCodeFlowHelper() + + # Run the server: mcp.run(transport="sse") diff --git a/service_client/assisted_service_api.py b/service_client/assisted_service_api.py index 4e7d36e..65f0919 100644 --- a/service_client/assisted_service_api.py +++ b/service_client/assisted_service_api.py @@ -18,7 +18,7 @@ def __init__(self, offline_token: str): def _get_access_token(self, offline_token: str) -> str: params = { - "client_id": "cloud-services", + "client_id": "ocm-cli", "grant_type": "refresh_token", "refresh_token": offline_token, }