diff --git a/server.py b/server.py index c3d3f67..7fc8d30 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,6 @@ from assisted_service_client import models from mcp.server.fastmcp import FastMCP - from service_client import InventoryClient, metrics, track_tool_usage, initiate_metrics from service_client.logger import log @@ -89,7 +88,7 @@ def get_offline_token() -> str: raise RuntimeError("No offline token found in environment or request headers") -def get_access_token() -> str: +async def get_access_token() -> str: """ Retrieve the access token. @@ -125,7 +124,7 @@ def get_access_token() -> str: "SSO_URL", "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token", ) - response = requests.post(sso_url, data=params, timeout=30) + response = await asyncio.to_thread(requests.post, sso_url, data=params, timeout=30) response.raise_for_status() log.debug("Successfully generated new access token") return response.json()["access_token"] @@ -152,7 +151,7 @@ async def cluster_info(cluster_id: str) -> str: - Host information and roles """ log.info("Retrieving cluster information for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_cluster(cluster_id=cluster_id) log.info("Successfully retrieved cluster information for %s", cluster_id) return result.to_str() @@ -177,7 +176,7 @@ async def list_clusters() -> str: - status (str): Current cluster status (e.g., 'ready', 'installing', 'error') """ log.info("Retrieving list of all clusters") - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) clusters = await client.list_clusters() resp = [ { @@ -210,7 +209,7 @@ async def cluster_events(cluster_id: str) -> str: event types, and descriptive messages about cluster activities. """ log.info("Retrieving events for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_events(cluster_id=cluster_id) log.info("Successfully retrieved events for cluster %s", cluster_id) return result @@ -234,7 +233,7 @@ async def host_events(cluster_id: str, host_id: str) -> str: hardware validation results, installation steps, and error messages. """ log.info("Retrieving events for host %s in cluster %s", host_id, cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_events(cluster_id=cluster_id, host_id=host_id) log.info( "Successfully retrieved events for host %s in cluster %s", host_id, cluster_id @@ -260,7 +259,7 @@ async def cluster_iso_download_url(cluster_id: str) -> str: }] """ log.info("Retrieving InfraEnv ISO URLs for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) infra_envs = await client.list_infra_envs(cluster_id) if not infra_envs: @@ -354,7 +353,7 @@ async def create_cluster( # pylint: disable=too-many-arguments,too-many-positio cpu_architecture, ssh_public_key is not None, ) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) # Prepare cluster parameters cluster_params = { @@ -417,7 +416,7 @@ async def set_cluster_vips(cluster_id: str, api_vip: str, ingress_vip: str) -> s api_vip, ingress_vip, ) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.update_cluster( cluster_id, api_vip=api_vip, ingress_vip=ingress_vip ) @@ -449,7 +448,7 @@ async def install_cluster(cluster_id: str) -> str: - All cluster validations pass """ log.info("Initiating installation for cluster_id: %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.install_cluster(cluster_id) log.info("Successfully triggered installation for cluster %s", cluster_id) return result.to_str() @@ -470,7 +469,7 @@ async def list_versions() -> str: including version numbers, release dates, and support status. """ log.info("Retrieving available OpenShift versions") - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_openshift_versions(True) log.info("Successfully retrieved OpenShift versions") return json.dumps(result) @@ -490,7 +489,7 @@ async def list_operator_bundles() -> str: including bundle names, descriptions, and operator details. """ log.info("Retrieving available operator bundles") - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_operator_bundles() log.info("Successfully retrieved %s operator bundles", len(result)) return json.dumps(result) @@ -516,7 +515,7 @@ async def add_operator_bundle_to_cluster(cluster_id: str, bundle_name: str) -> s showing the newly added operator bundle. """ log.info("Adding operator bundle '%s' to cluster %s", bundle_name, cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.add_operator_bundle_to_cluster(cluster_id, bundle_name) log.info( "Successfully added operator bundle '%s' to cluster %s", bundle_name, cluster_id @@ -558,7 +557,7 @@ async def cluster_credentials_download_url(cluster_id: str, file_name: str) -> s cluster_id, file_name, ) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.get_presigned_for_cluster_credentials(cluster_id, file_name) log.info( "Successfully retrieved presigned URL for cluster %s credentials file %s - %s", @@ -592,7 +591,7 @@ async def set_host_role(host_id: str, infraenv_id: str, role: str) -> str: showing the newly assigned role. """ log.info("Setting role '%s' for host %s in InfraEnv %s", role, host_id, infraenv_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) result = await client.update_host(host_id, infraenv_id, host_role=role) log.info("Successfully set role '%s' for host %s", role, host_id) return result.to_str() @@ -618,7 +617,7 @@ async def set_cluster_ssh_key(cluster_id: str, ssh_public_key: str) -> str: str: A formatted string containing the updated cluster configuration. """ log.info("Setting SSH public key for cluster %s", cluster_id) - client = InventoryClient(get_access_token()) + client = InventoryClient(await get_access_token()) # Update the cluster with the new SSH public key result = await client.update_cluster(cluster_id, ssh_public_key=ssh_public_key) diff --git a/service_client/assisted_service_api.py b/service_client/assisted_service_api.py index 0392a6a..495c3b7 100644 --- a/service_client/assisted_service_api.py +++ b/service_client/assisted_service_api.py @@ -40,14 +40,11 @@ def __init__(self, access_token: str): ) self.client_debug = os.environ.get("CLIENT_DEBUG", "False").lower() == "true" - @property - def pull_secret(self) -> str: + async def get_pull_secret(self) -> str: """Lazy-load the pull secret when first accessed.""" - if self._pull_secret is None: - self._pull_secret = self._get_pull_secret() - return self._pull_secret + if self._pull_secret is not None: + return self._pull_secret - def _get_pull_secret(self) -> str: url = os.environ.get( "PULL_SECRET_URL", "https://api.openshift.com/api/accounts_mgmt/v1/access_token", @@ -56,10 +53,13 @@ def _get_pull_secret(self) -> str: try: log.info("Fetching pull secret from %s", url) - response = requests.post(url, headers=headers, timeout=30) + response = await asyncio.to_thread( + requests.post, url, headers=headers, timeout=30 + ) response.raise_for_status() log.info("Successfully fetched pull secret") - return response.text + self._pull_secret = response.text + return self._pull_secret except RequestException as e: log.error("Error while fetching pull secret from %s: %s", url, str(e)) raise @@ -242,7 +242,7 @@ async def create_cluster( params = models.ClusterCreateParams( name=name, openshift_version=version, - pull_secret=self.pull_secret, + pull_secret=await self.get_pull_secret(), **cluster_params, ) log.info( @@ -272,7 +272,7 @@ async def create_infra_env( models.InfraEnv: The created infrastructure environment object. """ infra_env = models.InfraEnvCreateParams( - name=name, pull_secret=self.pull_secret, **infra_env_params + name=name, pull_secret=await self.get_pull_secret(), **infra_env_params ) log.info("Creating infrastructure environment '%s'", name) result = await asyncio.to_thread( diff --git a/service_client/exceptions.py b/service_client/exceptions.py index 1875d91..0b74586 100644 --- a/service_client/exceptions.py +++ b/service_client/exceptions.py @@ -46,7 +46,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: ) raise AssistedServiceAPIError(f"API error: Status {e.status}") from e except Exception as e: - log.error("Unexpected error during %s: %s", operation_name, str(e)) + log.exception("Unexpected error during %s: %s", operation_name, str(e)) raise AssistedServiceAPIError("An internal error occurred") from e return wrapper diff --git a/service_client/logger.py b/service_client/logger.py index 39f8b5d..22cf2ea 100644 --- a/service_client/logger.py +++ b/service_client/logger.py @@ -11,6 +11,9 @@ import os import re import sys +import atexit +import queue +from logging.handlers import QueueHandler, QueueListener class SensitiveFormatter(logging.Formatter): @@ -87,33 +90,18 @@ def get_logging_level() -> int: logging.getLogger("asyncio").setLevel(logging.ERROR) -def add_log_file_handler(logger: logging.Logger, filename: str) -> logging.FileHandler: - """ - Add a file handler to the logger with sensitive information filtering. - - Args: - logger: The logger instance to add the handler to. - filename: The path to the log file. - - Returns: - logging.FileHandler: The created file handler. - """ +def _create_file_handler(filename: str) -> logging.FileHandler: + """Create a file handler with sensitive formatting.""" fh = logging.FileHandler(filename) fh.setFormatter(SensitiveFormatter()) - logger.addHandler(fh) return fh -def add_stream_handler(logger: logging.Logger) -> None: - """ - Add a stream handler to the logger with sensitive information filtering. - - Args: - logger: The logger instance to add the handler to. - """ +def _create_stream_handler() -> logging.StreamHandler: + """Create a stream handler to stderr with sensitive formatting.""" ch = logging.StreamHandler(sys.stderr) ch.setFormatter(SensitiveFormatter()) - logger.addHandler(ch) + return ch logger_name = os.environ.get("LOGGER_NAME", "") @@ -129,9 +117,36 @@ def add_stream_handler(logger: logging.Logger) -> None: # Check if we should log to file (default: True, set to False in containers) log_to_file = os.environ.get("LOG_TO_FILE", "true").lower() == "true" +# Configure non-blocking logging via a Queue +_log_queue: queue.Queue[logging.LogRecord] = queue.Queue() + +_handlers: list[logging.Handler] = [] if log_to_file: - add_log_file_handler(log, "assisted-service-mcp.log") - add_log_file_handler(urllib3_logger, "assisted-service-mcp.log") + _handlers.append(_create_file_handler("assisted-service-mcp.log")) +_handlers.append(_create_stream_handler()) + +# Start a single listener that will process records on a background thread +_queue_listener = QueueListener(_log_queue, *_handlers, respect_handler_level=True) +_queue_listener.start() + +# Attach QueueHandler to our loggers +_queue_handler = QueueHandler(_log_queue) + +# Avoid duplicate propagation if root logger is used +log.handlers = [_queue_handler] if _queue_handler not in log.handlers else [] +log.propagate = False + +urllib3_logger.handlers = ( + [_queue_handler] if _queue_handler not in urllib3_logger.handlers else [] +) +urllib3_logger.propagate = False + + +def _stop_queue_listener() -> None: + try: + _queue_listener.stop() + except Exception: # noqa: BLE001 - best effort stop at exit + pass + -add_stream_handler(log) -add_stream_handler(urllib3_logger) +atexit.register(_stop_queue_listener) diff --git a/tests/test_assisted_service_api.py b/tests/test_assisted_service_api.py index 7a5966e..79290e8 100644 --- a/tests/test_assisted_service_api.py +++ b/tests/test_assisted_service_api.py @@ -32,9 +32,10 @@ def mock_access_token(self) -> str: @pytest.fixture def client(self, mock_access_token: str) -> InventoryClient: """Create a test client instance.""" - with patch.object( - InventoryClient, "_get_pull_secret", return_value="test-pull-secret" - ): + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response return InventoryClient(mock_access_token) @pytest.fixture @@ -42,14 +43,16 @@ def mock_api_client(self) -> Mock: """Mock API client for testing.""" return Mock() - def test_init_with_access_token(self, mock_access_token: str) -> None: + async def test_init_with_access_token(self, mock_access_token: str) -> None: """Test client initialization with access token.""" - with patch.object( - InventoryClient, "_get_pull_secret", return_value="test-pull-secret" - ): + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.text = "test-pull-secret" + mock_post.return_value = mock_response + client = InventoryClient(mock_access_token) assert client.access_token == mock_access_token - assert client.pull_secret == "test-pull-secret" + assert await client.get_pull_secret() == "test-pull-secret" assert ( client.inventory_url == "https://api.openshift.com/api/assisted-install/v2" @@ -62,15 +65,17 @@ def test_init_with_environment_variables(self, mock_access_token: str) -> None: with patch.dict( os.environ, {"INVENTORY_URL": test_url, "CLIENT_DEBUG": "true"} ): - with patch.object( - InventoryClient, "_get_pull_secret", return_value="test-pull-secret" - ): + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.text = "test-pull-secret" + mock_post.return_value = mock_response + client = InventoryClient(mock_access_token) assert client.inventory_url == test_url assert client.client_debug is True @patch("requests.post") - def test_get_pull_secret_success( + async def test_get_pull_secret_success( self, mock_post: Mock, mock_access_token: str ) -> None: """Test successful pull secret retrieval.""" @@ -81,7 +86,7 @@ def test_get_pull_secret_success( client = InventoryClient(mock_access_token) # Access the pull_secret property to trigger lazy loading - pull_secret = client.pull_secret + pull_secret = await client.get_pull_secret() mock_post.assert_called_once_with( "https://api.openshift.com/api/accounts_mgmt/v1/access_token", @@ -91,7 +96,7 @@ def test_get_pull_secret_success( assert pull_secret == "pull-secret-content" @patch("requests.post") - def test_get_pull_secret_failure( + async def test_get_pull_secret_failure( self, mock_post: Mock, mock_access_token: str ) -> None: """Test pull secret retrieval failure.""" @@ -101,10 +106,10 @@ def test_get_pull_secret_failure( # Exception should be raised when accessing pull_secret property with pytest.raises(RequestException): - _ = client.pull_secret + _ = await client.get_pull_secret() @patch("requests.post") - def test_get_pull_secret_with_custom_url( + async def test_get_pull_secret_with_custom_url( self, mock_post: Mock, mock_access_token: str ) -> None: """Test pull secret retrieval with custom URL.""" @@ -117,7 +122,7 @@ def test_get_pull_secret_with_custom_url( client = InventoryClient(mock_access_token) # Access the pull_secret property to trigger lazy loading - _ = client.pull_secret + _ = await client.get_pull_secret() mock_post.assert_called_once_with( custom_url, @@ -368,8 +373,13 @@ async def test_create_cluster_success(self, client: InventoryClient) -> None: with ( patch.object(client, "_installer_api") as mock_installer_api, - patch.object(client, "_get_pull_secret", return_value="mock-pull-secret"), + patch("requests.post") as mock_post, ): + # Mock pull secret HTTP response + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response + mock_api = Mock() mock_api.v2_register_cluster.return_value = cluster mock_installer_api.return_value = mock_api @@ -401,8 +411,13 @@ async def test_create_cluster_single_node(self, client: InventoryClient) -> None with ( patch.object(client, "_installer_api") as mock_installer_api, - patch.object(client, "_get_pull_secret", return_value="mock-pull-secret"), + patch("requests.post") as mock_post, ): + # Mock pull secret HTTP response + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response + mock_api = Mock() mock_api.v2_register_cluster.return_value = cluster mock_installer_api.return_value = mock_api @@ -428,8 +443,13 @@ async def test_create_infra_env_success(self, client: InventoryClient) -> None: with ( patch.object(client, "_installer_api") as mock_installer_api, - patch.object(client, "_get_pull_secret", return_value="mock-pull-secret"), + patch("requests.post") as mock_post, ): + # Mock pull secret HTTP response + mock_response = Mock() + mock_response.text = "mock-pull-secret" + mock_post.return_value = mock_response + mock_api = Mock() mock_api.register_infra_env.return_value = infra_env mock_installer_api.return_value = mock_api @@ -442,7 +462,7 @@ async def test_create_infra_env_success(self, client: InventoryClient) -> None: _args, kwargs = mock_api.register_infra_env.call_args infra_env_params = kwargs["infraenv_create_params"] assert infra_env_params.name == name - assert infra_env_params.pull_secret == client.pull_secret + assert infra_env_params.pull_secret == "mock-pull-secret" @pytest.mark.asyncio async def test_update_infra_env_success(self, client: InventoryClient) -> None: diff --git a/tests/test_server.py b/tests/test_server.py index 5056631..e05b97f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -100,7 +100,7 @@ def test_get_offline_token_no_request(self) -> None: server.get_offline_token() assert "No offline token found" in str(exc_info.value) - def test_get_access_token_from_authorization_header( + async def test_get_access_token_from_authorization_header( self, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test retrieving access token from Authorization header.""" @@ -108,11 +108,11 @@ def test_get_access_token_from_authorization_header( test_token = "test-access-token" mock_request.headers.get.return_value = f"Bearer {test_token}" - result = server.get_access_token() + result = await server.get_access_token() assert result == test_token mock_request.headers.get.assert_called_once_with("Authorization") - def test_get_access_token_invalid_authorization_header( + async def test_get_access_token_invalid_authorization_header( self, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token retrieval with invalid Authorization header.""" @@ -125,10 +125,10 @@ def test_get_access_token_invalid_authorization_header( mock_response.json.return_value = {"access_token": "new-token"} mock_post.return_value = mock_response - result = server.get_access_token() + result = await server.get_access_token() assert result == "new-token" - def test_get_access_token_no_authorization_header( + async def test_get_access_token_no_authorization_header( self, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token retrieval without Authorization header.""" @@ -141,11 +141,11 @@ def test_get_access_token_no_authorization_header( mock_response.json.return_value = {"access_token": "new-token"} mock_post.return_value = mock_response - result = server.get_access_token() + result = await server.get_access_token() assert result == "new-token" @patch("requests.post") - def test_get_access_token_generate_from_offline_token( + async def test_get_access_token_generate_from_offline_token( self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test generating access token from offline token.""" @@ -160,7 +160,7 @@ def test_get_access_token_generate_from_offline_token( mock_post.return_value = mock_response with patch.object(server, "get_offline_token", return_value=offline_token): - result = server.get_access_token() + result = await server.get_access_token() assert result == access_token mock_post.assert_called_once_with( @@ -174,7 +174,7 @@ def test_get_access_token_generate_from_offline_token( ) @patch("requests.post") - def test_get_access_token_custom_sso_url( + async def test_get_access_token_custom_sso_url( self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token generation with custom SSO URL.""" @@ -191,7 +191,7 @@ def test_get_access_token_custom_sso_url( with patch.dict(os.environ, {"SSO_URL": custom_sso_url}): with patch.object(server, "get_offline_token", return_value=offline_token): - result = server.get_access_token() + result = await server.get_access_token() assert result == access_token mock_post.assert_called_once_with( @@ -205,7 +205,7 @@ def test_get_access_token_custom_sso_url( ) @patch("requests.post") - def test_get_access_token_request_failure( + async def test_get_access_token_request_failure( self, mock_post: Mock, mock_mcp_get_context: Tuple[Mock, Mock] ) -> None: """Test access token generation request failure.""" @@ -216,9 +216,9 @@ def test_get_access_token_request_failure( with patch.object(server, "get_offline_token", return_value="offline-token"): with pytest.raises(RequestException): - server.get_access_token() + await server.get_access_token() - def test_get_access_token_no_request_context(self) -> None: + async def test_get_access_token_no_request_context(self) -> None: """Test access token retrieval when no request context is available.""" mock_context = Mock() mock_context.request_context.request = None @@ -232,7 +232,7 @@ def test_get_access_token_no_request_context(self) -> None: mock_response.json.return_value = {"access_token": "new-token"} mock_post.return_value = mock_response - result = server.get_access_token() + result = await server.get_access_token() assert result == "new-token"