diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/categorizer.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/categorizer.py index d2f48418d..07f4427dd 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/categorizer.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/categorizer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import MessagesPlaceholder @@ -35,6 +37,13 @@ class CategorizerToolConfig(FunctionBaseConfig, name="categorizer"): llm_name: LLMRef +def _extract_markdown_heading_level(report: str) -> str: + """ Extract the markdown heading level from first line (report title).""" + m = re.search(r'^(#+)', report, re.MULTILINE) + pound_signs = m.group(1) if m else "#" + return pound_signs + + @register_function(config_type=CategorizerToolConfig) async def categorizer_tool(config: CategorizerToolConfig, builder: Builder): # Set up LLM and chain @@ -49,8 +58,8 @@ async def _arun(report: str) -> str: result = await categorization_chain.ainvoke({"msgs": [HumanMessage(content=report)]}) - # Extract the markdown heading level from first line of report (e.g. '#' or '##') - pound_signs = report.split('\n')[0].split(' ')[0] + # Extract the title's heading level and add an additional '#' for the section heading + pound_signs = _extract_markdown_heading_level(report) + "#" # Format the root cause category section: # - Add newlines before and after section diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/hardware_check_tool.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/hardware_check_tool.py index 0f0f5b9ea..7b9cb539e 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/hardware_check_tool.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/hardware_check_tool.py @@ -27,6 +27,15 @@ from .prompts import ToolReasoningLayerPrompts +class HardwareCheckToolConfig(FunctionBaseConfig, name="hardware_check"): + description: str = Field( + default=("This tool checks hardware health status using IPMI monitoring to detect power state, " + "hardware degradation, and anomalies that could explain alerts. Args: host_id: str"), + description="Description of the tool for the agent.") + llm_name: LLMRef + test_mode: bool = Field(default=True, description="Whether to run in test mode") + + def _get_ipmi_monitor_data(ip_address, username, password): """ Capture IPMI monitoring data using the ipmimonitoring command. @@ -58,15 +67,6 @@ def _get_ipmi_monitor_data(ip_address, username, password): return None -class HardwareCheckToolConfig(FunctionBaseConfig, name="hardware_check"): - description: str = Field( - default=("This tool checks hardware health status using IPMI monitoring to detect power state, " - "hardware degradation, and anomalies that could explain alerts. Args: host_id: str"), - description="Description of the tool for the agent.") - llm_name: LLMRef - test_mode: bool = Field(default=True, description="Whether to run in test mode") - - @register_function(config_type=HardwareCheckToolConfig) async def hardware_check_tool(config: HardwareCheckToolConfig, builder: Builder): diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/host_performance_check_tool.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/host_performance_check_tool.py index 79b762dc9..9e2015193 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/host_performance_check_tool.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/host_performance_check_tool.py @@ -92,19 +92,18 @@ async def _parse_stdout_lines(config, builder, stdout_lines): Returns: str: Structured data parsed from the output in string format. """ - # Join the list of lines into a single text block - input_data = "\n".join(stdout_lines) if stdout_lines else "" - - prompt = ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_PARSING.format(input_data=input_data) - response = None try: - response = await utils.llm_ainvoke(config, builder, user_prompt=prompt) - structured_data = response + # Join the list of lines into a single text block + input_data = "\n".join(stdout_lines) if stdout_lines else "" + + prompt = ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_PARSING.format(input_data=input_data) + + response = await utils.llm_ainvoke(config=config, builder=builder, user_prompt=prompt) except Exception as e: - structured_data = ('{{"error": "Failed to parse nvda_nim response", ' - '"exception": "{}", "raw_response": "{}"}}').format(str(e), response) - return structured_data + response = ('{{"error": "Failed to parse stdout from the playbook run.", ' + '"exception": "{}", "raw_response": "{}"}}').format(str(e), response) + return response @register_function(config_type=HostPerformanceCheckToolConfig) diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/maintenance_check.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/maintenance_check.py index a17f62944..10c3f0e6b 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/maintenance_check.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/maintenance_check.py @@ -83,6 +83,7 @@ def _load_maintenance_data(path: str) -> pd.DataFrame: required = {"host_id", "maintenance_start", "maintenance_end"} missing = required - set(df.columns) if missing: + missing = sorted(missing) utils.logger.error("Missing required columns: %s", ", ".join(missing)) raise ValueError(f"Missing required columns: {', '.join(missing)}") @@ -96,25 +97,27 @@ def _parse_alert_data(input_message: str) -> dict | None: """ Parse alert data from an input message containing JSON into a dictionary. - Note: This function assumes the input message contains exactly one JSON object (the alert) - with potential extra text before and/or after the JSON. + This function extracts and parses a JSON object from a text message that may contain + additional text before and/or after the JSON. It handles both double and single quoted + JSON strings and can parse nested JSON structures. Args: - input_message (str): Input message containing JSON alert data + input_message (str): Input message containing a JSON object, which may be surrounded + by additional text. The JSON object should contain alert details + like host_id and timestamp. Returns: - dict | None: The parsed alert data as a dictionary containing alert details, - or None if parsing fails - - Raises: - ValueError: If no JSON object is found in the input message - json.JSONDecodeError: If the JSON parsing fails + dict | None: The parsed alert data as a dictionary if successful parsing, + containing fields like host_id and timestamp. + Returns None if no valid JSON object is found or parsing fails. """ # Extract everything between first { and last } start = input_message.find("{") end = input_message.rfind("}") + 1 if start == -1 or end == 0: - raise ValueError("No JSON object found in input message") + utils.logger.error("No JSON object found in input message") + return None + alert_json_str = input_message[start:end] try: return json.loads(alert_json_str.replace("'", '"')) diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/network_connectivity_check_tool.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/network_connectivity_check_tool.py index cf85dd4d1..af157e61c 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/network_connectivity_check_tool.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/network_connectivity_check_tool.py @@ -28,6 +28,15 @@ from .prompts import ToolReasoningLayerPrompts +class NetworkConnectivityCheckToolConfig(FunctionBaseConfig, name="network_connectivity_check"): + description: str = Field( + default=("This tool checks network connectivity of a host by running ping and socket connection tests. " + "Args: host_id: str"), + description="Description of the tool for the agent.") + llm_name: LLMRef + test_mode: bool = Field(default=True, description="Whether to run in test mode") + + def _check_service_banner(host: str, port: int = 80, connect_timeout: float = 10, read_timeout: float = 10) -> str: """ Connects to host:port, reads until the Telnet banner (‘Escape character is '^]'.’) or times out. @@ -56,15 +65,6 @@ def _check_service_banner(host: str, port: int = 80, connect_timeout: float = 10 return '' -class NetworkConnectivityCheckToolConfig(FunctionBaseConfig, name="network_connectivity_check"): - description: str = Field( - default=("This tool checks network connectivity of a host by running ping and socket connection tests. " - "Args: host_id: str"), - description="Description of the tool for the agent.") - llm_name: LLMRef - test_mode: bool = Field(default=True, description="Whether to run in test mode") - - @register_function(config_type=NetworkConnectivityCheckToolConfig) async def network_connectivity_check_tool(config: NetworkConnectivityCheckToolConfig, builder: Builder): diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/run.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/run.py index e9ca041b4..d5abbfe9e 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/run.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/run.py @@ -139,19 +139,25 @@ def receive_alert(): HTTP endpoint to receive a JSON alert via POST. Expects application/json with a single alert dict or a list of alerts. """ + # use the globals-set ENV_FILE + if ENV_FILE is None: + raise ValueError("ENV_FILE must be set before processing alerts") + try: data = request.get_json(force=True) except Exception: return jsonify({"error": "Invalid JSON"}), 400 alerts = data if isinstance(data, list) else [data] + if not all(isinstance(alert, dict) for alert in alerts): + return jsonify({"error": "Alerts not represented as dictionaries"}), 400 for alert in alerts: - alert_id = alert.get('alert_id') + if 'alert_id' not in alert: + return jsonify({"error": "`alert_id` is absent in the alert payload"}), 400 + + alert_id = alert['alert_id'] processed_alerts.append(alert_id) - # use the globals-set ENV_FILE - if ENV_FILE is None: - raise ValueError("ENV_FILE must be set before processing alerts") start_process(alert, ENV_FILE) return jsonify({"received_alert_count": len(alerts), "total_launched": len(processed_alerts)}), 200 diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/telemetry_metrics_host_performance_check_tool.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/telemetry_metrics_host_performance_check_tool.py index df76440ab..391411702 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/telemetry_metrics_host_performance_check_tool.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/telemetry_metrics_host_performance_check_tool.py @@ -51,6 +51,9 @@ def _timeseries_stats(ts): Returns: str: Markdown formatted string containing summary statistics """ + if len(ts) == 0: + return "No data points" + count = len(ts) max_val = max(ts) min_val = min(ts) @@ -89,7 +92,11 @@ def _get_llm_analysis_input(timestamp_value_list): str: Formatted string containing: - JSON array of [datetime_str, value] pairs with human readable timestamps - Summary statistics of the metric values + - "No data points" if input list is empty """ + if len(timestamp_value_list) == 0: + return "No data points" + # Convert Unix timestamps to ISO format datetime strings and preserve values # Example: "2022-01-17 12:00:00" for timestamp 1642435200 data = [[datetime.fromtimestamp(entry[0]).strftime("%Y-%m-%d %H:%M:%S"), entry[1]] @@ -120,7 +127,7 @@ async def _arun(host_id: str) -> str: # Customize query based on your monitoring setup and metrics # This example queries the CPU usage percentage by subtracting idle CPU from 100% - query = '(100 - cpu_usage_idle{cpu="cpu-total",instance=~"{host_id}:9100"})' + query = f'(100 - cpu_usage_idle{{cpu="cpu-total",instance=~"{host_id}:9100"}})' url = f"{monitoring_url}/api/query_range" # Example values - users should customize these based on their monitoring requirements diff --git a/examples/alert_triage_agent/src/aiq_alert_triage_agent/utils.py b/examples/alert_triage_agent/src/aiq_alert_triage_agent/utils.py index 900513c2d..ceeb46536 100644 --- a/examples/alert_triage_agent/src/aiq_alert_triage_agent/utils.py +++ b/examples/alert_triage_agent/src/aiq_alert_triage_agent/utils.py @@ -111,14 +111,14 @@ def preload_test_data(test_data_path: str | None, benign_fallback_data_path: str def get_test_data() -> pd.DataFrame: """Returns the preloaded test data.""" if _DATA_CACHE['test_data'] is None: - raise ValueError("Test data not preloaded. Call preload_test_data() first.") + raise ValueError("Test data not preloaded. Call `preload_test_data` first.") return pd.DataFrame(_DATA_CACHE['test_data']) def _get_static_data(): """Returns the preloaded benign fallback test data.""" if _DATA_CACHE['benign_fallback_test_data'] is None: - raise ValueError("Benign fallback test data not preloaded. Call preload_test_data() first.") + raise ValueError("Benign fallback test data not preloaded. Call `preload_test_data` first.") return _DATA_CACHE['benign_fallback_test_data'] @@ -147,7 +147,7 @@ def load_column_or_static(df, host_id, column): try: return static_data[column] except KeyError as exc: - raise KeyError(f"Column '{column}' not found in static data") from exc + raise KeyError(f"Column '{column}' not found in test and benign fallback data") from exc # Column exists in DataFrame, get value for this host # Assumption: In test dataset, host_ids are unique and used to locate specific tool return values # If multiple rows found for a host_id, this indicates data inconsistency diff --git a/examples/alert_triage_agent/tests/test_categorizer.py b/examples/alert_triage_agent/tests/test_categorizer.py new file mode 100644 index 000000000..76080d6f8 --- /dev/null +++ b/examples/alert_triage_agent/tests/test_categorizer.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from aiq_alert_triage_agent.categorizer import _extract_markdown_heading_level + + +@pytest.mark.parametrize( + "test_input,expected", + [ + pytest.param("# Title", "#", id="single_hash"), + pytest.param("### Title", "###", id="multiple_hashes"), + pytest.param("No heading", "#", id="no_heading_default"), + pytest.param("", "#", id="empty_string"), + pytest.param("## My Title\n### Heading", "##", id="first_of_many"), + pytest.param("Here is a title\n## Title Line", "##", id="first_after_text"), + pytest.param("## Heading first\n# Title", "##", id="heading_precedence"), + pytest.param("###No space between # and title", "###", id="no_space_after_hashes"), + ], +) +def test_extract_markdown_heading_level(test_input, expected): + assert _extract_markdown_heading_level(test_input) == expected diff --git a/examples/alert_triage_agent/tests/test_hardware_check_tool.py b/examples/alert_triage_agent/tests/test_hardware_check_tool.py new file mode 100644 index 000000000..00c4e711b --- /dev/null +++ b/examples/alert_triage_agent/tests/test_hardware_check_tool.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from aiq_alert_triage_agent.hardware_check_tool import _get_ipmi_monitor_data + + +# Fixtures for inputs and expected command +@pytest.fixture +def ipmi_args(): + return "1.1.1.1", "test_user", "test_pass" + + +@pytest.fixture +def expected_cmd(ipmi_args): + ip, user, pwd = ipmi_args + return [ + "ipmimonitoring", + "-h", + ip, + "-u", + user, + "-p", + pwd, + "--privilege-level=USER", + ] + + +# Fixture to mock subprocess.run +@pytest.fixture +def mock_run(): + with patch('subprocess.run') as m: + yield m + + +# Parameterized test covering both success and failure +@pytest.mark.parametrize( + "stdout, side_effect, expected", + [ + # success case: subprocess returns stdout + pytest.param("Sample IPMI output", None, "Sample IPMI output", id="success"), + # failure case: subprocess raises CalledProcessError + pytest.param( + "unused output", + subprocess.CalledProcessError(returncode=1, cmd=["ipmimonitoring"], stderr="Command failed"), + None, # expected None when ipmimonitoring command raises error + id="failure"), + ]) +def test_get_ipmi_monitor_data(mock_run, ipmi_args, expected_cmd, stdout, side_effect, expected): + # configure mock + if side_effect: + mock_run.side_effect = side_effect + else: + mock_result = MagicMock() + mock_result.stdout = stdout + mock_run.return_value = mock_result + + # invoke + result = _get_ipmi_monitor_data(*ipmi_args) + + # assertions + assert result == expected + mock_run.assert_called_once_with(expected_cmd, capture_output=True, text=True, check=True) diff --git a/examples/alert_triage_agent/tests/test_host_performance_check_tool.py b/examples/alert_triage_agent/tests/test_host_performance_check_tool.py new file mode 100644 index 000000000..51269fa2a --- /dev/null +++ b/examples/alert_triage_agent/tests/test_host_performance_check_tool.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import patch + +from aiq_alert_triage_agent.host_performance_check_tool import _parse_stdout_lines +from aiq_alert_triage_agent.prompts import ToolReasoningLayerPrompts + +EXAMPLE_CPU_USAGE_OUTPUT = """ +03:45:00 PM CPU %usr %nice %sys %iowait %irq %soft %steal %guest %gnice %idle +03:45:01 PM all 60.00 0.00 5.00 1.00 0.00 0.50 0.00 0.00 0.00 33.50 +03:45:01 PM 0 95.00 0.00 3.00 0.50 0.00 0.50 0.00 0.00 0.00 1.00 +03:45:01 PM 1 25.00 0.00 7.00 1.50 0.00 0.50 0.00 0.00 0.00 66.00""" + +EXAMPLE_MEMORY_USAGE_OUTPUT = """ + total used free shared buff/cache available +Mem: 7989 1234 512 89 6243 6521 +Swap: 2047 0 2047""" + +EXAMPLE_DISK_IO_OUTPUT = """ +Device r/s w/s rkB/s wkB/s rrqm/s wrqm/s %util await svctm +sca 20.0 80.0 1024.0 4096.0 0.0 0.0 98.0 120.0 1.2""" + +EXAMPLE_LLM_PARSED_OUTPUT = json.dumps( + { + "cpu_usage": [{ + "timestamp": "03:45:01 PM", + "cpu": "all", + "user": 60.00, + "nice": 0.00, + "system": 5.00, + "iowait": 1.00, + "irq": 0.00, + "softirq": 0.50, + "steal": 0.00, + "guest": 0.00, + "gnice": 0.00, + "idle": 33.50, + }, + { + "timestamp": "03:45:01 PM", + "cpu": "0", + "user": 95.00, + "nice": 0.00, + "system": 3.00, + "iowait": 0.50, + "irq": 0.00, + "softirq": 0.50, + "steal": 0.00, + "guest": 0.00, + "gnice": 0.00, + "idle": 1.00, + }, + { + "timestamp": "03:45:01 PM", + "cpu": "1", + "user": 25.00, + "nice": 0.00, + "system": 7.00, + "iowait": 1.50, + "irq": 0.00, + "softirq": 0.50, + "steal": 0.00, + "guest": 0.00, + "gnice": 0.00, + "idle": 66.00, + }], + "memory_usage": { + "total": 7989, + "used": 1234, + "free": 512, + "shared": 89, + "buff_cache": 6243, + "available": 6521, + }, + "swap_usage": { + "total": 2047, + "used": 0, + "free": 2047, + }, + "disk_io": [{ + "device": "sca", + "read_per_sec": 20.0, + "write_per_sec": 80.0, + "read_kB_per_sec": 1024.0, + "write_kB_per_sec": 4096.0, + "read_merge_per_sec": 0.0, + "write_merge_per_sec": 0.0, + "util_percent": 98.0, + "await_ms": 120.0, + "service_time_ms": 1.2, + }] + }, + sort_keys=True) + + +async def test_parse_stdout_lines_success(): + # Test data + test_stdout_lines = [EXAMPLE_CPU_USAGE_OUTPUT, EXAMPLE_MEMORY_USAGE_OUTPUT, EXAMPLE_DISK_IO_OUTPUT] + + # Mock the LLM response + with patch('aiq_alert_triage_agent.utils.llm_ainvoke') as mock_llm: + mock_llm.return_value = EXAMPLE_LLM_PARSED_OUTPUT + + # Call the function + result = await _parse_stdout_lines( + config=None, # unused, mocked + builder=None, # unused, mocked + stdout_lines=test_stdout_lines) + + # Verify the result + assert result == EXAMPLE_LLM_PARSED_OUTPUT + + # Verify llm_ainvoke was called with correct prompt + mock_llm.assert_called_once() + call_args = mock_llm.call_args[1] + assert 'config' in call_args + assert 'builder' in call_args + assert 'user_prompt' in call_args + input_data = "\n".join(test_stdout_lines) + assert call_args['user_prompt'] == ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_PARSING.format( + input_data=input_data) + + +async def test_parse_stdout_lines_llm_error(): + # Simulate LLM throwing an exception + with patch('aiq_alert_triage_agent.utils.llm_ainvoke') as mock_llm: + mock_llm.side_effect = Exception("LLM error") + mock_llm.return_value = None + + result = await _parse_stdout_lines( + config=None, # unused, mocked + builder=None, # unused, mocked + stdout_lines=["Some test output"]) + + # Verify error is properly captured in response + assert result == ('{"error": "Failed to parse stdout from the playbook run.",' + ' "exception": "LLM error", "raw_response": "None"}') diff --git a/examples/alert_triage_agent/tests/test_maintenance_check.py b/examples/alert_triage_agent/tests/test_maintenance_check.py new file mode 100644 index 000000000..04880ce9e --- /dev/null +++ b/examples/alert_triage_agent/tests/test_maintenance_check.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import importlib.resources +import inspect +import os +import tempfile +from datetime import datetime +from pathlib import Path +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pandas as pd +import pytest +import yaml +from aiq_alert_triage_agent.maintenance_check import NO_ONGOING_MAINTENANCE_STR +from aiq_alert_triage_agent.maintenance_check import MaintenanceCheckToolConfig +from aiq_alert_triage_agent.maintenance_check import _get_active_maintenance +from aiq_alert_triage_agent.maintenance_check import _load_maintenance_data +from aiq_alert_triage_agent.maintenance_check import _parse_alert_data +from aiq_alert_triage_agent.register import AlertTriageAgentWorkflowConfig + +from aiq.builder.framework_enum import LLMFrameworkEnum +from aiq.builder.workflow_builder import WorkflowBuilder +from aiq.data_models.component_ref import LLMRef + + +def test_load_maintenance_data(): + # Load paths from config like in test_utils.py + package_name = inspect.getmodule(AlertTriageAgentWorkflowConfig).__package__ + config_file: Path = importlib.resources.files(package_name).joinpath("configs", "config_test_mode.yml").absolute() + with open(config_file, "r") as file: + config = yaml.safe_load(file) + maintenance_data_path = config["functions"]["maintenance_check"]["static_data_path"] + maintenance_data_path_abs = importlib.resources.files(package_name).joinpath("../../../../", + maintenance_data_path).absolute() + + # Test successful loading with actual maintenance data file + df = _load_maintenance_data(maintenance_data_path_abs) + + # Verify DataFrame structure + assert isinstance(df, pd.DataFrame) + assert not df.empty + required_columns = {"host_id", "maintenance_start", "maintenance_end"} + assert all(col in df.columns for col in required_columns) + + # Verify data types + assert pd.api.types.is_datetime64_dtype(df["maintenance_start"]) + assert pd.api.types.is_datetime64_dtype(df["maintenance_end"]) + + # Test with missing required columns + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: + try: + # Create CSV with missing columns + f.write("host_id,some_other_column\n") + f.write("test-host,value\n") + f.flush() + + with pytest.raises(ValueError, match="Missing required columns: maintenance_end, maintenance_start"): + _load_maintenance_data(f.name) + finally: + os.unlink(f.name) + + # Test with non-existent file + with pytest.raises(FileNotFoundError): + _load_maintenance_data("nonexistent.csv") + + +@pytest.mark.parametrize( + "input_msg,expected", + [ + pytest.param("Alert received: {'host_id': 'server1', 'timestamp': '2024-03-21T10:00:00.000'} - Please check", { + "host_id": "server1", "timestamp": "2024-03-21T10:00:00.000" + }, + id="valid_json_with_surrounding_text"), + pytest.param('{"host_id": "server2", "timestamp": "2024-03-21T11:00:00.000"}', { + "host_id": "server2", "timestamp": "2024-03-21T11:00:00.000" + }, + id="clean_json_without_surrounding_text"), + pytest.param("{'host_id': 'server3', 'timestamp': '2024-03-21T12:00:00.000'}", { + "host_id": "server3", "timestamp": "2024-03-21T12:00:00.000" + }, + id="json_with_single_quotes"), + pytest.param("This is a message with no JSON", None, id="no_json_in_input"), + pytest.param("Alert: {invalid json format} received", None, id="invalid_json_format"), + pytest.param("{'host_id': 'server1'} {'host_id': 'server2'}", None, id="multiple_json_objects"), + pytest.param( + ("Nested JSON Alert: {'host_id': 'server4', 'details': {'location': 'rack1', 'metrics': " + "{'cpu': 90, 'memory': 85}}, 'timestamp': '2024-03-21T13:00:00.000'}"), + { + "host_id": "server4", + "details": { + "location": "rack1", "metrics": { + "cpu": 90, "memory": 85 + } + }, + "timestamp": "2024-03-21T13:00:00.000" + }, + id="nested_json_structure"), + pytest.param("Alert received:\n{'host_id': 'server5', 'timestamp': '2024-03-21T14:00:00.000'}\nPlease check", { + "host_id": "server5", "timestamp": "2024-03-21T14:00:00.000" + }, + id="json_with_newlines"), + ]) +def test_parse_alert_data(input_msg, expected): + result = _parse_alert_data(input_msg) + assert result == expected + + +def test_get_active_maintenance(): + # Create test data + test_data = { + 'host_id': ['host1', 'host1', 'host2', 'host3', 'host4'], + 'maintenance_start': [ + '2024-03-21 09:00:00', # Active maintenance with end time + '2024-03-21 14:00:00', # Future maintenance + '2024-03-21 09:00:00', # Ongoing maintenance (no end time) + '2024-03-21 08:00:00', # Past maintenance + '2024-03-21 09:00:00', # Different host + ], + 'maintenance_end': [ + '2024-03-21 11:00:00', + '2024-03-21 16:00:00', + None, + '2024-03-21 09:00:00', + '2024-03-21 11:00:00', + ] + } + df = pd.DataFrame(test_data) + df['maintenance_start'] = pd.to_datetime(df['maintenance_start']) + df['maintenance_end'] = pd.to_datetime(df['maintenance_end']) + + # Test 1: Active maintenance with end time + alert_time = datetime(2024, 3, 21, 10, 0, 0) + result = _get_active_maintenance(df, 'host1', alert_time) + assert result is not None + start_str, end_str = result + assert start_str == '2024-03-21 09:00:00' + assert end_str == '2024-03-21 11:00:00' + + # Test 2: No active maintenance (future maintenance) + alert_time = datetime(2024, 3, 21, 13, 0, 0) + result = _get_active_maintenance(df, 'host1', alert_time) + assert result is None + + # Test 3: Ongoing maintenance (no end time) + alert_time = datetime(2024, 3, 21, 10, 0, 0) + result = _get_active_maintenance(df, 'host2', alert_time) + assert result is not None + start_str, end_str = result + assert start_str == '2024-03-21 09:00:00' + assert end_str == '' # Empty string for ongoing maintenance + + # Test 4: Past maintenance + alert_time = datetime(2024, 3, 21, 10, 0, 0) + result = _get_active_maintenance(df, 'host3', alert_time) + assert result is None + + # Test 5: Non-existent host + alert_time = datetime(2024, 3, 21, 10, 0, 0) + result = _get_active_maintenance(df, 'host5', alert_time) + assert result is None + + +async def test_maintenance_check_tool(): + # Create a temporary maintenance data file + test_data = { + 'host_id': ['host1', 'host2'], + 'maintenance_start': ['2024-03-21 09:00:00', '2024-03-21 09:00:00'], + 'maintenance_end': ['2024-03-21 11:00:00', None] + } + # Test cases + test_cases = [ + # Test 1: Valid alert during maintenance + { + 'input': "{'host_id': 'host1', 'timestamp': '2024-03-21T10:00:00.000'}", + 'expected_maintenance': True, + 'mock_summary': 'Maintenance summary report' + }, + # Test 2: Valid alert not during maintenance + { + 'input': "{'host_id': 'host1', 'timestamp': '2024-03-21T12:00:00.000'}", 'expected_maintenance': False + }, + # Test 3: Invalid JSON format + { + 'input': "Invalid JSON data", 'expected_maintenance': False + }, + # Test 4: Missing required fields + { + 'input': "{'host_id': 'host1'}", # Missing timestamp + 'expected_maintenance': False + }, + # Test 5: Invalid timestamp format + { + 'input': "{'host_id': 'host1', 'timestamp': 'invalid-time'}", 'expected_maintenance': False + }, + # Test 6: Host under ongoing maintenance (no end time) + { + 'input': "{'host_id': 'host2', 'timestamp': '2024-03-21T10:00:00.000'}", + 'expected_maintenance': True, + 'mock_summary': 'Ongoing maintenance summary' + } + ] + + # Create a temporary CSV file to store test maintenance data + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: + try: + # Write test data to CSV file + df = pd.DataFrame(test_data) + df.to_csv(f.name, index=False) + f.flush() + + # Set up mock builder and LLM + mock_builder = AsyncMock() + mock_llm = MagicMock() + mock_builder.get_llm.return_value = mock_llm + + # Configure maintenance check tool + config = MaintenanceCheckToolConfig( + llm_name=LLMRef(value="dummy"), + description="direct test", + static_data_path=f.name, + ) + + # Initialize workflow builder and add maintenance check function + async with WorkflowBuilder() as builder: + builder.get_llm = mock_builder.get_llm + await builder.add_function("maintenance_check", config) + maintenance_check_tool = builder.get_tool("maintenance_check", wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + # Run test cases + for case in test_cases: + # Mock the alert summarization function + with patch('aiq_alert_triage_agent.maintenance_check._summarize_alert') as mock_summarize: + if case['expected_maintenance']: + mock_summarize.return_value = case['mock_summary'] + + # Invoke maintenance check tool with test input + result = await maintenance_check_tool.ainvoke(input=case['input']) + + # Verify results based on whether maintenance was expected + if case['expected_maintenance']: + assert result == case['mock_summary'] + mock_summarize.assert_called_once() + mock_summarize.reset_mock() + else: + assert result == NO_ONGOING_MAINTENANCE_STR + mock_summarize.assert_not_called() + + finally: + # Clean up temporary file + os.unlink(f.name) diff --git a/examples/alert_triage_agent/tests/test_monitoring_process_check_tool.py b/examples/alert_triage_agent/tests/test_monitoring_process_check_tool.py new file mode 100644 index 000000000..90d213101 --- /dev/null +++ b/examples/alert_triage_agent/tests/test_monitoring_process_check_tool.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import patch + +from aiq_alert_triage_agent.monitoring_process_check_tool import _run_ansible_playbook_for_monitor_process_check +from aiq_alert_triage_agent.playbooks import MONITOR_PROCESS_CHECK_PLAYBOOK + + +async def test_run_ansible_playbook_for_monitor_process_check(): + # Test data + ansible_host = "test.example.com" + ansible_user = "testuser" + ansible_port = 22 + ansible_private_key_path = "/path/to/key.pem" + + # Mock playbook output + mock_playbook_output = { + "task_results": [{ + "task": "Check process status", + "host": ansible_host, + "result": { + "cmd": + "ps aux | grep monitoring", + "stdout_lines": [ + "user1 1234 0.0 0.2 12345 5678 ? Ss 10:00 0:00 /usr/bin/monitoring-agent", + "user1 5678 2.0 1.0 23456 7890 ? Sl 10:01 0:05 /usr/bin/monitoring-collector" + ] + } + }, + { + "task": "Check service status", + "host": ansible_host, + "result": { + "cmd": + "systemctl status monitoring-service", + "stdout_lines": [ + "● monitoring-service.service - Monitoring Service", " Active: active (running)" + ] + } + }] + } + + # Mock the run_ansible_playbook function + with patch("aiq_alert_triage_agent.utils.run_ansible_playbook", new_callable=AsyncMock) as mock_run: + mock_run.return_value = mock_playbook_output + + # Call the function + result = await _run_ansible_playbook_for_monitor_process_check( + ansible_host=ansible_host, + ansible_user=ansible_user, + ansible_port=ansible_port, + ansible_private_key_path=ansible_private_key_path) + + # Verify run_ansible_playbook was called with correct arguments + mock_run.assert_called_once_with(playbook=MONITOR_PROCESS_CHECK_PLAYBOOK, + ansible_host=ansible_host, + ansible_user=ansible_user, + ansible_port=ansible_port, + ansible_private_key_path=ansible_private_key_path) + + # Verify the result structure + assert isinstance(result, list) + assert len(result) == 2 + + # Verify first task details + first_task = result[0] + assert first_task["task"] == "Check process status" + assert first_task["host"] == ansible_host + assert first_task["cmd"] == "ps aux | grep monitoring" + assert len(first_task["stdout_lines"]) == 2 + assert "monitoring-agent" in first_task["stdout_lines"][0] + assert "monitoring-collector" in first_task["stdout_lines"][1] + + # Verify second task details + second_task = result[1] + assert second_task["task"] == "Check service status" + assert second_task["host"] == ansible_host + assert second_task["cmd"] == "systemctl status monitoring-service" + assert len(second_task["stdout_lines"]) == 2 + assert "monitoring-service.service" in second_task["stdout_lines"][0] + assert "Active: active" in second_task["stdout_lines"][1] diff --git a/examples/alert_triage_agent/tests/test_network_connectivity_check_tool.py b/examples/alert_triage_agent/tests/test_network_connectivity_check_tool.py new file mode 100644 index 000000000..9a0140a71 --- /dev/null +++ b/examples/alert_triage_agent/tests/test_network_connectivity_check_tool.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from aiq_alert_triage_agent.network_connectivity_check_tool import _check_service_banner + + +@pytest.fixture +def mock_sock(): + """A reusable mock socket whose recv and settimeout we can configure.""" + sock = MagicMock() + return sock + + +@patch('socket.create_connection') +def test_successful_banner_read(mock_create_conn, mock_sock): + # Simulate a two‐chunk banner (one before the pattern, the pattern itself) then EOF + mock_sock.recv.side_effect = [ + b"Welcome to test server\n", + b"Escape character is '^]'.\n", + b"" # EOF + ] + mock_create_conn.return_value.__enter__.return_value = mock_sock + + result = _check_service_banner("my.host", port=8080) + assert "Welcome to test server" in result + assert "Escape character is '^]'." in result + + mock_create_conn.assert_called_once_with(("my.host", 8080), timeout=10) + mock_sock.settimeout.assert_called_once_with(10) + + +@pytest.mark.parametrize( + "side_effect, port, conn_to, read_to", + [ + (socket.timeout(), 80, 10, 10), + (ConnectionRefusedError(), 80, 10, 10), + (OSError(), 1234, 5, 2), + ], +) +@patch('socket.create_connection') +def test_error_conditions(mock_create_conn, side_effect, port, conn_to, read_to): + """ + If create_connection raises timeout/conn refused/OS error, + _check_service_banner should return empty string and + propagate the connection parameters correctly. + """ + mock_create_conn.side_effect = side_effect + + result = _check_service_banner("any.host", port=port, connect_timeout=conn_to, read_timeout=read_to) + assert result == "" + mock_create_conn.assert_called_once_with(("any.host", port), timeout=conn_to) + + +@patch('socket.create_connection') +def test_reading_until_eof_without_banner(mock_create_conn, mock_sock): + """ + If the server never emits the banner and closes the connection, + we should still return whatever was read before EOF (even empty). + """ + # Single empty chunk simulates immediate EOF + mock_sock.recv.side_effect = [b""] + mock_create_conn.return_value.__enter__.return_value = mock_sock + + result = _check_service_banner("no.banner.host") + assert result == "" # nothing was ever received + + mock_create_conn.assert_called_once_with(("no.banner.host", 80), timeout=10) + mock_sock.settimeout.assert_called_once_with(10) diff --git a/examples/alert_triage_agent/tests/test_run.py b/examples/alert_triage_agent/tests/test_run.py new file mode 100644 index 000000000..f379f342c --- /dev/null +++ b/examples/alert_triage_agent/tests/test_run.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import patch + +import pytest +from aiq_alert_triage_agent import run + + +@pytest.fixture +def client(): + """Create a test client for the Flask application.""" + run.app.config['TESTING'] = True + with run.app.test_client() as client: + yield client + + +@pytest.fixture(autouse=True) +def reset_global_state(): + """Reset global state before each test.""" + run.processed_alerts = [] + run.ENV_FILE = '.placeholder_env_file_value' + + +def test_hsts_header(client): + """Test that HSTS header is properly set.""" + response = client.get('/') + assert response.headers['Strict-Transport-Security'] == 'max-age=31536000; includeSubDomains; preload' + + +@pytest.mark.parametrize('alert', + [{ + "alert_id": 1, + "alert_name": "InstanceDown", + "host_id": "test-instance-1.example.com", + "severity": "critical", + "description": "Test description", + "summary": "Test summary", + "timestamp": "2025-04-28T05:00:00.000000" + }, + { + "alert_id": 2, + "alert_name": "CPUUsageHighError", + "host_id": "test-instance-2.example.com", + "severity": "warning", + "description": "High CPU usage", + "summary": "CPU at 95%", + "timestamp": "2025-04-28T06:00:00.000000" + }]) +def test_receive_single_alert(client, alert): + """Test receiving a single alert with different alert types.""" + with patch('aiq_alert_triage_agent.run.start_process') as mock_start_process: + response = client.post('/alerts', data=json.dumps(alert), content_type='application/json') + + data = json.loads(response.data) + assert response.status_code == 200 + assert data['received_alert_count'] == 1 + assert data['total_launched'] == 1 + mock_start_process.assert_called_once() + + +def test_receive_multiple_alerts(client): + """Test receiving multiple alerts in a single request with different counts.""" + alert_count = 3 + test_alerts = [{ + "alert_id": i, + "alert_name": f"TestAlert{i}", + "host_id": f"test-instance-{i}.example.com", + "severity": "critical", + "timestamp": "2025-04-28T05:00:00.000000" + } for i in range(alert_count)] + + with patch('aiq_alert_triage_agent.run.start_process') as mock_start_process: + response = client.post('/alerts', data=json.dumps(test_alerts), content_type='application/json') + + data = json.loads(response.data) + assert response.status_code == 200 + assert data['received_alert_count'] == alert_count + assert data['total_launched'] == alert_count + assert mock_start_process.call_count == alert_count + + # post again to test that the total_launched is cumulative + response = client.post('/alerts', data=json.dumps(test_alerts), content_type='application/json') + + data = json.loads(response.data) + assert response.status_code == 200 + assert data['received_alert_count'] == alert_count + assert data['total_launched'] == alert_count * 2 + assert mock_start_process.call_count == alert_count * 2 + + +@pytest.mark.parametrize( + 'invalid_data,expected_error', + [ + pytest.param('invalid json', 'Invalid JSON', id='invalid_syntax'), + pytest.param('{incomplete json', 'Invalid JSON', id='incomplete_json'), + pytest.param('[1, 2, 3]', "Alerts not represented as dictionaries", + id='wrong_alert_format'), # Valid JSON but invalid alert format + pytest.param('{"key": "value"}', "`alert_id` is absent in the alert payload", + id='missing_alert_id') # Valid JSON but invalid alert format + ]) +def test_invalid_json(client, invalid_data, expected_error): + """Test handling of various invalid JSON data formats.""" + response = client.post('/alerts', data=invalid_data, content_type='application/json') + + assert response.status_code == 400 + data = json.loads(response.data) + assert data['error'] == expected_error + + +@pytest.mark.parametrize( + 'args,expected', + [ + pytest.param(['--host', '127.0.0.1', '--port', '8080', '--env_file', '/custom/.env'], { + 'host': '127.0.0.1', 'port': 8080, 'env_file': '/custom/.env' + }, + id='custom_host_port_env_file'), + pytest.param([], { + 'host': '0.0.0.0', 'port': 5000, 'env_file': '.env' + }, id='default_args'), + pytest.param(['--port', '3000'], { + 'host': '0.0.0.0', 'port': 3000, 'env_file': '.env' + }, id='partial_override') + ]) +def test_parse_args(args, expected): + """Test command line argument parsing with different argument combinations.""" + with patch('sys.argv', ['script.py'] + args): + parsed_args = run.parse_args() + assert parsed_args.host == expected['host'] + assert parsed_args.port == expected['port'] + assert parsed_args.env_file == expected['env_file'] diff --git a/examples/alert_triage_agent/tests/test_telemetry_metrics_host_heartbeat_check_tool.py b/examples/alert_triage_agent/tests/test_telemetry_metrics_host_heartbeat_check_tool.py new file mode 100644 index 000000000..2bcb2c9bf --- /dev/null +++ b/examples/alert_triage_agent/tests/test_telemetry_metrics_host_heartbeat_check_tool.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +import requests +from aiq_alert_triage_agent.telemetry_metrics_host_heartbeat_check_tool import \ + TelemetryMetricsHostHeartbeatCheckToolConfig + +from aiq.builder.framework_enum import LLMFrameworkEnum +from aiq.builder.workflow_builder import WorkflowBuilder +from aiq.data_models.component_ref import LLMRef + + +async def test_telemetry_metrics_host_heartbeat_check_tool(): + # Test cases with expected API responses and outcomes + test_cases = [ + # Test 1: Host is up and reporting metrics + { + 'host_id': 'host1', + 'api_response': { + 'data': { + 'result': [{ + 'metric': { + 'instance': 'host1:9100' + }, + 'value': [1234567890, '1'] # Timestamp and "up" value + }] + } + }, + 'expected_success': True, + 'mock_llm_conclusion': 'Host host1 is up and reporting metrics normally.' + }, + # Test 2: Host is down (no metrics reported) + { + 'host_id': 'host2', + 'api_response': { + 'data': { + 'result': [] # Empty result indicates no metrics reported + } + }, + 'expected_success': True, + 'mock_llm_conclusion': 'Host host2 appears to be down - no heartbeat metrics reported.' + }, + # Test 3: API error scenario + { + 'host_id': 'host3', + 'api_error': requests.exceptions.RequestException('Connection failed'), + 'expected_success': False + } + ] + + # Configure the tool + config = TelemetryMetricsHostHeartbeatCheckToolConfig( + llm_name=LLMRef(value="dummy"), + test_mode=False, # Important: testing in live mode + metrics_url="http://test-monitoring-system:9090") + + # Set up mock builder and LLM + mock_builder = AsyncMock() + mock_llm = MagicMock() + mock_builder.get_llm.return_value = mock_llm + + # Initialize workflow builder and add the function + async with WorkflowBuilder() as builder: + builder.get_llm = mock_builder.get_llm + await builder.add_function("telemetry_metrics_host_heartbeat_check", config) + heartbeat_check_tool = builder.get_tool("telemetry_metrics_host_heartbeat_check", + wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + # Run test cases + for case in test_cases: + # Mock the requests.get call + with patch('requests.get') as mock_get, \ + patch('aiq_alert_triage_agent.utils.llm_ainvoke') as mock_llm_invoke: + + if 'api_error' in case: + # Simulate API error + mock_get.side_effect = case['api_error'] + else: + # Mock successful API response + mock_response = MagicMock() + mock_response.json.return_value = case['api_response'] + mock_get.return_value = mock_response + + if case['expected_success']: + # Set up LLM mock response for successful cases + mock_llm_invoke.return_value = case['mock_llm_conclusion'] + + # Invoke tool and verify results + result = await heartbeat_check_tool.ainvoke(input=case['host_id']) + + # Verify the result matches expected LLM conclusion + assert result == case['mock_llm_conclusion'] + + # Verify API call was made correctly + mock_get.assert_called_once() + args, kwargs = mock_get.call_args + assert kwargs['params']['query'] == f'up{{instance=~"{case["host_id"]}:9100"}}' + + # Verify LLM was called + mock_llm_invoke.assert_called_once() + else: + # Test error case + with pytest.raises(requests.exceptions.RequestException): + await heartbeat_check_tool.ainvoke(input=case['host_id']) diff --git a/examples/alert_triage_agent/tests/test_telemetry_metrics_host_performance_check_tool.py b/examples/alert_triage_agent/tests/test_telemetry_metrics_host_performance_check_tool.py new file mode 100644 index 000000000..05cf4b14f --- /dev/null +++ b/examples/alert_triage_agent/tests/test_telemetry_metrics_host_performance_check_tool.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from datetime import datetime +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +import requests +from aiq_alert_triage_agent.telemetry_metrics_host_performance_check_tool import \ + TelemetryMetricsHostPerformanceCheckToolConfig +from aiq_alert_triage_agent.telemetry_metrics_host_performance_check_tool import _get_llm_analysis_input +from aiq_alert_triage_agent.telemetry_metrics_host_performance_check_tool import _timeseries_stats + +from aiq.builder.framework_enum import LLMFrameworkEnum +from aiq.builder.workflow_builder import WorkflowBuilder +from aiq.data_models.component_ref import LLMRef + + +async def test_telemetry_metrics_host_performance_check_tool(): + # Test cases with expected API responses and outcomes + test_cases = [ + # Test 1: Normal CPU usage pattern + { + 'host_id': 'host1', + 'api_response': { + 'data': { + 'result': [{ + 'values': [ + [1642435200, "45.2"], # Example timestamp and CPU usage + [1642438800, "47.8"], + [1642442400, "42.5"], + ] + }] + } + }, + 'expected_success': True, + 'mock_llm_conclusion': 'CPU usage for host1 shows normal patterns with average utilization around 45%.' + }, + # Test 2: High CPU usage pattern + { + 'host_id': + 'host2', + 'api_response': { + 'data': { + 'result': [{ + 'values': [ + [1642435200, "85.2"], + [1642438800, "87.8"], + [1642442400, "92.5"], + ] + }] + } + }, + 'expected_success': + True, + 'mock_llm_conclusion': + 'Host host2 shows consistently high CPU utilization above 85%, indicating potential performance issues.' + }, + # Test 3: API error scenario + { + 'host_id': 'host3', + 'api_error': requests.exceptions.RequestException('Connection failed'), + 'expected_success': False + } + ] + + # Configure the tool + config = TelemetryMetricsHostPerformanceCheckToolConfig( + llm_name=LLMRef(value="dummy"), + test_mode=False, # Testing in live mode + metrics_url="http://test-monitoring-system:9090") + + # Set up mock builder and LLM + mock_builder = AsyncMock() + mock_llm = MagicMock() + mock_builder.get_llm.return_value = mock_llm + + # Initialize workflow builder and add the function + async with WorkflowBuilder() as builder: + builder.get_llm = mock_builder.get_llm + await builder.add_function("telemetry_metrics_host_performance_check", config) + performance_check_tool = builder.get_tool("telemetry_metrics_host_performance_check", + wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + # Run test cases + for case in test_cases: + # Mock the requests.get call + with patch('requests.get') as mock_get, \ + patch('aiq_alert_triage_agent.utils.llm_ainvoke') as mock_llm_invoke: + + if 'api_error' in case: + # Simulate API error + mock_get.side_effect = case['api_error'] + else: + # Mock successful API response + mock_response = MagicMock() + mock_response.json.return_value = case['api_response'] + mock_get.return_value = mock_response + + if case['expected_success']: + # Set up LLM mock response for successful cases + mock_llm_invoke.return_value = case['mock_llm_conclusion'] + + # Invoke tool and verify results + result = await performance_check_tool.ainvoke(input=case['host_id']) + + # Verify the result matches expected LLM conclusion + assert result == case['mock_llm_conclusion'] + + # Verify API call was made correctly + mock_get.assert_called_once() + args, kwargs = mock_get.call_args + + # Verify the query parameters + params = kwargs['params'] + host_id = case["host_id"] + assert params['query'] == f'(100 - cpu_usage_idle{{cpu="cpu-total",instance=~"{host_id}:9100"}})' + assert 'step' in params + # Should parse without error + datetime.fromisoformat(params['start'].replace('Z', '+00:00')) + datetime.fromisoformat(params['end'].replace('Z', '+00:00')) + + # Verify LLM was called with processed data + mock_llm_invoke.assert_called_once() + # Verify LLM was called with correctly formatted data input + llm_call_args = mock_llm_invoke.call_args + user_prompt = llm_call_args[1]['user_prompt'] + assert user_prompt.startswith('Timeseries:\n') # Check format starts with timeseries + assert '\n\nTime Series Statistics' in user_prompt # Check statistics section exists + assert all(stat in user_prompt for stat in [ + 'Number of Data Points:', 'Maximum Value:', 'Minimum Value:', 'Mean Value:', 'Median Value:' + ]) # Check all statistics are present + + else: + # Test error case + with pytest.raises(requests.exceptions.RequestException): + await performance_check_tool.ainvoke(input=case['host_id']) + + +def test_timeseries_stats(): + # Test case 1: Normal sequence of values + ts1 = [45.2, 47.8, 42.5, 44.1, 46.3] + result1 = _timeseries_stats(ts1) + + # Verify all expected statistics are present + assert 'Number of Data Points: 5' in result1 + assert 'Maximum Value: 47.8' in result1 + assert 'Minimum Value: 42.5' in result1 + assert 'Mean Value: 45.18' in result1 # 225.9/5 + assert 'Median Value: 45.2' in result1 + + # Test case 2: Single value + ts2 = [42.0] + result2 = _timeseries_stats(ts2) + assert 'Number of Data Points: 1' in result2 + assert 'Maximum Value: 42.0' in result2 + assert 'Minimum Value: 42.0' in result2 + assert 'Mean Value: 42.00' in result2 + assert 'Median Value: 42.0' in result2 + + # Test case 3: Empty list + ts3 = [] + result3 = _timeseries_stats(ts3) + assert "No data points" == result3 + + # Test case 4: List with integer values + ts4 = [1, 2, 3, 4, 5] + result4 = _timeseries_stats(ts4) + assert 'Number of Data Points: 5' in result4 + assert 'Maximum Value: 5' in result4 + assert 'Minimum Value: 1' in result4 + assert 'Mean Value: 3.00' in result4 + assert 'Median Value: 3' in result4 + + +def test_get_llm_analysis_input(): + # Test case 1: Normal sequence of timestamp-value pairs + def to_timestamp(date_str): + return int(datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S").timestamp()) + + timestamp_value_list1 = [[to_timestamp("2025-04-17 12:00:00"), + "45.2"], [to_timestamp("2025-04-17 13:00:00"), "47.8"], + [to_timestamp("2025-04-17 14:00:00"), "42.5"]] + result1 = _get_llm_analysis_input(timestamp_value_list1) + + # Parse the JSON part of the output + timeseries_str = result1.split('\n\n')[0].replace('Timeseries:\n', '') + timeseries_data = json.loads(timeseries_str) + + # Verify timestamp conversion and format + assert len(timeseries_data) == 3 + assert timeseries_data[0][0] == "2025-04-17 12:00:00" + assert timeseries_data[0][1] == "45.2" + + # Verify statistics section exists and contains all required fields + assert 'Time Series Statistics' in result1 + assert 'Number of Data Points: 3' in result1 + assert 'Maximum Value: 47.8' in result1 + assert 'Minimum Value: 42.5' in result1 + assert 'Mean Value: 45.17' in result1 + assert 'Median Value: 45.2' in result1 + + # Test case 2: Single timestamp-value pair + timestamp_value_list2 = [[to_timestamp("2025-04-20 10:00:00"), "82.0"]] + result2 = _get_llm_analysis_input(timestamp_value_list2) + + timeseries_str2 = result2.split('\n\n')[0].replace('Timeseries:\n', '') + timeseries_data2 = json.loads(timeseries_str2) + + assert len(timeseries_data2) == 1 + assert timeseries_data2[0][0] == "2025-04-20 10:00:00" + assert timeseries_data2[0][1] == "82.0" + assert 'Number of Data Points: 1' in result2 + + # Test case 3: Empty list + timestamp_value_list3 = [] + result3 = _get_llm_analysis_input(timestamp_value_list3) + assert "No data points" == result3 + + # Test case 4: Mixed numeric types (integers and floats) + timestamp_value_list4 = [ + [to_timestamp("2025-04-17 12:00:00"), "100"], # Integer value + [to_timestamp("2025-04-17 13:00:00"), "47.8"], # Float value + [to_timestamp("2025-04-17 14:00:00"), "50"] # Integer value + ] + result4 = _get_llm_analysis_input(timestamp_value_list4) + + timeseries_str4 = result4.split('\n\n')[0].replace('Timeseries:\n', '') + timeseries_data4 = json.loads(timeseries_str4) + + assert len(timeseries_data4) == 3 + assert all(isinstance(entry[1], str) for entry in timeseries_data4) # All values should be strings + assert 'Maximum Value: 100' in result4 + assert 'Minimum Value: 47.8' in result4 diff --git a/examples/alert_triage_agent/tests/test_utils.py b/examples/alert_triage_agent/tests/test_utils.py new file mode 100644 index 000000000..7851f2744 --- /dev/null +++ b/examples/alert_triage_agent/tests/test_utils.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import importlib.resources +import inspect +from pathlib import Path +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pandas as pd +import pytest +import yaml +from aiq_alert_triage_agent.register import AlertTriageAgentWorkflowConfig +from aiq_alert_triage_agent.utils import _DATA_CACHE +from aiq_alert_triage_agent.utils import _LLM_CACHE +from aiq_alert_triage_agent.utils import _get_llm +from aiq_alert_triage_agent.utils import load_column_or_static +from aiq_alert_triage_agent.utils import preload_test_data +from aiq_alert_triage_agent.utils import run_ansible_playbook + +from aiq.builder.framework_enum import LLMFrameworkEnum + + +async def test_get_llm(): + # Clear the cache before test + _LLM_CACHE.clear() + + llm_name_1 = "test_llm" + llm_name_2 = "different_llm" + wrapper_type = LLMFrameworkEnum.LANGCHAIN + + # Create mock builder + mock_builder = MagicMock() + llms = { + (llm_name_1, wrapper_type): object(), + (llm_name_2, wrapper_type): object(), + } + mock_builder.get_llm = AsyncMock(side_effect=lambda llm_name, wrapper_type: llms[(llm_name, wrapper_type)]) + + # Test first call - should create new LLM + result = await _get_llm(mock_builder, llm_name_1, wrapper_type) + + # Verify LLM was created with correct parameters + mock_builder.get_llm.assert_called_once_with(llm_name=llm_name_1, wrapper_type=wrapper_type) + assert result is llms[(llm_name_1, wrapper_type)] + + # Verify cache state after first call + assert len(_LLM_CACHE) == 1 + assert _LLM_CACHE[(llm_name_1, wrapper_type)] is llms[(llm_name_1, wrapper_type)] + + # Test second call with same parameters - should return cached LLM + result2 = await _get_llm(mock_builder, llm_name_1, wrapper_type) + + # Verify get_llm was not called again + mock_builder.get_llm.assert_called_once() + assert result2 is llms[(llm_name_1, wrapper_type)] + + # Verify cache state hasn't changed + assert len(_LLM_CACHE) == 1 + assert _LLM_CACHE[(llm_name_1, wrapper_type)] is llms[(llm_name_1, wrapper_type)] + + # Test with different parameters - should create new LLM + result3 = await _get_llm(mock_builder, llm_name_2, wrapper_type) + + # Verify get_llm was called again with new parameters + assert mock_builder.get_llm.call_count == 2 + mock_builder.get_llm.assert_called_with(llm_name=llm_name_2, wrapper_type=wrapper_type) + assert result3 is llms[(llm_name_2, wrapper_type)] + + # Verify cache state after adding second LLM + assert len(_LLM_CACHE) == 2 + assert _LLM_CACHE[(llm_name_1, wrapper_type)] is llms[(llm_name_1, wrapper_type)] + assert _LLM_CACHE[(llm_name_2, wrapper_type)] is llms[(llm_name_2, wrapper_type)] + + +def test_preload_test_data(): + # Clear the data cache before test + _DATA_CACHE.clear() + _DATA_CACHE.update({'test_data': None, 'benign_fallback_test_data': None}) + + # Load paths from config + package_name = inspect.getmodule(AlertTriageAgentWorkflowConfig).__package__ + config_file: Path = importlib.resources.files(package_name).joinpath("configs", "config_test_mode.yml").absolute() + with open(config_file, "r") as file: + config = yaml.safe_load(file) + test_data_path = config["workflow"]["test_data_path"] + benign_fallback_data_path = config["workflow"]["benign_fallback_data_path"] + test_data_path_abs = importlib.resources.files(package_name).joinpath("../../../../", test_data_path).absolute() + benign_fallback_data_path_abs = importlib.resources.files(package_name).joinpath( + "../../../../", benign_fallback_data_path).absolute() + + # Test successful loading with actual test files + preload_test_data(test_data_path_abs, benign_fallback_data_path_abs) + + # Verify data was loaded correctly + assert len(_DATA_CACHE) == 2 + assert isinstance(_DATA_CACHE['test_data'], pd.DataFrame) + assert isinstance(_DATA_CACHE['benign_fallback_test_data'], dict) + assert not _DATA_CACHE['test_data'].empty + assert len(_DATA_CACHE['benign_fallback_test_data']) > 0 + + # Test error cases + with pytest.raises(ValueError, match="test_data_path must be provided"): + preload_test_data(None, benign_fallback_data_path) + + with pytest.raises(ValueError, match="benign_fallback_data_path must be provided"): + preload_test_data(test_data_path, None) + + # Test with non-existent files + with pytest.raises(FileNotFoundError): + preload_test_data("nonexistent.csv", benign_fallback_data_path) + + with pytest.raises(FileNotFoundError): + preload_test_data(test_data_path, "nonexistent.json") + + +def test_load_column_or_static(): + # Clear and initialize the data cache with test data + _DATA_CACHE.clear() + _DATA_CACHE.update({ + 'test_data': None, + 'benign_fallback_test_data': { + 'static_column': 'static_value', 'another_static': 'another_value' + } + }) + + # Create test DataFrame + df = pd.DataFrame({ + 'host_id': ['host1', 'host2', 'host3'], + 'string_column': ['value1', 'value2', 'value3'], + 'integer_column': [1, 2, 3] + }) + + # Test successful DataFrame column access + assert load_column_or_static(df, 'host1', 'string_column') == 'value1' + assert load_column_or_static(df, 'host2', 'integer_column') == 2 + + # Test fallback to static JSON when column not in DataFrame + assert load_column_or_static(df, 'host1', 'static_column') == 'static_value' + assert load_column_or_static(df, 'host2', 'another_static') == 'another_value' + + # Test error when column not found in either source + with pytest.raises(KeyError, match="Column 'nonexistent' not found in test and benign fallback data"): + load_column_or_static(df, 'host1', 'nonexistent') + + # Test error when host_id not found + with pytest.raises(KeyError, match="No row for host_id='unknown_host' in DataFrame"): + load_column_or_static(df, 'unknown_host', 'string_column') + + # Test error when multiple rows found for same host_id + df_duplicate = pd.DataFrame({ + 'host_id': ['host1', 'host1', 'host2'], 'string_column': ['value1', 'value1_dup', 'value2'] + }) + with pytest.raises(ValueError, match="Multiple rows found for host_id='host1' in DataFrame"): + load_column_or_static(df_duplicate, 'host1', 'string_column') + + # Test error when benign fallback data not preloaded + _DATA_CACHE['benign_fallback_test_data'] = None + with pytest.raises(ValueError, match="Benign fallback test data not preloaded. Call `preload_test_data` first."): + load_column_or_static(df, 'host1', 'static_column') + + +def _mock_ansible_runner(status="successful", rc=0, events=None, stdout=None): + """ + Build a dummy ansible_runner.Runner-like object. + """ + runner = MagicMock() + runner.status = status + runner.rc = rc + # Only set .events if given + if events is not None: + runner.events = events + else: + # Simulate no events + if stdout is not None: + runner.stdout = MagicMock() + runner.stdout.read.return_value = stdout + else: + runner.stdout = None + # Leave runner.events unset or empty + runner.events = [] + + return runner + + +@pytest.mark.parametrize( + "status, rc, events, stdout, expected_tasks, expected_raw", + [ + # 1) Successful run with two events + ( + "successful", + 0, + [ + { + "event": "runner_on_ok", + "event_data": { + "task": "test task", "host": "host1", "res": { + "changed": True, "stdout": "hello" + } + }, + "stdout": "Task output", + }, + { + "event": "runner_on_failed", + "event_data": { + "task": "failed task", "host": "host1", "res": { + "failed": True, "msg": "error" + } + }, + "stdout": "Error output", + }, + ], + None, + # Build expected task_results from events + lambda evs: [{ + "task": ev["event_data"]["task"], + "host": ev["event_data"]["host"], + "status": ev["event"], + "stdout": ev["stdout"], + "result": ev["event_data"]["res"], } + for ev in evs if ev["event"] in ("runner_on_ok", "runner_on_failed")], + None, + ), + # 2) No events but stdout present + ("failed", 1, None, "Command failed output", lambda _: [], "Command failed output"), + # 3) No events and no stdout + ("failed", 1, None, None, lambda _: [], "No output captured."), + ], +) +async def test_run_ansible_playbook_various(status, rc, events, stdout, expected_tasks, expected_raw): + # Ansible parameters + playbook = [{"name": "test task", "command": "echo hello"}] + ansible_host = "test.example.com" + ansible_user = "testuser" + ansible_port = 22 + ansible_private_key_path = "/path/to/key.pem" + + runner = _mock_ansible_runner(status=status, rc=rc, events=events, stdout=stdout) + + # Patch ansible_runner.run + with patch("ansible_runner.run", return_value=runner) as mock_run: + result = await run_ansible_playbook(playbook, + ansible_host, + ansible_user, + ansible_port, + ansible_private_key_path) + + # Verify the call + mock_run.assert_called_once() + call_kwargs = mock_run.call_args.kwargs + assert call_kwargs["playbook"] == playbook + inv = call_kwargs["inventory"]["all"]["hosts"]["host1"] + assert inv["ansible_host"] == ansible_host + assert inv["ansible_user"] == ansible_user + assert inv["ansible_ssh_private_key_file"] == ansible_private_key_path + assert inv["ansible_port"] == ansible_port + + # Verify returned dict + assert result["ansible_status"] == status + assert result["return_code"] == rc + assert result["task_results"] == expected_tasks(events or []) + if not events: + assert result["raw_output"] == expected_raw