Skip to content
103 changes: 103 additions & 0 deletions .github/scripts/tests/therock_matrix_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from copy import deepcopy
from pathlib import Path
import os
import sys
import unittest
from unittest.mock import patch

sys.path.insert(0, os.fspath(Path(__file__).parent.parent))
import therock_matrix

# Store original project_map to restore between tests
ORIGINAL_PROJECT_MAP = deepcopy(therock_matrix.project_map)


class TheRockMatrixTest(unittest.TestCase):
def setUp(self):
therock_matrix.project_map = deepcopy(ORIGINAL_PROJECT_MAP)

def test_collect_projects_to_run_without_additional_option(self):
subtrees = ["projects/hipblaslt"]

Expand All @@ -33,11 +41,106 @@ def test_collect_projects_to_run_dependency_graph(self):
self.assertEqual(len(project_to_run), 1)

def test_collect_projects_to_run_dependency_graph_diff_projects(self):
# miopen and rocwmma: rocwmma adds to blas, miopen combines with blas
# via dependency_graph, so we end up with 1 combined project
subtrees = ["projects/miopen", "projects/rocwmma"]

project_to_run = therock_matrix.collect_projects_to_run(subtrees)
self.assertEqual(len(project_to_run), 1)
# Verify rocwmma tests are included in the combined project
projects_to_test = project_to_run[0]["projects_to_test"].split(",")
self.assertIn("rocwmma", projects_to_test)
self.assertIn("miopen", projects_to_test)

def test_collect_projects_to_run_truly_separate_projects(self):
# prim and fft are truly separate projects with no dependency overlap
subtrees = ["projects/rocprim", "projects/hipfft"]

project_to_run = therock_matrix.collect_projects_to_run(subtrees)
self.assertEqual(len(project_to_run), 2)


class TheRockDynamicDepsTest(unittest.TestCase):
"""Tests for dynamic test dependency resolution from TheRock."""

def setUp(self):
therock_matrix.project_map = deepcopy(ORIGINAL_PROJECT_MAP)

@patch("therock_matrix.get_test_dependencies_from_therock")
def test_rocblas_uses_therock_deps(self, mock_get_deps):
"""When TheRock returns deps for rocblas, use those deps."""
mock_get_deps.return_value = ["rocblas", "hipblas", "rocsolver"]
subtrees = ["projects/rocblas"]

project_to_run = therock_matrix.collect_projects_to_run(subtrees)

self.assertEqual(len(project_to_run), 1)
projects_to_test = project_to_run[0]["projects_to_test"].split(",")
self.assertIn("rocblas", projects_to_test)
self.assertIn("hipblas", projects_to_test)
self.assertIn("rocsolver", projects_to_test)
# Should NOT include hipblaslt or rocroller (from fallback)
self.assertNotIn("hipblaslt", projects_to_test)
self.assertNotIn("rocroller", projects_to_test)

@patch("therock_matrix.get_test_dependencies_from_therock")
def test_hipblaslt_falls_back_when_therock_returns_empty(self, mock_get_deps):
"""When TheRock returns empty, fall back to project_map."""
mock_get_deps.return_value = None
subtrees = ["projects/hipblaslt"]

project_to_run = therock_matrix.collect_projects_to_run(subtrees)

self.assertEqual(len(project_to_run), 1)
projects_to_test = project_to_run[0]["projects_to_test"].split(",")
# Should include all from project_map["blas"]["projects_to_test"]
self.assertIn("hipblaslt", projects_to_test)
self.assertIn("rocblas", projects_to_test)
self.assertIn("hipblas", projects_to_test)
self.assertIn("rocroller", projects_to_test)

@patch("therock_matrix.get_test_dependencies_from_therock")
def test_hipblaslt_falls_back_when_therock_returns_only_self(self, mock_get_deps):
"""When TheRock returns only the component itself, fall back to project_map."""
mock_get_deps.return_value = ["hipblaslt"]
subtrees = ["projects/hipblaslt"]

project_to_run = therock_matrix.collect_projects_to_run(subtrees)

self.assertEqual(len(project_to_run), 1)
projects_to_test = project_to_run[0]["projects_to_test"].split(",")
# Should include all from project_map["blas"]["projects_to_test"]
self.assertIn("hipblaslt", projects_to_test)
self.assertIn("rocblas", projects_to_test)
self.assertIn("hipblas", projects_to_test)
self.assertIn("rocroller", projects_to_test)

@patch("therock_matrix.get_test_dependencies_from_therock")
def test_mixed_components_combines_deps(self, mock_get_deps):
"""When both rocblas and hipblaslt change, combine TheRock deps with fallback."""

