Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
Expand All @@ -35,6 +37,13 @@ class CategorizerToolConfig(FunctionBaseConfig, name="categorizer"):
llm_name: LLMRef


def _extract_markdown_heading_level(report: str) -> str:
""" Extract the markdown heading level from first line (report title)."""
m = re.search(r'^(#+)', report, re.MULTILINE)
pound_signs = m.group(1) if m else "#"
return pound_signs


@register_function(config_type=CategorizerToolConfig)
async def categorizer_tool(config: CategorizerToolConfig, builder: Builder):
# Set up LLM and chain
Expand All @@ -49,8 +58,8 @@ async def _arun(report: str) -> str:

result = await categorization_chain.ainvoke({"msgs": [HumanMessage(content=report)]})

# Extract the markdown heading level from first line of report (e.g. '#' or '##')
pound_signs = report.split('\n')[0].split(' ')[0]
# Extract the title's heading level and add an additional '#' for the section heading
pound_signs = _extract_markdown_heading_level(report) + "#"

# Format the root cause category section:
# - Add newlines before and after section
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
from .prompts import ToolReasoningLayerPrompts


class HardwareCheckToolConfig(FunctionBaseConfig, name="hardware_check"):
description: str = Field(
default=("This tool checks hardware health status using IPMI monitoring to detect power state, "
"hardware degradation, and anomalies that could explain alerts. Args: host_id: str"),
description="Description of the tool for the agent.")
llm_name: LLMRef
test_mode: bool = Field(default=True, description="Whether to run in test mode")


def _get_ipmi_monitor_data(ip_address, username, password):
"""
Capture IPMI monitoring data using the ipmimonitoring command.
Expand Down Expand Up @@ -58,15 +67,6 @@ def _get_ipmi_monitor_data(ip_address, username, password):
return None


class HardwareCheckToolConfig(FunctionBaseConfig, name="hardware_check"):
description: str = Field(
default=("This tool checks hardware health status using IPMI monitoring to detect power state, "
"hardware degradation, and anomalies that could explain alerts. Args: host_id: str"),
description="Description of the tool for the agent.")
llm_name: LLMRef
test_mode: bool = Field(default=True, description="Whether to run in test mode")


@register_function(config_type=HardwareCheckToolConfig)
async def hardware_check_tool(config: HardwareCheckToolConfig, builder: Builder):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,18 @@ async def _parse_stdout_lines(config, builder, stdout_lines):
Returns:
str: Structured data parsed from the output in string format.
"""
# Join the list of lines into a single text block
input_data = "\n".join(stdout_lines) if stdout_lines else ""

prompt = ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_PARSING.format(input_data=input_data)

response = None
try:
response = await utils.llm_ainvoke(config, builder, user_prompt=prompt)
structured_data = response
# Join the list of lines into a single text block
input_data = "\n".join(stdout_lines) if stdout_lines else ""

prompt = ToolReasoningLayerPrompts.HOST_PERFORMANCE_CHECK_PARSING.format(input_data=input_data)

response = await utils.llm_ainvoke(config=config, builder=builder, user_prompt=prompt)
except Exception as e:
structured_data = ('{{"error": "Failed to parse nvda_nim response", '
'"exception": "{}", "raw_response": "{}"}}').format(str(e), response)
return structured_data
response = ('{{"error": "Failed to parse stdout from the playbook run.", '
'"exception": "{}", "raw_response": "{}"}}').format(str(e), response)
return response


