Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ wheels/

# Virtual environments
.venv

# Log files
*.log
10 changes: 10 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"mcp": {
"servers": {
"assisted-sse": {
"type": "sse",
"url": "http://localhost:8000/sse"
}
}
}
}
205 changes: 176 additions & 29 deletions server.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,144 @@
from mcp.server.fastmcp import FastMCP
import asyncio
import base64
import hashlib
import httpx
import json
import os
import starlette.requests
import starlette.responses
import urllib.parse
import webbrowser

from service_client import InventoryClient

mcp = FastMCP("AssistedService", host="0.0.0.0")

def get_offline_token():
class AuthCodeFlowHelper:
"""
Simplifies use of the OAuth authorization code flow.
"""

# These are the settings to use the 'ocm-cli' client. We should probably consider creating a
# client specifically for our use.
CLIENT_ID = "ocm-cli"
AUTH_URL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/auth"
TOKEN_URL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token"
CALLBACK_TIMEOUT = 300

def __init__(self):
# Create the async event that will be used to coordinate with the callback:
self._callback_event = asyncio.Event()
self._callback_code: str | None = None

# We will save the tokens in order to avoid running the flow for each request:
self._refresh_token: str | None = None
self._access_token: str | None = None

async def run(self) -> tuple[str, str]:
"""
Runs the OAuth authorization code flow.

This method opens the URL of the authorization server provider in a browser, so that the user
can provide the credentials. If credentials are correct, the identity provider will send a
request to the '/oauth/callback' endpoint, containing the authorization code, which will then
be exchanged for the access and refresh tokens.

Returns:
(str, str): A tuple containing the refresh and access tokens, in that order.

Raises:
RuntimeError: If something fails during the process.
"""
# Don't run the flow if we already have the tokens:
if self._refresh_token is not None and self._access_token is not None:
return (self._refresh_token, self._access_token)

# Generate a random challenge:
challenge_bytes = os.urandom(32)
challenge_verifier = base64.urlsafe_b64encode(challenge_bytes).rstrip(b'=').decode('utf-8')
challenge_hash = hashlib.sha256(challenge_verifier.encode('utf-8')).digest()
challenge_text = base64.urlsafe_b64encode(challenge_hash).rstrip(b'=').decode('utf-8')

# Build the ahtorization URL and open it in the browser:
callback_url = f"http://127.0.0.1:8000/oauth/callback"
auth_query = urllib.parse.urlencode({
"response_type": "code",
"client_id": self.CLIENT_ID,
"redirect_uri": callback_url,
"scope": "openid",
"code_challenge": challenge_text,
"code_challenge_method": "S256",
})
auth_url = f"{self.AUTH_URL}?{auth_query}"
webbrowser.open_new_tab(auth_url)

# Wait till the code has been received:
try:
await asyncio.wait_for(self._callback_event.wait(), timeout=self.CALLBACK_TIMEOUT)
except TimeoutError:
raise RuntimeError(f"Failed to get the auth code after waiting for {self.CALLBACK_TIMEOUT} seconds")

# Exchange the code for the tokens:
auth_response = httpx.post(
self.TOKEN_URL,
data={
"grant_type": "authorization_code",
"code": self._callback_code,
"client_id": self.CLIENT_ID,
"redirect_uri": callback_url,
"code_verifier": challenge_verifier,
},
)
auth_response.raise_for_status()
auth_body = auth_response.json()
self._refresh_token = auth_body.get("refresh_token")
if self._refresh_token is None:
raise RuntimeError("Response doesn't contain the refresh token")
self._access_token = auth_body.get("access_token")
if self._access_token is None:
raise RuntimeError("Response doesn't contain the access token")
return (self._refresh_token, self._access_token)

async def callback(self, request: starlette.requests.Request) -> starlette.responses.Response:
"""
Processes the callback sent by the authorization server.

Args:
request (starlette.requests.Request): The HTTP request.

Returns:
response (starlette.responses.Response): The HTTP response.
"""
self._callback_code = request.query_params["code"]
self._callback_event.set()
return starlette.responses.HTMLResponse(
status_code=200,
content="Received the authorization code, you can return to your MCP client now.",
)

# Single instance of the authorization code flow helper. It will be created only when the flow
# is enabled.
auth_code_flow_helper: AuthCodeFlowHelper | None = None

@mcp.custom_route(
path="/oauth/callback",
methods=[
"GET",
],
)
async def oauth_callback(request: starlette.requests.Request) -> starlette.responses.Response:
if auth_code_flow_helper is None:
return starlette.responses.Response(status_code=404)
return await auth_code_flow_helper.callback(request)