def mock_deps(component_names):
component = component_names[0]
if component == "rocblas":
return ["rocblas", "hipblas", "rocsolver"]
else:
return None # hipblaslt has no TheRock deps

mock_get_deps.side_effect = mock_deps
subtrees = ["projects/rocblas", "projects/hipblaslt"]

project_to_run = therock_matrix.collect_projects_to_run(subtrees)

self.assertEqual(len(project_to_run), 1)
projects_to_test = project_to_run[0]["projects_to_test"].split(",")
# Should include rocblas deps from TheRock
self.assertIn("rocblas", projects_to_test)
self.assertIn("hipblas", projects_to_test)
self.assertIn("rocsolver", projects_to_test)
# Should also include fallback deps for hipblaslt
self.assertIn("hipblaslt", projects_to_test)
self.assertIn("rocroller", projects_to_test)


if __name__ == "__main__":
unittest.main()
87 changes: 85 additions & 2 deletions .github/scripts/therock_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,33 @@
"""

import os
import sys
from pathlib import Path

# TheRock is checked out at ./TheRock during CI
THEROCK_DIR = Path.cwd() / "TheRock"


def get_test_dependencies_from_therock(component_names):
"""
Get test dependencies using TheRock's determine_rocm_test_dependencies.py.
Returns None if TheRock or the script is not available.
"""
test_tools_path = THEROCK_DIR / "test_tools"
if not test_tools_path.exists():
return None

sys.path.insert(0, str(test_tools_path))
try:
from determine_rocm_test_dependencies import get_rocm_test_dependencies

return get_rocm_test_dependencies(component_names, THEROCK_DIR)
except ImportError:
return None
finally:
if str(test_tools_path) in sys.path:
sys.path.remove(str(test_tools_path))


subtree_to_project_map = {
"dnn-providers/fusilli-provider": "fusilli-provider",
Expand Down Expand Up @@ -153,6 +180,20 @@
}


def extract_component_names_from_subtrees(subtrees):
"""Extract component names from subtree paths.

E.g., 'projects/rocprim' -> 'rocprim', 'shared/tensile' -> 'tensile'
"""
components = []
for subtree in subtrees:
# Extract the last part of the path as the component name
parts = subtree.split("/")
if len(parts) >= 2:
components.append(parts[-1])
return components


def collect_projects_to_run(subtrees):
platform = os.getenv("PLATFORM")
projects = set()
Expand All @@ -161,6 +202,30 @@ def collect_projects_to_run(subtrees):
if subtree in subtree_to_project_map:
projects.add(subtree_to_project_map.get(subtree))

# For each component, get tests from TheRock's script.
# If TheRock returns meaningful deps (more than just the component itself),
# use them; otherwise fall back to project_map.
# Also track which components belong to which project.
component_names = extract_component_names_from_subtrees(subtrees)
tests_per_component = {}
components_needing_fallback = set()
component_to_project = {}

# Build mapping from component to project
for subtree in subtrees:
if subtree in subtree_to_project_map:
component = subtree.split("/")[-1] if "/" in subtree else subtree
project = subtree_to_project_map[subtree]
component_to_project[component] = project

for component in component_names:
deps = get_test_dependencies_from_therock([component])
# Only use TheRock deps if it returns more than just the component itself
if deps and set(deps) != {component}:
tests_per_component[component] = deps
else:
components_needing_fallback.add(component)

for project in list(projects):
# Check if an optional math component was included.
if project in additional_options:
Expand Down Expand Up @@ -227,9 +292,27 @@ def collect_projects_to_run(subtrees):
project_map_data["cmake_options"] = list(
set(project_map_data["cmake_options"])
)
project_map_data["projects_to_test"] = list(
set(project_map_data["projects_to_test"])

# Collect tests: use TheRock deps for components that have them,
# fall back to hardcoded for components that don't.
# Only include test deps for components that belong to this project.
tests_to_run = set()

# Add test deps only for components that belong to this project
for component, deps in tests_per_component.items():
if component_to_project.get(component) == project:
tests_to_run.update(deps)

# For components needing fallback that belong to this project,
# use hardcoded projects_to_test
project_needs_fallback = any(
component_to_project.get(c) == project
for c in components_needing_fallback
)
if project_needs_fallback:
tests_to_run.update(project_map_data["projects_to_test"])

project_map_data["projects_to_test"] = list(tests_to_run)

cmake_flag_options = " ".join(project_map_data["cmake_options"])
projects_to_test_options = ",".join(project_map_data["projects_to_test"])
Expand Down
Loading