@register_function(config_type=HostPerformanceCheckToolConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _load_maintenance_data(path: str) -> pd.DataFrame:
required = {"host_id", "maintenance_start", "maintenance_end"}
missing = required - set(df.columns)
if missing:
missing = sorted(missing)
utils.logger.error("Missing required columns: %s", ", ".join(missing))
raise ValueError(f"Missing required columns: {', '.join(missing)}")

Expand All @@ -96,25 +97,27 @@ def _parse_alert_data(input_message: str) -> dict | None:
"""
Parse alert data from an input message containing JSON into a dictionary.

Note: This function assumes the input message contains exactly one JSON object (the alert)
with potential extra text before and/or after the JSON.
This function extracts and parses a JSON object from a text message that may contain
additional text before and/or after the JSON. It handles both double and single quoted
JSON strings and can parse nested JSON structures.

Args:
input_message (str): Input message containing JSON alert data
input_message (str): Input message containing a JSON object, which may be surrounded
by additional text. The JSON object should contain alert details
like host_id and timestamp.

Returns:
dict | None: The parsed alert data as a dictionary containing alert details,
or None if parsing fails

Raises:
ValueError: If no JSON object is found in the input message
json.JSONDecodeError: If the JSON parsing fails
dict | None: The parsed alert data as a dictionary if successful parsing,
containing fields like host_id and timestamp.
Returns None if no valid JSON object is found or parsing fails.
"""
# Extract everything between first { and last }
start = input_message.find("{")
end = input_message.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError("No JSON object found in input message")
utils.logger.error("No JSON object found in input message")
return None

alert_json_str = input_message[start:end]
try:
return json.loads(alert_json_str.replace("'", '"'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
from .prompts import ToolReasoningLayerPrompts


class NetworkConnectivityCheckToolConfig(FunctionBaseConfig, name="network_connectivity_check"):
description: str = Field(
default=("This tool checks network connectivity of a host by running ping and socket connection tests. "
"Args: host_id: str"),
description="Description of the tool for the agent.")
llm_name: LLMRef
test_mode: bool = Field(default=True, description="Whether to run in test mode")


def _check_service_banner(host: str, port: int = 80, connect_timeout: float = 10, read_timeout: float = 10) -> str:
"""
Connects to host:port, reads until the Telnet banner (‘Escape character is '^]'.’) or times out.
Expand Down Expand Up @@ -56,15 +65,6 @@ def _check_service_banner(host: str, port: int = 80, connect_timeout: float = 10
return ''


class NetworkConnectivityCheckToolConfig(FunctionBaseConfig, name="network_connectivity_check"):
description: str = Field(
default=("This tool checks network connectivity of a host by running ping and socket connection tests. "
"Args: host_id: str"),
description="Description of the tool for the agent.")
llm_name: LLMRef
test_mode: bool = Field(default=True, description="Whether to run in test mode")


@register_function(config_type=NetworkConnectivityCheckToolConfig)
async def network_connectivity_check_tool(config: NetworkConnectivityCheckToolConfig, builder: Builder):

Expand Down
14 changes: 10 additions & 4 deletions examples/alert_triage_agent/src/aiq_alert_triage_agent/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,25 @@ def receive_alert():
HTTP endpoint to receive a JSON alert via POST.
Expects application/json with a single alert dict or a list of alerts.
"""
# use the globals-set ENV_FILE
if ENV_FILE is None:
raise ValueError("ENV_FILE must be set before processing alerts")

try:
data = request.get_json(force=True)
except Exception:
return jsonify({"error": "Invalid JSON"}), 400

alerts = data if isinstance(data, list) else [data]
if not all(isinstance(alert, dict) for alert in alerts):
return jsonify({"error": "Alerts not represented as dictionaries"}), 400

for alert in alerts:
alert_id = alert.get('alert_id')
if 'alert_id' not in alert:
return jsonify({"error": "`alert_id` is absent in the alert payload"}), 400

alert_id = alert['alert_id']
processed_alerts.append(alert_id)
# use the globals-set ENV_FILE
if ENV_FILE is None:
raise ValueError("ENV_FILE must be set before processing alerts")
start_process(alert, ENV_FILE)

return jsonify({"received_alert_count": len(alerts), "total_launched": len(processed_alerts)}), 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def _timeseries_stats(ts):
Returns:
str: Markdown formatted string containing summary statistics
"""
if len(ts) == 0:
return "No data points"

count = len(ts)
max_val = max(ts)
min_val = min(ts)
Expand Down Expand Up @@ -89,7 +92,11 @@ def _get_llm_analysis_input(timestamp_value_list):
str: Formatted string containing:
- JSON array of [datetime_str, value] pairs with human readable timestamps
- Summary statistics of the metric values
- "No data points" if input list is empty
"""
if len(timestamp_value_list) == 0:
return "No data points"

# Convert Unix timestamps to ISO format datetime strings and preserve values
# Example: "2022-01-17 12:00:00" for timestamp 1642435200
data = [[datetime.fromtimestamp(entry[0]).strftime("%Y-%m-%d %H:%M:%S"), entry[1]]
Expand Down Expand Up @@ -120,7 +127,7 @@ async def _arun(host_id: str) -> str:

# Customize query based on your monitoring setup and metrics
# This example queries the CPU usage percentage by subtracting idle CPU from 100%
query = '(100 - cpu_usage_idle{cpu="cpu-total",instance=~"{host_id}:9100"})'
query = f'(100 - cpu_usage_idle{{cpu="cpu-total",instance=~"{host_id}:9100"}})'
url = f"{monitoring_url}/api/query_range"

# Example values - users should customize these based on their monitoring requirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def preload_test_data(test_data_path: str | None, benign_fallback_data_path: str
def get_test_data() -> pd.DataFrame:
"""Returns the preloaded test data."""
if _DATA_CACHE['test_data'] is None:
raise ValueError("Test data not preloaded. Call preload_test_data() first.")
raise ValueError("Test data not preloaded. Call `preload_test_data` first.")
return pd.DataFrame(_DATA_CACHE['test_data'])


def _get_static_data():
"""Returns the preloaded benign fallback test data."""
if _DATA_CACHE['benign_fallback_test_data'] is None:
raise ValueError("Benign fallback test data not preloaded. Call preload_test_data() first.")
raise ValueError("Benign fallback test data not preloaded. Call `preload_test_data` first.")
return _DATA_CACHE['benign_fallback_test_data']


Expand Down Expand Up @@ -147,7 +147,7 @@ def load_column_or_static(df, host_id, column):
try:
return static_data[column]
except KeyError as exc:
raise KeyError(f"Column '{column}' not found in static data") from exc
raise KeyError(f"Column '{column}' not found in test and benign fallback data") from exc
# Column exists in DataFrame, get value for this host
# Assumption: In test dataset, host_ids are unique and used to locate specific tool return values
# If multiple rows found for a host_id, this indicates data inconsistency
Expand Down
34 changes: 34 additions & 0 deletions examples/alert_triage_agent/tests/test_categorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from aiq_alert_triage_agent.categorizer import _extract_markdown_heading_level


@pytest.mark.parametrize(
"test_input,expected",
[
pytest.param("# Title", "#", id="single_hash"),
pytest.param("### Title", "###", id="multiple_hashes"),
pytest.param("No heading", "#", id="no_heading_default"),
pytest.param("", "#", id="empty_string"),
pytest.param("## My Title\n### Heading", "##", id="first_of_many"),
pytest.param("Here is a title\n## Title Line", "##", id="first_after_text"),
pytest.param("## Heading first\n# Title", "##", id="heading_precedence"),
pytest.param("###No space between # and title", "###", id="no_space_after_hashes"),
],
)
def test_extract_markdown_heading_level(test_input, expected):
assert _extract_markdown_heading_level(test_input) == expected
79 changes: 79 additions & 0 deletions examples/alert_triage_agent/tests/test_hardware_check_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest
from aiq_alert_triage_agent.hardware_check_tool import _get_ipmi_monitor_data


# Fixtures for inputs and expected command
@pytest.fixture
def ipmi_args():
return "1.1.1.1", "test_user", "test_pass"


@pytest.fixture
def expected_cmd(ipmi_args):
ip, user, pwd = ipmi_args
return [
"ipmimonitoring",
"-h",
ip,
"-u",
user,
"-p",
pwd,
"--privilege-level=USER",
]


# Fixture to mock subprocess.run
@pytest.fixture
def mock_run():
with patch('subprocess.run') as m:
yield m


# Parameterized test covering both success and failure
@pytest.mark.parametrize(
"stdout, side_effect, expected",
[
# success case: subprocess returns stdout
pytest.param("Sample IPMI output", None, "Sample IPMI output", id="success"),
# failure case: subprocess raises CalledProcessError
pytest.param(
"unused output",
subprocess.CalledProcessError(returncode=1, cmd=["ipmimonitoring"], stderr="Command failed"),
None, # expected None when ipmimonitoring command raises error
id="failure"),
])
def test_get_ipmi_monitor_data(mock_run, ipmi_args, expected_cmd, stdout, side_effect, expected):
# configure mock
if side_effect:
mock_run.side_effect = side_effect
else:
mock_result = MagicMock()
mock_result.stdout = stdout
mock_run.return_value = mock_result

# invoke
result = _get_ipmi_monitor_data(*ipmi_args)

# assertions
assert result == expected
mock_run.assert_called_once_with(expected_cmd, capture_output=True, text=True, check=True)
Loading