async def get_offline_token() -> str:
"""Retrieve the offline token from environment variables or request headers.

This function attempts to get the Red Hat OpenShift Cluster Manager (OCM) offline token
first from the OFFLINE_TOKEN environment variable, then from the OCM-Offline-Token
request header. The token is required for authenticating with the Red Hat assisted
installer service.
This function attempts to get the Red Hat OpenShift Cluster Manager (OCM) refresh token
first using the authorization code flow, if it is enabled, or trying to get it from
the OFFLINE_TOKEN environment variable, and finally from the OCM-Offline-Token request
header. The token is required for authenticating with the Red Hat assisted installer service.

Returns:
str: The offline token string used for authentication.
Expand All @@ -21,6 +147,10 @@ def get_offline_token():
RuntimeError: If no offline token is found in either environment variables
or request headers.
"""
if auth_code_flow_helper is not None:
(refresh_token, _) = await auth_code_flow_helper.run()
return refresh_token

token = os.environ.get("OFFLINE_TOKEN")
if token:
return token
Expand All @@ -32,7 +162,7 @@ def get_offline_token():
raise RuntimeError("No offline token found in environment or request headers")

@mcp.tool()
def cluster_info(cluster_id: str) -> str:
async def cluster_info(cluster_id: str) -> str:
"""Get comprehensive information about a specific assisted installer cluster.

Retrieves detailed cluster information including configuration, status, hosts,
Expand All @@ -49,10 +179,11 @@ def cluster_info(cluster_id: str) -> str:
- Network configuration (VIPs, subnets)
- Host information and roles
"""
return InventoryClient(get_offline_token()).get_cluster(cluster_id=cluster_id).to_str()
token = await get_offline_token()
return InventoryClient(token).get_cluster(cluster_id=cluster_id).to_str()

@mcp.tool()
def list_clusters() -> str:
async def list_clusters() -> str:
"""List all assisted installer clusters for the current user.

Retrieves a summary of all clusters associated with the current user's account.
Expand All @@ -67,12 +198,13 @@ def list_clusters() -> str:
- openshift_version (str): The OpenShift version being installed
- status (str): Current cluster status (e.g., 'ready', 'installing', 'error')
"""
clusters = InventoryClient(get_offline_token()).list_clusters()
token = await get_offline_token()
clusters = InventoryClient(token).list_clusters()
resp = [{"name": cluster["name"], "id": cluster["id"], "openshift_version": cluster["openshift_version"], "status": cluster["status"]} for cluster in clusters]
return json.dumps(resp)

@mcp.tool()
def cluster_events(cluster_id: str) -> str:
async def cluster_events(cluster_id: str) -> str:
"""Get the events related to a cluster with the given cluster id.

Retrieves chronological events related to cluster installation, configuration
Expand All @@ -86,10 +218,11 @@ def cluster_events(cluster_id: str) -> str:
str: A JSON-formatted string containing cluster events with timestamps,
event types, and descriptive messages about cluster activities.
"""
return InventoryClient(get_offline_token()).get_events(cluster_id=cluster_id)
token = await get_offline_token()
return InventoryClient(token).get_events(cluster_id=cluster_id)

@mcp.tool()
def host_events(cluster_id: str, host_id: str) -> str:
async def host_events(cluster_id: str, host_id: str) -> str:
"""Get events specific to a particular host within a cluster.

Retrieves events related to a specific host's installation progress, hardware
Expand All @@ -103,10 +236,11 @@ def host_events(cluster_id: str, host_id: str) -> str:
str: A JSON-formatted string containing host-specific events including
hardware validation results, installation steps, and error messages.
"""
return InventoryClient(get_offline_token()).get_events(cluster_id=cluster_id, host_id=host_id)
token = await get_offline_token()
return InventoryClient(token).get_events(cluster_id=cluster_id, host_id=host_id)

@mcp.tool()
def infraenv_info(infraenv_id: str) -> str:
async def infraenv_info(infraenv_id: str) -> str:
"""Get detailed information about an infrastructure environment (InfraEnv).

An InfraEnv contains the configuration and resources needed to boot and discover
Expand All @@ -124,10 +258,11 @@ def infraenv_info(infraenv_id: str) -> str:
- Associated cluster information
- Static network configuration if applicable
"""
return InventoryClient(get_offline_token()).get_infra_env(infraenv_id).to_str()
token = await get_offline_token()
return InventoryClient(token).get_infra_env(infraenv_id).to_str()

