Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"assisted-service-client>=2.41.0.post3",
"fastmcp>=2.8.0",
"fastmcp>=2.10.6",
"netaddr>=1.3.0",
"requests>=2.32.3",
"retry>=0.9.2",
Expand Down
72 changes: 38 additions & 34 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
import json
import os
import asyncio
from typing import Literal, cast

import requests
import uvicorn
from assisted_service_client import models
from mcp.server.fastmcp import FastMCP

from fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_headers

from service_client import InventoryClient, metrics, track_tool_usage, initiate_metrics
from service_client.logger import log


mcp = FastMCP("AssistedService", host="0.0.0.0")
mcp_server: FastMCP = FastMCP("AssistedService", host="0.0.0.0")


def format_presigned_url(presigned_url: models.PresignedUrl) -> str:
Expand Down Expand Up @@ -65,12 +66,11 @@ def get_offline_token() -> str:
log.debug("Found offline token in environment variables")
return token

request = mcp.get_context().request_context.request
if request is not None:
token = request.headers.get("OCM-Offline-Token")
if token:
log.debug("Found offline token in request headers")
return token
headers = get_http_headers()
token = headers.get("ocm-offline-token")
if token:
log.debug("Found offline token in request headers")
return token

log.error("No offline token found in environment or request headers")
raise RuntimeError("No offline token found in environment or request headers")
Expand All @@ -92,14 +92,13 @@ def get_access_token() -> str:
"""
log.debug("Attempting to retrieve access token")
# First try to get the token from the authorization header:
request = mcp.get_context().request_context.request
if request is not None:
header = request.headers.get("Authorization")
if header is not None:
parts = header.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
log.debug("Found access token in authorization header")
return parts[1]
headers = get_http_headers()
auth_header = headers.get("authorization")
if auth_header:
parts = auth_header.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
log.debug("Found access token in authorization header")
return parts[1]

# Now try to get the offline token, and generate a new access token from it:
log.debug("Generating new access token from offline token")
Expand All @@ -118,7 +117,7 @@ def get_access_token() -> str:
return response.json()["access_token"]


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def cluster_info(cluster_id: str) -> str:
"""
Expand All @@ -145,7 +144,7 @@ async def cluster_info(cluster_id: str) -> str:
return result.to_str()


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def list_clusters() -> str:
"""
Expand Down Expand Up @@ -179,7 +178,7 @@ async def list_clusters() -> str:
return json.dumps(resp)


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def cluster_events(cluster_id: str) -> str:
"""
Expand All @@ -203,7 +202,7 @@ async def cluster_events(cluster_id: str) -> str:
return result


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def host_events(cluster_id: str, host_id: str) -> str:
"""
Expand All @@ -229,7 +228,7 @@ async def host_events(cluster_id: str, host_id: str) -> str:
return result


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def cluster_iso_download_url(cluster_id: str) -> str:
"""
Expand Down Expand Up @@ -286,7 +285,7 @@ async def cluster_iso_download_url(cluster_id: str) -> str:
return "\n\n".join(iso_info)


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def create_cluster(
name: str, version: str, base_domain: str, single_node: bool
Expand Down Expand Up @@ -337,7 +336,7 @@ async def create_cluster(
return cluster.id


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str:
"""
Expand Down Expand Up @@ -372,7 +371,7 @@ async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> s
return result.to_str()


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def install_cluster(cluster_id: str) -> str:
"""
Expand Down Expand Up @@ -402,7 +401,7 @@ async def install_cluster(cluster_id: str) -> str:
return result.to_str()


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def list_versions() -> str:
"""
Expand All @@ -423,7 +422,7 @@ async def list_versions() -> str:
return json.dumps(result)


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def list_operator_bundles() -> str:
"""
Expand All @@ -444,7 +443,7 @@ async def list_operator_bundles() -> str:
return json.dumps(result)


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str:
"""
Expand Down Expand Up @@ -472,7 +471,7 @@ async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> s
return result.to_str()


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def cluster_credentials_download_url(cluster_id: str, file_name: str) -> str:
"""
Expand Down Expand Up @@ -516,7 +515,7 @@ async def cluster_credentials_download_url(cluster_id: str, file_name: str) -> s
return format_presigned_url(result)


@mcp.tool()
@mcp_server.tool()
@track_tool_usage()
async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str:
"""
Expand Down Expand Up @@ -546,15 +545,20 @@ async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str:

def list_tools() -> list[str]:
"""List all MCP tools."""
return list(asyncio.run(mcp_server.get_tools()))

async def mcp_list_tools() -> list[str]:
return [t.name for t in await mcp.list_tools()]

return asyncio.run(mcp_list_tools())
def get_transport() -> Literal["http", "streamable-http", "sse"]:
"""Get the transport type from the environment."""
t = os.getenv("TRANSPORT", "sse")
if t not in ["http", "streamable-http", "sse"]:
t = "sse" # fallback to default
return cast(Literal["http", "streamable-http", "sse"], t)


if __name__ == "__main__":
app = mcp.sse_app()
transport = get_transport()
app = mcp_server.http_app(transport=transport)
initiate_metrics(list_tools())
app.add_route("/metrics", metrics)
uvicorn.run(app, host="0.0.0.0")
5 changes: 5 additions & 0 deletions template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ parameters:
- name: SSO_URL
value: "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token"
description: "URL for Red Hat Single Sign-On (SSO) OpenID Connect token endpoint"
- name: TRANSPORT
value: "sse"
description: "MCP transport type. Valid values: 'http', 'streamable-http', 'sse'. Defaults to 'sse'."
- name: PULL_SECRET_URL
value: "https://api.openshift.com/api/accounts_mgmt/v1/access_token"
description: "URL for accessing pull secrets via the accounts management API"
Expand Down Expand Up @@ -81,6 +84,8 @@ objects:
value: ${INVENTORY_URL}
- name: SSO_URL
value: ${SSO_URL}
- name: TRANSPORT
value: ${TRANSPORT}
- name: PULL_SECRET_URL
value: ${PULL_SECRET_URL}
- name: CLIENT_DEBUG
Expand Down
Loading