From c6cff5931d52bee76e4f3100bcd41f6e5fb05425 Mon Sep 17 00:00:00 2001 From: Nick Carboni Date: Wed, 13 Aug 2025 16:58:23 -0400 Subject: [PATCH] Fix logging and requests sync calls The blockbuster tool revealed that both logging and the use of the requests library for getting the token and pull secret were blocking the async nature of the mcp server. This commit handles the logging by splitting the log into a producer and consumer through a queue and using async.to_thread for the requests calls. --- server.py | 33 +++++++------- service_client/assisted_service_api.py | 20 ++++---- service_client/exceptions.py | 2 +- service_client/logger.py | 63 ++++++++++++++++---------- tests/test_assisted_service_api.py | 62 ++++++++++++++++--------- tests/test_server.py | 28 ++++++------ 6 files changed, 121 insertions(+), 87 deletions(-) diff --git a/server.py b/server.py index c3d3f67..7fc8d30 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,6 @@ from assisted_service_client import models from mcp.server.fastmcp import FastMCP - from service_client import InventoryClient, metrics, track_tool_usage, initiate_metrics from service_client.logger import log @@ -89,7 +88,7 @@ def get_offline_token() -> str: raise RuntimeError("No offline token found in environment or request headers") -def get_access_token() -> str: +async def get_access_token() -> str: """ Retrieve the access token. @@ -125,7 +124,7 @@ def get_access_token() -> str: "SSO_URL", "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token", ) - response = requests.post(sso_url, data=params, timeout=30) + response = await asyncio.to_thread(requests.post, sso_url, data=params, timeout=30) response.raise_for_status() log.debug("Successfully generated new access token") return response.json()["access_token"] @@ -152,7 +151,7 @@ async def cluster_info(cluster_id: str) -> str: - Host information and roles """ log.info("Retrieving cluster information for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_cluster(cluster_id=cluster_id) log.info("Successfully retrieved cluster information for %s", cluster_id) return result.to_str() @@ -177,7 +176,7 @@ async def list_clusters() -> str: - status (str): Current cluster status (e.g., 'ready', 'installing', 'error') """ log.info("Retrieving list of all clusters") - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) clusters = await client.list_clusters() resp = [ { @@ -210,7 +209,7 @@ async def cluster_events(cluster_id: str) -> str: event types, and descriptive messages about cluster activities. """ log.info("Retrieving events for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_events(cluster_id=cluster_id) log.info("Successfully retrieved events for cluster %s", cluster_id) return result @@ -234,7 +233,7 @@ async def host_events(cluster_id: str, host_id: str) -> str: hardware validation results, installation steps, and error messages. """ log.info("Retrieving events for host %s in cluster %s", host_id, cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_events(cluster_id=cluster_id, host_id=host_id) log.info( "Successfully retrieved events for host %s in cluster %s", host_id, cluster_id @@ -260,7 +259,7 @@ async def cluster_iso_download_url(cluster_id: str) -> str: }] """ log.info("Retrieving InfraEnv ISO URLs for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) infra_envs = await client.list_infra_envs(cluster_id) if not infra_envs: @@ -354,7 +353,7 @@ async def create_cluster( # pylint: disable=too-many-arguments,too-many-positio cpu_architecture, ssh_public_key is not None, ) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) # Prepare cluster parameters cluster_params = { @@ -417,7 +416,7 @@ async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> s api_vip, ingress_vip, ) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.update_cluster( cluster_id, api_vip=api_vip, ingress_vip=ingress_vip ) @@ -449,7 +448,7 @@ async def install_cluster(cluster_id: str) -> str: - All cluster validations pass """ log.info("Initiating installation for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.install_cluster(cluster_id) log.info("Successfully triggered installation for cluster %s", cluster_id) return result.to_str() @@ -470,7 +469,7 @@ async def list_versions() -> str: including version numbers, release dates, and support status. """ log.info("Retrieving available OpenShift versions") - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_openshift_versions(True) log.info("Successfully retrieved OpenShift versions") return json.dumps(result) @@ -490,7 +489,7 @@ async def list_operator_bundles() -> str: including bundle names, descriptions, and operator details. """ log.info("Retrieving available operator bundles") - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_operator_bundles() log.info("Successfully retrieved %s operator bundles", len(result)) return json.dumps(result) @@ -516,7 +515,7 @@ async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> s showing the newly added operator bundle. """ log.info("Adding operator bundle '%s' to cluster %s", bundle_name, cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.add_operator_bundle_to_cluster(cluster_id, bundle_name) log.info( "Successfully added operator bundle '%s' to cluster %s", bundle_name, cluster_id @@ -558,7 +557,7 @@ async def cluster_credentials_download_url(cluster_id: str, file_name: str) -> s cluster_id, file_name, ) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_presigned_for_cluster_credentials(cluster_id, file_name) log.info( "Successfully retrieved presigned URL for cluster %s credentials file %s - %s", @@ -592,7 +591,7 @@ async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: showing the newly assigned role. """ log.info("Setting role '%s' for host %s in InfraEnv %s", role, host_id, infraenv_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.update_host(host_id, infraenv_id, host_role=role) log.info("Successfully set role '%s' for host %s", role, host_id) return result.to_str() @@ -618,7 +617,7 @@ async def set_cluster_ssh_key(cluster_id: str, ssh_public_key: str) -> str: str: A formatted string containing the updated cluster configuration. """ log.info("Setting SSH public key for cluster %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) # Update the cluster with the new SSH public key result = await client.update_cluster(cluster_id, ssh_public_key=ssh_public_key) diff --git a/service_client/assisted_service_api.py b/service_client/assisted_service_api.py index 0392a6a..495c3b7 100644 --- a/service_client/assisted_service_api.py +++ b/service_client/assisted_service_api.py @@ -40,14 +40,11 @@ def __init__(self, access_token: str): ) self.client_debug = os.environ.get("CLIENT_DEBUG", "False").lower() == "true" - @property - def pull_secret(self) -> str: + async def get_pull_secret(self) -> str: """Lazy-load the pull secret when first accessed.""" - if self._pull_secret is None: - self._pull_secret = self._get_pull_secret() - return self._pull_secret + if self._pull_secret is not None: + return self._pull_secret - def _get_pull_secret(self) -> str: url = os.environ.get( "PULL_SECRET_URL", "https://api.openshift.com/api/accounts_mgmt/v1/access_token", @@ -56,10 +53,13 @@ def _get_pull_secret(self) -> str: try: log.info("Fetching pull secret from %s", url) - response = requests.post(url, headers=headers, timeout=30) + response = await asyncio.to_thread( + requests.post, url, headers=headers, timeout=30 + ) response.raise_for_status() log.info("Successfully fetched pull secret") - return response.text + self._pull_secret = response.text + return self._pull_secret except RequestException as e: log.error("Error while fetching pull secret from %s: %s", url, str(e)) raise @@ -242,7 +242,7 @@ async def create_cluster( params = models.ClusterCreateParams( name=name, openshift_version=version, - pull_secret=self.pull_secret, + pull_secret=await self.get_pull_secret(), **cluster_params, ) log.info( @@ -272,7 +272,7 @@ async def create_infra_env( models.InfraEnv: The created infrastructure environment object. """ infra_env = models.InfraEnvCreateParams( - name=name, pull_secret=self.pull_secret, **infra_env_params + name=name, pull_secret=await self.get_pull_secret(), **infra_env_params ) log.info("Creating infrastructure environment '%s'", name) result = await asyncio.to_thread( diff --git a/service_client/exceptions.py b/service_client/exceptions.py index 1875d91..0b74586 100644 --- a/service_client/exceptions.py +++ b/service_client/exceptions.py @@ -46,7 +46,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: ) raise AssistedServiceAPIError(f"API error: Status {e.status}") from e except Exception as e: - log.error("Unexpected error during %s: %s", operation_name, str(e)) + log.exception("Unexpected error during %s: %s", operation_name, str(e)) raise AssistedServiceAPIError("An internal error occurred") from e return wrapper diff --git a/service_client/logger.py b/service_client/logger.py index 39f8b5d..22cf2ea 100644 --- a/service_client/logger.py +++ b/service_client/logger.py @@ -11,6 +11,9 @@ import os import re import sys +import atexit +import queue +from logging.handlers import QueueHandler, QueueListener class SensitiveFormatter(logging.Formatter): @@ -87,33 +90,18 @@ def get_logging_level() -> int: logging.getLogger("asyncio").setLevel(logging.ERROR) -def add_log_file_handler(logger: logging.Logger, filename: str) -> logging.FileHandler: - """ - Add a file handler to the logger with sensitive information filtering. - - Args: - logger: The logger instance to add the handler to. - filename: The path to the log file. - - Returns: - logging.FileHandler: The created file handler. - """ +def _create_file_handler(filename: str) -> logging.FileHandler: + """Create a file handler with sensitive formatting.""" fh = logging.FileHandler(filename) fh.setFormatter(SensitiveFormatter()) - logger.addHandler(fh) return fh -def add_stream_handler(logger: logging.Logger) -> None: - """ - Add a stream handler to the logger with sensitive information filtering. - - Args: - logger: The logger instance to add the handler to. - """ +def _create_stream_handler() -> logging.StreamHandler: + """Create a stream handler to stderr with sensitive formatting.""" ch = logging.StreamHandler(sys.stderr) ch.setFormatter(SensitiveFormatter()) - logger.addHandler(ch) + return ch logger_name = os.environ.get("LOGGER_NAME", "") @@ -129,9 +117,36 @@ def add_stream_handler(logger: logging.Logger) -> None: # Check if we should log to file (default: True, set to False in containers) log_to_file = os.environ.get("LOG_TO_FILE", "true").lower() == "true" +# Configure non-blocking logging via a Queue +_log_queue: queue.Queue[logging.LogRecord] = queue.Queue() + +_handlers: list[logging.Handler] = [] if log_to_file: - add_log_file_handler(log, "assisted-service-mcp.log") - add_log_file_handler(urllib3_logger, "assisted-service-mcp.log") + _handlers.append(_create_file_handler("assisted-service-mcp.log")) +_handlers.append(_create_stream_handler()) + +# Start a single listener that will process records on a background thread +_queue_listener = QueueListener(_log_queue, *_handlers, respect_handler_level=True) +_queue_listener.start() + +# Attach QueueHandler to our loggers +_queue_handler = QueueHandler(_log_queue) + +# Avoid duplicate propagation if root logger is used +log.handlers = [_queue_handler] if _queue_handler not in log.handlers else [] +log.propagate = False + +urllib3_logger.handlers = ( + [_queue_handler] if _queue_handler not in urllib3_logger.handlers else [] +) +urllib3_logger.propagate = False + + +def _stop_queue_listener() -> None: + try: + _queue_listener.stop() + except Exception: # noqa: BLE001 - best effort stop at exit + pass + -add_stream_handler(log) -add_stream_handler(urllib3_logger) +atexit.register(_stop_queue_listener) diff --git a/tests/test_assisted_service_api.py b/tests/test_assisted_service_api.py index 7a5966e..79290e8 100644 --- a/tests/test_assisted_service_api.py +++ b/tests/test_assisted_service_api.py @@ -32,9 +32,10 @@ def mock_access_token(self) -> str: @pytest.fixture def client(self, mock_access_token: str) -> InventoryClient: """Create a test client instance.""" - with patch.object( - InventoryClient, "_get_pull_secret", return_value="test-pull-secret" - ): + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response return InventoryClient(mock_access_token) @pytest.fixture @@ -42,14 +43,16 @@ def mock_api_client(self) -> Mock: """Mock API client for testing.""" return Mock() - def test_init_with_access_token(self, mock_access_token: str) -> None: + async def test_init_with_access_token(self, mock_access_token: str) -> None: """Test client initialization with access token.""" - with patch.object( - InventoryClient, "_get_pull_secret", return_value="test-pull-secret" - ): + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.text = "test-pull-secret" + mock_post.return_value = mock_response + client = InventoryClient(mock_access_token) assert client.access_token == mock_access_token - assert client.pull_secret == "test-pull-secret" + assert await client.get_pull_secret() == "test-pull-secret" assert ( client.inventory_url == "https://api.openshift.com/api/assisted-install/v2" @@ -62,15 +65,17 @@ def test_init_with_environment_variables(self, mock_access_token: str) -> None: with patch.dict( os.environ, {"INVENTORY_URL": test_url, "CLIENT_DEBUG": "true"} ): - with patch.object( - InventoryClient, "_get_pull_secret", return_value="test-pull-secret" - ): + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.text = "test-pull-secret" + mock_post.return_value = mock_response + client = InventoryClient(mock_access_token) assert client.inventory_url == test_url assert client.client_debug is True @patch("requests.post") - def test_get_pull_secret_success( + async def test_get_pull_secret_success( self, mock_post: Mock, mock_access_token: str ) -> None: """Test successful pull secret retrieval.""" @@ -81,7 +86,7 @@ def test_get_pull_secret_success( client = InventoryClient(mock_access_token) # Access the pull_secret property to trigger lazy loading - pull_secret = client.pull_secret + pull_secret = await client.get_pull_secret() mock_post.assert_called_once_with( "https://api.openshift.com/api/accounts_mgmt/v1/access_token", @@ -91,7 +96,7 @@ def test_get_pull_secret_success( assert pull_secret == "pull-secret-content" @patch("requests.post") - def test_get_pull_secret_failure( + async def test_get_pull_secret_failure( self, mock_post: Mock, mock_access_token: str ) -> None: """Test pull secret retrieval failure.""" @@ -101,10 +106,10 @@ def test_get_pull_secret_failure( # Exception should be raised when accessing pull_secret property with pytest.raises(RequestException): - _ = client.pull_secret + _ = await client.get_pull_secret() @patch("requests.post") - def test_get_pull_secret_with_custom_url( + async def test_get_pull_secret_with_custom_url( self, mock_post: Mock, mock_access_token: str ) -> None: """Test pull secret retrieval with custom URL.""" @@ -117,7 +122,7 @@ def test_get_pull_secret_with_custom_url( client = InventoryClient(mock_access_token) # Access the pull_secret property to trigger lazy loading - _ = client.pull_secret + _ = await client.get_pull_secret() mock_post.assert_called_once_with( custom_url, @@ -368,8 +373,13 @@ async def test_create_cluster_success(self, client: InventoryClient) -> None: with ( patch.object(client, "_installer_api") as mock_installer_api, - patch.object(client, "_get_pull_secret", return_value="mock-pull-secret"), + patch("requests.post") as mock_post, ): + # Mock pull secret HTTP response + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response + mock_api = Mock() mock_api.v2_register_cluster.return_value = cluster mock_installer_api.return_value = mock_api @@ -401,8 +411,13 @@ async def test_create_cluster_single_node(self, client: InventoryClient) -> None with ( patch.object(client, "_installer_api") as mock_installer_api, - patch.object(client, "_get_pull_secret", return_value="mock-pull-secret"), + patch("requests.post") as mock_post, ): + # Mock pull secret HTTP response + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response + mock_api = Mock() mock_api.v2_register_cluster.return_value = cluster mock_installer_api.return_value = mock_api @@ -428,8 +443,13 @@ async def test_create_infra_env_success(self, client: InventoryClient) -> None: with ( patch.object(client, "_installer_api") as mock_installer_api, - patch.object(client, "_get_pull_secret", return_value="mock-pull-secret"), + patch("requests.post") as mock_post, ): + # Mock pull secret HTTP response + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response + mock_api = Mock() mock_api.register_infra_env.return_value = infra_env mock_installer_api.return_value = mock_api @@ -442,7 +462,7 @@ async def test_create_infra_env_success(self, client: InventoryClient) -> None: _args, kwargs = mock_api.register_infra_env.call_args infra_env_params = kwargs["infraenv_create_params"] assert infra_env_params.name == name - assert infra_env_params.pull_secret == client.pull_secret + assert infra_env_params.pull_secret == "mock-pull-secret" @pytest.mark.asyncio async def test_update_infra_env_success(self, client: InventoryClient) -> None: diff --git a/tests/test_server.py b/tests/test_server.py index 5056631..e05b97f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -100,7 +100,7 @@ def test_get_offline_token_no_request(self) -> None: server.get_offline_token() assert "No offline token found" in str(exc_info.value) - def test_get_access_token_from_authorization_header( + async def test_get_access_token_from_authorization_header( self, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test retrieving access token from Authorization header.""" @@ -108,11 +108,11 @@ def test_get_access_token_from_authorization_header( test_token = "test-access-token" mock_request.headers.get.return_value = f"Bearer {test_token}" - result = server.get_access_token() + result = await server.get_access_token() assert result == test_token mock_request.headers.get.assert_called_once_with("Authorization") - def test_get_access_token_invalid_authorization_header( + async def test_get_access_token_invalid_authorization_header( self, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token retrieval with invalid Authorization header.""" @@ -125,10 +125,10 @@ def test_get_access_token_invalid_authorization_header( mock_response.json.return_value = {"access_token": "new-token"} mock_post.return_value = mock_response - result = server.get_access_token() + result = await server.get_access_token() assert result == "new-token" - def test_get_access_token_no_authorization_header( + async def test_get_access_token_no_authorization_header( self, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token retrieval without Authorization header.""" @@ -141,11 +141,11 @@ def test_get_access_token_no_authorization_header( mock_response.json.return_value = {"access_token": "new-token"} mock_post.return_value = mock_response - result = server.get_access_token() + result = await server.get_access_token() assert result == "new-token" @patch("requests.post") - def test_get_access_token_generate_from_offline_token( + async def test_get_access_token_generate_from_offline_token( self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test generating access token from offline token.""" @@ -160,7 +160,7 @@ def test_get_access_token_generate_from_offline_token( mock_post.return_value = mock_response with patch.object(server, "get_offline_token", return_value=offline_token): - result = server.get_access_token() + result = await server.get_access_token() assert result == access_token mock_post.assert_called_once_with( @@ -174,7 +174,7 @@ def test_get_access_token_generate_from_offline_token( ) @patch("requests.post") - def test_get_access_token_custom_sso_url( + async def test_get_access_token_custom_sso_url( self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token generation with custom SSO URL.""" @@ -191,7 +191,7 @@ def test_get_access_token_custom_sso_url( with patch.dict(os.environ, {"SSO_URL": custom_sso_url}): with patch.object(server, "get_offline_token", return_value=offline_token): - result = server.get_access_token() + result = await server.get_access_token() assert result == access_token mock_post.assert_called_once_with( @@ -205,7 +205,7 @@ def test_get_access_token_custom_sso_url( ) @patch("requests.post") - def test_get_access_token_request_failure( + async def test_get_access_token_request_failure( self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token generation request failure.""" @@ -216,9 +216,9 @@ def test_get_access_token_request_failure( with patch.object(server, "get_offline_token", return_value="offline-token"): with pytest.raises(RequestException): - server.get_access_token() + await server.get_access_token() - def test_get_access_token_no_request_context(self) -> None: + async def test_get_access_token_no_request_context(self) -> None: """Test access token retrieval when no request context is available.""" mock_context = Mock() mock_context.request_context.request = None @@ -232,7 +232,7 @@ def test_get_access_token_no_request_context(self) -> None: mock_response.json.return_value = {"access_token": "new-token"} mock_post.return_value = mock_response - result = server.get_access_token() + result = await server.get_access_token() assert result == "new-token"