diff --git a/pyproject.toml b/pyproject.toml index 860d8f9..505b2f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" requires-python = ">=3.13" dependencies = [ "assisted-service-client>=2.41.0.post3", - "fastmcp>=2.8.0", + "fastmcp>=2.10.6", "netaddr>=1.3.0", "requests>=2.32.3", "retry>=0.9.2", diff --git a/server.py b/server.py index f56a77d..fd60d6f 100644 --- a/server.py +++ b/server.py @@ -8,18 +8,19 @@ import json import os import asyncio +from typing import Literal, cast import requests import uvicorn from assisted_service_client import models -from mcp.server.fastmcp import FastMCP - +from fastmcp import FastMCP +from fastmcp.server.dependencies import get_http_headers from service_client import InventoryClient, metrics, track_tool_usage, initiate_metrics from service_client.logger import log -mcp = FastMCP("AssistedService", host="0.0.0.0") +mcp_server: FastMCP = FastMCP("AssistedService", host="0.0.0.0") def format_presigned_url(presigned_url: models.PresignedUrl) -> str: @@ -65,12 +66,11 @@ def get_offline_token() -> str: log.debug("Found offline token in environment variables") return token - request = mcp.get_context().request_context.request - if request is not None: - token = request.headers.get("OCM-Offline-Token") - if token: - log.debug("Found offline token in request headers") - return token + headers = get_http_headers() + token = headers.get("ocm-offline-token") + if token: + log.debug("Found offline token in request headers") + return token log.error("No offline token found in environment or request headers") raise RuntimeError("No offline token found in environment or request headers") @@ -92,14 +92,13 @@ def get_access_token() -> str: """ log.debug("Attempting to retrieve access token") # First try to get the token from the authorization header: - request = mcp.get_context().request_context.request - if request is not None: - header = request.headers.get("Authorization") - if header is not None: - parts = header.split() - if len(parts) == 2 and parts[0].lower() == "bearer": - log.debug("Found access token in authorization header") - return parts[1] + headers = get_http_headers() + auth_header = headers.get("authorization") + if auth_header: + parts = auth_header.split() + if len(parts) == 2 and parts[0].lower() == "bearer": + log.debug("Found access token in authorization header") + return parts[1] # Now try to get the offline token, and generate a new access token from it: log.debug("Generating new access token from offline token") @@ -118,7 +117,7 @@ def get_access_token() -> str: return response.json()["access_token"] -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def cluster_info(cluster_id: str) -> str: """ @@ -145,7 +144,7 @@ async def cluster_info(cluster_id: str) -> str: return result.to_str() -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def list_clusters() -> str: """ @@ -179,7 +178,7 @@ async def list_clusters() -> str: return json.dumps(resp) -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def cluster_events(cluster_id: str) -> str: """ @@ -203,7 +202,7 @@ async def cluster_events(cluster_id: str) -> str: return result -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def host_events(cluster_id: str, host_id: str) -> str: """ @@ -229,7 +228,7 @@ async def host_events(cluster_id: str, host_id: str) -> str: return result -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def cluster_iso_download_url(cluster_id: str) -> str: """ @@ -286,7 +285,7 @@ async def cluster_iso_download_url(cluster_id: str) -> str: return "\n\n".join(iso_info) -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def create_cluster( name: str, version: str, base_domain: str, single_node: bool @@ -337,7 +336,7 @@ async def create_cluster( return cluster.id -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> str: """ @@ -372,7 +371,7 @@ async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> s return result.to_str() -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def install_cluster(cluster_id: str) -> str: """ @@ -402,7 +401,7 @@ async def install_cluster(cluster_id: str) -> str: return result.to_str() -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def list_versions() -> str: """ @@ -423,7 +422,7 @@ async def list_versions() -> str: return json.dumps(result) -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def list_operator_bundles() -> str: """ @@ -444,7 +443,7 @@ async def list_operator_bundles() -> str: return json.dumps(result) -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> str: """ @@ -472,7 +471,7 @@ async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> s return result.to_str() -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def cluster_credentials_download_url(cluster_id: str, file_name: str) -> str: """ @@ -516,7 +515,7 @@ async def cluster_credentials_download_url(cluster_id: str, file_name: str) -> s return format_presigned_url(result) -@mcp.tool() +@mcp_server.tool() @track_tool_usage() async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: """ @@ -546,15 +545,20 @@ async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: def list_tools() -> list[str]: """List all MCP tools.""" + return list(asyncio.run(mcp_server.get_tools())) - async def mcp_list_tools() -> list[str]: - return [t.name for t in await mcp.list_tools()] - return asyncio.run(mcp_list_tools()) +def get_transport() -> Literal["http", "streamable-http", "sse"]: + """Get the transport type from the environment.""" + t = os.getenv("TRANSPORT", "sse") + if t not in ["http", "streamable-http", "sse"]: + t = "sse" # fallback to default + return cast(Literal["http", "streamable-http", "sse"], t) if __name__ == "__main__": - app = mcp.sse_app() + transport = get_transport() + app = mcp_server.http_app(transport=transport) initiate_metrics(list_tools()) app.add_route("/metrics", metrics) uvicorn.run(app, host="0.0.0.0") diff --git a/template.yaml b/template.yaml index ee02cb4..8e132e9 100644 --- a/template.yaml +++ b/template.yaml @@ -13,6 +13,9 @@ parameters: - name: SSO_URL value: "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token" description: "URL for Red Hat Single Sign-On (SSO) OpenID Connect token endpoint" +- name: TRANSPORT + value: "sse" + description: "MCP transport type. Valid values: 'http', 'streamable-http', 'sse'. Defaults to 'sse'." - name: PULL_SECRET_URL value: "https://api.openshift.com/api/accounts_mgmt/v1/access_token" description: "URL for accessing pull secrets via the accounts management API" @@ -81,6 +84,8 @@ objects: value: ${INVENTORY_URL} - name: SSO_URL value: ${SSO_URL} + - name: TRANSPORT + value: ${TRANSPORT} - name: PULL_SECRET_URL value: ${PULL_SECRET_URL} - name: CLIENT_DEBUG diff --git a/tests/test_server.py b/tests/test_server.py index 3a27b1c..916d1ac 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,12 +4,15 @@ import json import os -from typing import Generator, Tuple +from typing import Generator from unittest.mock import Mock, patch, call import pytest from requests.exceptions import RequestException +from fastmcp import Client +from mcp.types import TextContent + from service_client import InventoryClient import server from tests.test_utils import ( @@ -25,14 +28,10 @@ class TestTokenFunctions: """Test cases for token handling functions.""" @pytest.fixture - def mock_mcp_get_context(self) -> Generator[Tuple[Mock, Mock], None, None]: - """Mock MCP context for testing.""" - mock_context = Mock() - mock_request = Mock() - mock_context.request_context.request = mock_request - - with patch.object(server.mcp, "get_context", return_value=mock_context): - yield mock_context, mock_request + def mock_http_headers(self) -> Generator[Mock, None, None]: + """Mock HTTP headers for testing.""" + with patch("server.get_http_headers") as mock_get_headers: + yield mock_get_headers def test_get_offline_token_from_environment(self) -> None: """Test retrieving offline token from environment variables.""" @@ -42,15 +41,14 @@ def test_get_offline_token_from_environment(self) -> None: assert result == test_token def test_get_offline_token_environment_takes_precedence( - self, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_http_headers: Mock ) -> None: """Test that environment token takes precedence over request header token.""" - _mock_context, mock_request = mock_mcp_get_context env_token = "environment-token" header_token = "header-token" # Set up both environment and header tokens - mock_request.headers.get.return_value = header_token + mock_http_headers.return_value = {"ocm-offline-token": header_token} with patch.dict(os.environ, {"OFFLINE_TOKEN": env_token}): result = server.get_offline_token() @@ -59,28 +57,22 @@ def test_get_offline_token_environment_takes_precedence( assert result == env_token # Should not even check the request headers since env token was found - mock_request.headers.get.assert_not_called() + mock_http_headers.assert_not_called() - def test_get_offline_token_from_headers( - self, mock_mcp_get_context: Tuple[Mock, Mock] - ) -> None: + def test_get_offline_token_from_headers(self, mock_http_headers: Mock) -> None: """Test retrieving offline token from request headers.""" - _mock_context, mock_request = mock_mcp_get_context test_token = "test-offline-token-header" - mock_request.headers.get.return_value = test_token + mock_http_headers.return_value = {"ocm-offline-token": test_token} # Ensure environment variable is not set with patch.dict(os.environ, {}, clear=True): result = server.get_offline_token() assert result == test_token - mock_request.headers.get.assert_called_once_with("OCM-Offline-Token") + mock_http_headers.assert_called_once() - def test_get_offline_token_not_found( - self, mock_mcp_get_context: Tuple[Mock, Mock] - ) -> None: + def test_get_offline_token_not_found(self, mock_http_headers: Mock) -> None: """Test error when offline token is not found.""" - _mock_context, mock_request = mock_mcp_get_context - mock_request.headers.get.return_value = None + mock_http_headers.return_value = {} with patch.dict(os.environ, {}, clear=True): with pytest.raises(RuntimeError) as exc_info: @@ -89,33 +81,28 @@ def test_get_offline_token_not_found( def test_get_offline_token_no_request(self) -> None: """Test offline token retrieval when no request is available.""" - mock_context = Mock() - mock_context.request_context.request = None - - with patch.object(server.mcp, "get_context", return_value=mock_context): + with patch("server.get_http_headers", return_value={}): with patch.dict(os.environ, {}, clear=True): with pytest.raises(RuntimeError) as exc_info: server.get_offline_token() assert "No offline token found" in str(exc_info.value) def test_get_access_token_from_authorization_header( - self, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_http_headers: Mock ) -> None: """Test retrieving access token from Authorization header.""" - _mock_context, mock_request = mock_mcp_get_context test_token = "test-access-token" - mock_request.headers.get.return_value = f"Bearer {test_token}" + mock_http_headers.return_value = {"authorization": f"Bearer {test_token}"} result = server.get_access_token() assert result == test_token - mock_request.headers.get.assert_called_once_with("Authorization") + mock_http_headers.assert_called_once() def test_get_access_token_invalid_authorization_header( - self, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_http_headers: Mock ) -> None: """Test access token retrieval with invalid Authorization header.""" - _mock_context, mock_request = mock_mcp_get_context - mock_request.headers.get.return_value = "Invalid header format" + mock_http_headers.return_value = {"authorization": "Invalid header format"} with patch.object(server, "get_offline_token", return_value="offline-token"): with patch("requests.post") as mock_post: @@ -127,11 +114,10 @@ def test_get_access_token_invalid_authorization_header( assert result == "new-token" def test_get_access_token_no_authorization_header( - self, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_http_headers: Mock ) -> None: """Test access token retrieval without Authorization header.""" - _mock_context, mock_request = mock_mcp_get_context - mock_request.headers.get.return_value = None + mock_http_headers.return_value = {} with patch.object(server, "get_offline_token", return_value="offline-token"): with patch("requests.post") as mock_post: @@ -144,11 +130,10 @@ def test_get_access_token_no_authorization_header( @patch("requests.post") def test_get_access_token_generate_from_offline_token( - self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_post: Mock, mock_http_headers: Mock ) -> None: """Test generating access token from offline token.""" - _mock_context, mock_request = mock_mcp_get_context - mock_request.headers.get.return_value = None + mock_http_headers.return_value = {} offline_token = "test-offline-token" access_token = "generated-access-token" @@ -173,11 +158,10 @@ def test_get_access_token_generate_from_offline_token( @patch("requests.post") def test_get_access_token_custom_sso_url( - self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_post: Mock, mock_http_headers: Mock ) -> None: """Test access token generation with custom SSO URL.""" - _mock_context, mock_request = mock_mcp_get_context - mock_request.headers.get.return_value = None + mock_http_headers.return_value = {} custom_sso_url = "https://custom-sso.example.com/token" offline_token = "test-offline-token" @@ -204,11 +188,10 @@ def test_get_access_token_custom_sso_url( @patch("requests.post") def test_get_access_token_request_failure( - self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] + self, mock_post: Mock, mock_http_headers: Mock ) -> None: """Test access token generation request failure.""" - _mock_context, mock_request = mock_mcp_get_context - mock_request.headers.get.return_value = None + mock_http_headers.return_value = {} mock_post.side_effect = RequestException("Network error") @@ -218,10 +201,7 @@ def test_get_access_token_request_failure( def test_get_access_token_no_request_context(self) -> None: """Test access token retrieval when no request context is available.""" - mock_context = Mock() - mock_context.request_context.request = None - - with patch.object(server.mcp, "get_context", return_value=mock_context): + with patch("server.get_http_headers", return_value={}): with patch.object( server, "get_offline_token", return_value="offline-token" ): @@ -258,16 +238,19 @@ async def test_cluster_info_success( cluster_id = "test-cluster-id" cluster = create_test_cluster(cluster_id=cluster_id) mock_inventory_client.get_cluster.return_value = cluster - with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_info(cluster_id) - - assert result == cluster.to_str() - mock_inventory_client.get_cluster.assert_called_once_with( - cluster_id=cluster_id - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_info", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == cluster.to_str() + mock_inventory_client.get_cluster.assert_called_once_with( + cluster_id=cluster_id + ) @pytest.mark.asyncio async def test_list_clusters_success( @@ -295,11 +278,13 @@ async def test_list_clusters_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.list_clusters() - - expected_result = json.dumps(mock_clusters) - assert result == expected_result - mock_inventory_client.list_clusters.assert_called_once() + async with Client(server.mcp_server) as client: + resp = await client.call_tool("list_clusters", {}) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = json.dumps(mock_clusters) + assert resp.content[0].text == expected_result + mock_inventory_client.list_clusters.assert_called_once() @pytest.mark.asyncio async def test_cluster_events_success( @@ -315,12 +300,16 @@ async def test_cluster_events_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_events(cluster_id) - - assert result == mock_events - mock_inventory_client.get_events.assert_called_once_with( - cluster_id=cluster_id - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_events", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == mock_events + mock_inventory_client.get_events.assert_called_once_with( + cluster_id=cluster_id + ) @pytest.mark.asyncio async def test_host_events_success( @@ -337,12 +326,16 @@ async def test_host_events_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.host_events(cluster_id, host_id) - - assert result == mock_events - mock_inventory_client.get_events.assert_called_once_with( - cluster_id=cluster_id, host_id=host_id - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "host_events", {"cluster_id": cluster_id, "host_id": host_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == mock_events + mock_inventory_client.get_events.assert_called_once_with( + cluster_id=cluster_id, host_id=host_id + ) @pytest.mark.asyncio async def test_cluster_iso_download_url_success( @@ -367,14 +360,20 @@ async def test_cluster_iso_download_url_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_iso_download_url(cluster_id) - - expected_result = "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id/downloads/image\nExpires at: 2023-12-31T23:59:59Z" - assert result == expected_result - mock_inventory_client.list_infra_envs.assert_called_once_with(cluster_id) - mock_inventory_client.get_infra_env_download_url.assert_called_once_with( - "test-infraenv-id" - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_iso_download_url", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id/downloads/image\nExpires at: 2023-12-31T23:59:59Z" + assert resp.content[0].text == expected_result + mock_inventory_client.list_infra_envs.assert_called_once_with( + cluster_id + ) + mock_inventory_client.get_infra_env_download_url.assert_called_once_with( + "test-infraenv-id" + ) @pytest.mark.asyncio async def test_cluster_iso_download_url_multiple_infraenvs( @@ -421,22 +420,28 @@ async def test_cluster_iso_download_url_multiple_infraenvs( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_iso_download_url(cluster_id) - - expected_result = ( - "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id-1/downloads/image\n" - "Expires at: 2023-12-31T23:59:59Z\n\n" - "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id-2/downloads/image\n" - "Expires at: 2024-01-15T12:00:00Z" - ) - assert result == expected_result - mock_inventory_client.list_infra_envs.assert_called_once_with(cluster_id) - mock_inventory_client.get_infra_env_download_url.assert_has_calls( - [ - call("test-infraenv-id-1"), - call("test-infraenv-id-2"), - ] - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_iso_download_url", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = ( + "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id-1/downloads/image\n" + "Expires at: 2023-12-31T23:59:59Z\n\n" + "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id-2/downloads/image\n" + "Expires at: 2024-01-15T12:00:00Z" + ) + assert resp.content[0].text == expected_result + mock_inventory_client.list_infra_envs.assert_called_once_with( + cluster_id + ) + mock_inventory_client.get_infra_env_download_url.assert_has_calls( + [ + call("test-infraenv-id-1"), + call("test-infraenv-id-2"), + ] + ) @pytest.mark.asyncio async def test_cluster_iso_download_url_no_expiration( @@ -461,14 +466,20 @@ async def test_cluster_iso_download_url_no_expiration( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_iso_download_url(cluster_id) - - expected_result = "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id/downloads/image" - assert result == expected_result - mock_inventory_client.list_infra_envs.assert_called_once_with(cluster_id) - mock_inventory_client.get_infra_env_download_url.assert_called_once_with( - "test-infraenv-id" - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_iso_download_url", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id/downloads/image" + assert resp.content[0].text == expected_result + mock_inventory_client.list_infra_envs.assert_called_once_with( + cluster_id + ) + mock_inventory_client.get_infra_env_download_url.assert_called_once_with( + "test-infraenv-id" + ) @pytest.mark.asyncio async def test_cluster_iso_download_url_zero_expiration( @@ -493,15 +504,21 @@ async def test_cluster_iso_download_url_zero_expiration( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_iso_download_url(cluster_id) - - # Should not include expiration time since it's a zero/default value - expected_result = "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id/downloads/image" - assert result == expected_result - mock_inventory_client.list_infra_envs.assert_called_once_with(cluster_id) - mock_inventory_client.get_infra_env_download_url.assert_called_once_with( - "test-infraenv-id" - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_iso_download_url", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + # Should not include expiration time since it's a zero/default value + expected_result = "URL: https://api.openshift.com/api/assisted-install/v2/infra-envs/test-id/downloads/image" + assert resp.content[0].text == expected_result + mock_inventory_client.list_infra_envs.assert_called_once_with( + cluster_id + ) + mock_inventory_client.get_infra_env_download_url.assert_called_once_with( + "test-infraenv-id" + ) @pytest.mark.asyncio async def test_cluster_iso_download_url_no_infraenvs( @@ -516,10 +533,19 @@ async def test_cluster_iso_download_url_no_infraenvs( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_iso_download_url(cluster_id) - - assert result == "No ISO download URLs found for this cluster." - mock_inventory_client.list_infra_envs.assert_called_once_with(cluster_id) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_iso_download_url", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert ( + resp.content[0].text + == "No ISO download URLs found for this cluster." + ) + mock_inventory_client.list_infra_envs.assert_called_once_with( + cluster_id + ) @pytest.mark.asyncio async def test_create_cluster_success( @@ -549,17 +575,30 @@ async def test_create_cluster_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.create_cluster( - name, version, base_domain, single_node - ) - assert result == cluster.id - - mock_inventory_client.create_cluster.assert_called_once_with( - name, version, single_node, base_dns_domain=base_domain, tags="chatbot" - ) - mock_inventory_client.create_infra_env.assert_called_once_with( - name, cluster_id="cluster-id", openshift_version=version - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "create_cluster", + { + "name": name, + "version": version, + "base_domain": base_domain, + "single_node": single_node, + }, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == cluster.id + + mock_inventory_client.create_cluster.assert_called_once_with( + name, + version, + single_node, + base_dns_domain=base_domain, + tags="chatbot", + ) + mock_inventory_client.create_infra_env.assert_called_once_with( + name, cluster_id="cluster-id", openshift_version=version + ) @pytest.mark.asyncio async def test_set_cluster_vips_success( @@ -578,12 +617,21 @@ async def test_set_cluster_vips_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.set_cluster_vips(cluster_id, api_vip, ingress_vip) - - assert result == cluster.to_str() - mock_inventory_client.update_cluster.assert_called_once_with( - cluster_id, api_vip=api_vip, ingress_vip=ingress_vip - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "set_cluster_vips", + { + "cluster_id": cluster_id, + "api_vip": api_vip, + "ingress_vip": ingress_vip, + }, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == cluster.to_str() + mock_inventory_client.update_cluster.assert_called_once_with( + cluster_id, api_vip=api_vip, ingress_vip=ingress_vip + ) @pytest.mark.asyncio async def test_install_cluster_success( @@ -599,10 +647,16 @@ async def test_install_cluster_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.install_cluster(cluster_id) - - assert result == cluster.to_str() - mock_inventory_client.install_cluster.assert_called_once_with(cluster_id) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "install_cluster", {"cluster_id": cluster_id} + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == cluster.to_str() + mock_inventory_client.install_cluster.assert_called_once_with( + cluster_id + ) @pytest.mark.asyncio async def test_list_versions_success( @@ -617,11 +671,15 @@ async def test_list_versions_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.list_versions() - - expected_result = json.dumps(mock_versions) - assert result == expected_result - mock_inventory_client.get_openshift_versions.assert_called_once_with(True) + async with Client(server.mcp_server) as client: + resp = await client.call_tool("list_versions", {}) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = json.dumps(mock_versions) + assert resp.content[0].text == expected_result + mock_inventory_client.get_openshift_versions.assert_called_once_with( + True + ) @pytest.mark.asyncio async def test_list_operator_bundles_success( @@ -639,11 +697,13 @@ async def test_list_operator_bundles_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.list_operator_bundles() - - expected_result = json.dumps(mock_bundles) - assert result == expected_result - mock_inventory_client.get_operator_bundles.assert_called_once() + async with Client(server.mcp_server) as client: + resp = await client.call_tool("list_operator_bundles", {}) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = json.dumps(mock_bundles) + assert resp.content[0].text == expected_result + mock_inventory_client.get_operator_bundles.assert_called_once() @pytest.mark.asyncio async def test_add_operator_bundle_to_cluster_success( @@ -661,14 +721,17 @@ async def test_add_operator_bundle_to_cluster_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.add_operator_bundle_to_cluster( - cluster_id, bundle_name - ) - - assert result == cluster.to_str() - mock_inventory_client.add_operator_bundle_to_cluster.assert_called_once_with( - cluster_id, bundle_name - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "add_operator_bundle_to_cluster", + {"cluster_id": cluster_id, "bundle_name": bundle_name}, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == cluster.to_str() + mock_inventory_client.add_operator_bundle_to_cluster.assert_called_once_with( + cluster_id, bundle_name + ) @pytest.mark.asyncio async def test_set_host_role_success( @@ -687,12 +750,17 @@ async def test_set_host_role_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.set_host_role(host_id, infraenv_id, role) - - assert result == host.to_str() - mock_inventory_client.update_host.assert_called_once_with( - host_id, infraenv_id, host_role=role - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "set_host_role", + {"host_id": host_id, "infraenv_id": infraenv_id, "role": role}, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + assert resp.content[0].text == host.to_str() + mock_inventory_client.update_host.assert_called_once_with( + host_id, infraenv_id, host_role=role + ) @pytest.mark.asyncio async def test_cluster_credentials_download_url_success( @@ -712,15 +780,18 @@ async def test_cluster_credentials_download_url_success( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_credentials_download_url( - cluster_id, file_name - ) - - expected_result = "URL: https://example.com/presigned-url\nExpires at: 2023-12-31T23:59:59Z" - assert result == expected_result - mock_inventory_client.get_presigned_for_cluster_credentials.assert_called_once_with( - cluster_id, file_name - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_credentials_download_url", + {"cluster_id": cluster_id, "file_name": file_name}, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = "URL: https://example.com/presigned-url\nExpires at: 2023-12-31T23:59:59Z" + assert resp.content[0].text == expected_result + mock_inventory_client.get_presigned_for_cluster_credentials.assert_called_once_with( + cluster_id, file_name + ) @pytest.mark.asyncio async def test_cluster_credentials_download_url_no_expiration( @@ -740,15 +811,18 @@ async def test_cluster_credentials_download_url_no_expiration( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_credentials_download_url( - cluster_id, file_name - ) - - expected_result = "URL: https://example.com/presigned-url" - assert result == expected_result - mock_inventory_client.get_presigned_for_cluster_credentials.assert_called_once_with( - cluster_id, file_name - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_credentials_download_url", + {"cluster_id": cluster_id, "file_name": file_name}, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + expected_result = "URL: https://example.com/presigned-url" + assert resp.content[0].text == expected_result + mock_inventory_client.get_presigned_for_cluster_credentials.assert_called_once_with( + cluster_id, file_name + ) @pytest.mark.asyncio async def test_cluster_credentials_download_url_zero_expiration( @@ -770,13 +844,16 @@ async def test_cluster_credentials_download_url_zero_expiration( with patch.object( server, "InventoryClient", return_value=mock_inventory_client ): - result = await server.cluster_credentials_download_url( - cluster_id, file_name - ) - - # Should not include expiration time since it's a zero/default value - expected_result = "URL: https://example.com/presigned-url" - assert result == expected_result - mock_inventory_client.get_presigned_for_cluster_credentials.assert_called_once_with( - cluster_id, file_name - ) + async with Client(server.mcp_server) as client: + resp = await client.call_tool( + "cluster_credentials_download_url", + {"cluster_id": cluster_id, "file_name": file_name}, + ) + assert resp.content is not None and len(resp.content) > 0 + assert isinstance(resp.content[0], TextContent) + # Should not include expiration time since it's a zero/default value + expected_result = "URL: https://example.com/presigned-url" + assert resp.content[0].text == expected_result + mock_inventory_client.get_presigned_for_cluster_credentials.assert_called_once_with( + cluster_id, file_name + ) diff --git a/uv.lock b/uv.lock index 17bb90a..4f4e3de 100644 --- a/uv.lock +++ b/uv.lock @@ -47,6 +47,7 @@ dependencies = [ { name = "assisted-service-client" }, { name = "fastmcp" }, { name = "netaddr" }, + { name = "prometheus-client" }, { name = "requests" }, { name = "retry" }, { name = "types-requests" }, @@ -71,8 +72,9 @@ test = [ [package.metadata] requires-dist = [ { name = "assisted-service-client", specifier = ">=2.41.0.post3" }, - { name = "fastmcp", specifier = ">=2.8.0" }, + { name = "fastmcp", specifier = ">=2.10.6" }, { name = "netaddr", specifier = ">=1.3.0" }, + { name = "prometheus-client", specifier = ">=0.22.1" }, { name = "requests", specifier = ">=2.32.3" }, { name = "retry", specifier = ">=0.9.2" }, { name = "types-requests", specifier = ">=2.32.4.20250611" }, @@ -368,7 +370,7 @@ wheels = [ [[package]] name = "fastmcp" -version = "2.10.5" +version = "2.10.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "authlib" }, @@ -382,9 +384,9 @@ dependencies = [ { name = "python-dotenv" }, { name = "rich" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/08/7e9c8dc9c2712ccc6393383ef6d7999b84f658ee37cabc42f853e72f86e1/fastmcp-2.10.5.tar.gz", hash = "sha256:f829e0b11c4d136db1d81e20e8acb19cf5108f64059482d1853f3c940326cf04", size = 1618410, upload-time = "2025-07-11T22:23:32.968Z" } +sdist = { url = "https://files.pythonhosted.org/packages/00/a0/eceb88277ef9e3a442e099377a9b9c29fb2fa724e234486e03a44ca1c677/fastmcp-2.10.6.tar.gz", hash = "sha256:5a7b3301f9f1b64610430caef743ac70175c4b812e1949f037e4db65b0a42c5a", size = 1640538, upload-time = "2025-07-19T20:02:12.543Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/74/453a1e6d7673b831a04ac0167d34a3c21cf2a17d55b4d242f262474fff1f/fastmcp-2.10.5-py3-none-any.whl", hash = "sha256:ab218f6a66b61f6f83c413d37aa18f5c30882c44c8925f39ecd02dd855826540", size = 201275, upload-time = "2025-07-11T22:23:31.314Z" }, + { url = "https://files.pythonhosted.org/packages/dc/05/4958cccbe862958d862b6a15f2d10d2f5ec3c411268dcb131a433e5e7a0d/fastmcp-2.10.6-py3-none-any.whl", hash = "sha256:9782416a8848cc0f4cfcc578e5c17834da620bef8ecf4d0daabf5dd1272411a2", size = 202613, upload-time = "2025-07-19T20:02:11.47Z" }, ] [[package]] @@ -634,6 +636,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prometheus-client" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/cf/40dde0a2be27cc1eb41e333d1a674a74ce8b8b0457269cc640fd42b07cf7/prometheus_client-0.22.1.tar.gz", hash = "sha256:190f1331e783cf21eb60bca559354e0a4d4378facecf78f5428c39b675d20d28", size = 69746, upload-time = "2025-06-02T14:29:01.152Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/ae/ec06af4fe3ee72d16973474f122541746196aaa16cea6f66d18b963c6177/prometheus_client-0.22.1-py3-none-any.whl", hash = "sha256:cca895342e308174341b2cbf99a56bef291fbc0ef7b9e5412a0f26d653ba7094", size = 58694, upload-time = "2025-06-02T14:29:00.068Z" }, +] + [[package]] name = "py" version = "1.11.0"