@mcp.tool()
def create_cluster(name: str, version: str, base_domain: str, single_node: bool) -> str:
async def create_cluster(name: str, version: str, base_domain: str, single_node: bool) -> str:
"""Create a new OpenShift cluster and associated infrastructure environment.

Creates both a cluster definition and an InfraEnv for host discovery. The cluster
Expand All @@ -148,13 +283,14 @@ def create_cluster(name: str, version: str, base_domain: str, single_node: bool)
- cluster_id (str): The unique identifier of the created cluster
- infraenv_id (str): The unique identifier of the created InfraEnv
"""
client = InventoryClient(get_offline_token())
token = await get_offline_token()
client = InventoryClient(token)
cluster = client.create_cluster(name, version, single_node, base_dns_domain=base_domain)
infraenv = client.create_infra_env(name, cluster_id=cluster.id, openshift_version=cluster.openshift_version)
return json.dumps({'cluster_id': cluster.id, 'infraenv_id': infraenv.id})

@mcp.tool()
def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str:
async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str:
"""Configure the virtual IP addresses (VIPs) for cluster API and ingress traffic.

Sets the API VIP (for cluster management) and Ingress VIP (for application traffic)
Expand All @@ -172,10 +308,11 @@ def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str:
str: A formatted string containing the updated cluster configuration
showing the newly set VIP addresses.
"""
return InventoryClient(get_offline_token()).update_cluster(cluster_id, api_vip=api_vip, ingress_vip=ingress_vip).to_str()
token = await get_offline_token()
return InventoryClient(token).update_cluster(cluster_id, api_vip=api_vip, ingress_vip=ingress_vip).to_str()

@mcp.tool()
def install_cluster(cluster_id: str) -> str:
async def install_cluster(cluster_id: str) -> str:
"""Trigger the installation process for a prepared cluster.

Initiates the OpenShift installation on all discovered and validated hosts.
Expand All @@ -195,10 +332,11 @@ def install_cluster(cluster_id: str) -> str:
- Network configuration is complete (VIPs set if required)
- All cluster validations pass
"""
return InventoryClient(get_offline_token()).install_cluster(cluster_id).to_str()
token = await get_offline_token()
return InventoryClient(token).install_cluster(cluster_id).to_str()

@mcp.tool()
def list_versions() -> str:
async def list_versions() -> str:
"""List all available OpenShift versions for installation.

Retrieves the complete list of OpenShift versions that can be installed
Expand All @@ -209,10 +347,11 @@ def list_versions() -> str:
str: A JSON string containing available OpenShift versions with metadata
including version numbers, release dates, and support status.
"""
return json.dumps(InventoryClient(get_offline_token()).get_openshift_versions(True))
token = await get_offline_token()
return json.dumps(InventoryClient(token).get_openshift_versions(True))

@mcp.tool()
def list_operator_bundles() -> str:
async def list_operator_bundles() -> str:
"""List available operator bundles for cluster installation.

Retrieves operator bundles that can be optionally installed during cluster
Expand All @@ -223,10 +362,11 @@ def list_operator_bundles() -> str:
str: A JSON string containing available operator bundles with metadata
including bundle names, descriptions, and operator details.
"""
return json.dumps(InventoryClient(get_offline_token()).get_operator_bundles())
token = await get_offline_token()
return json.dumps(InventoryClient(token).get_operator_bundles())

@mcp.tool()
def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str:
async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str:
"""Add an operator bundle to be installed with the cluster.

Configures the specified operator bundle to be automatically installed
Expand All @@ -242,10 +382,11 @@ def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str:
str: A formatted string containing the updated cluster configuration
showing the newly added operator bundle.
"""
return InventoryClient(get_offline_token()).add_operator_bundle_to_cluster(cluster_id, bundle_name).to_str()
token = await get_offline_token()
return InventoryClient(token).add_operator_bundle_to_cluster(cluster_id, bundle_name).to_str()

@mcp.tool()
def set_host_role(host_id: str, infraenv_id: str, role: str) -> str:
async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str:
"""Assign a specific role to a discovered host in the cluster.

Sets the role for a host that has been discovered through the InfraEnv boot process.
Expand All @@ -263,7 +404,13 @@ def set_host_role(host_id: str, infraenv_id: str, role: str) -> str:
str: A formatted string containing the updated host configuration
showing the newly assigned role.
"""
return InventoryClient(get_offline_token()).update_host(host_id, infraenv_id, host_role=role).to_str()
token = await get_offline_token()
return InventoryClient(token).update_host(host_id, infraenv_id, host_role=role).to_str()

if __name__ == "__main__":
# Create the authorization code flow helper if enabled:
if os.getenv("USE_AUTHORIZATION_CODE_FLOW", "false").lower() == "true":
auth_code_flow_helper = AuthCodeFlowHelper()

# Run the server:
mcp.run(transport="sse")
2 changes: 1 addition & 1 deletion service_client/assisted_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, offline_token: str):

def _get_access_token(self, offline_token: str) -> str:
params = {
"client_id": "cloud-services",
"client_id": "ocm-cli",
"grant_type": "refresh_token",
"refresh_token": offline_token,
}
Expand Down