Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion projects/hipblaslt/tensilelite/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ While full test suites can be run with a single `tox` command, developers may wi
build the hipBLASLt tensilelite client executable (`tensilelite-client`) and run individual tests separately.
This is useful for debugging specific problems or isolating issues in a specific test.

### Run Full Test Suite with Tox
### Run Test Suite with Tox

The standard workflow for running the entire test suite is to use `tox`. This command will build
`tensilelite-client` and execute all tests.
Expand All @@ -16,6 +16,12 @@ cd rocm-libraries/projects/hipblaslt/tensilelite
tox -e py3 -- Tensile/Tests -m common
```

Subsequently, you can run just the Tensile unit tests via:

```
tox -e unit -- Tensile/Tests/unit
```

### Build client with invoke and Run a Test (Default Path)

This workflow uses `invoke` to build the client into the default `build_tmp` directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def test_filterLogicFilesByPredicates_no_match(mock_logic_file):
result = filterLogicFilesByPredicates(logicFiles, predicateMap)
assert len(result) == 0

@pytest.mark.xfail
def test_filterLogicFilesByPredicates_match_emulation_ids(mock_logic_file):
logicFiles = ["file1.yaml"]
predicateMap = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import pytest
from unittest.mock import MagicMock

from Tensile.Components.CustomSchedule import hasCustomSchedule, ScheduleInfo
from Tensile.Common import IsaVersion

# Helper to create a mock data type
def _mock_dtype(is_16bit=False, is_8bit=False, num_bytes=4):
mock = MagicMock()
mock.isHalf.return_value = is_16bit
mock.isBFloat16.return_value = False # Assuming isHalf is enough for is16bit
mock.isInt8.return_value = is_8bit
mock.is8bitFloat.return_value = False # Assuming isInt8 is enough for is8bit
mock.numBytes.return_value = num_bytes
return mock

# Base kernel configuration factory
def create_base_kernel():
kernel = {
"UseCustomMainLoopSchedule": True,
"EnableMatrixInstruction": True,
"ISA": IsaVersion(9,5,0),
"ProblemType": {
"DataType": _mock_dtype(),
"DataTypeA": _mock_dtype(),
"DataTypeB": _mock_dtype(),
"TransposeA": False,
"TransposeB": False,
},
"MacroTile0": 0, "MacroTile1": 0, "DepthU": 0,
"PrefetchGlobalRead": 0, "PrefetchLocalRead": 0, "DirectToLds": False,
"GlobalReadVectorWidthA": 0, "GlobalReadVectorWidthB": 0,
"LocalReadVectorWidth": 0,
"MatrixInstruction": [],
"MIWaveGroup": [],
"LDSTrInst": False,
"TransposeLDS": 0,
"ForceUnrollSubIter": False,
"SwapGlobalReadOrder": False, # For asserting it gets set
"UsePLRPack": False, # For asserting it gets set
}
return kernel

class TestCustomSchedule:
def test_no_custom_schedule(self):
"""Test that a kernel that doesn't match any condition returns False."""
kernel = create_base_kernel()
# An empty kernel should not have a custom schedule
has_schedule, schedule_info = hasCustomSchedule(kernel)
assert not has_schedule
assert schedule_info is None

def test_schedule_256x256x64_16bit_TN(self):
"""Tests the 256x256x64 16-bit TN schedule."""
kernel = create_base_kernel()
dtype_16bit = _mock_dtype(is_16bit=True, num_bytes=2)
kernel["ProblemType"].update({
"DataType": dtype_16bit, "DataTypeA": dtype_16bit, "DataTypeB": dtype_16bit,
"TransposeA": True, "TransposeB": False
})
kernel.update({
"MacroTile0": 256, "MacroTile1": 256, "DepthU": 64,
"PrefetchGlobalRead": 2, "PrefetchLocalRead": 1, "DirectToLds": True,
"GlobalReadVectorWidthA": 8, "GlobalReadVectorWidthB": 8, "LocalReadVectorWidth": 8,
"MatrixInstruction": [16,16,32,1], "MIWaveGroup": [2,2], "TransposeLDS": 1
})

has_schedule, schedule_info = hasCustomSchedule(kernel)

assert has_schedule
assert isinstance(schedule_info, ScheduleInfo)
assert schedule_info.numCodePaths == 2
assert schedule_info.numMfma == 128
assert 'PackA0' not in schedule_info.optSchedule
assert not kernel["UsePLRPack"]

def test_schedule_256x256x64_16bit_NT(self):
"""Tests the 256x256x64 16-bit NT schedule."""
kernel = create_base_kernel()
dtype_16bit = _mock_dtype(is_16bit=True, num_bytes=2)
kernel["ProblemType"].update({
"DataType": dtype_16bit, "DataTypeA": dtype_16bit, "DataTypeB": dtype_16bit,
"TransposeA": False, "TransposeB": True
})
kernel.update({
"MacroTile0": 256, "MacroTile1": 256, "DepthU": 64,
"PrefetchGlobalRead": 2, "PrefetchLocalRead": 1, "DirectToLds": True,
"GlobalReadVectorWidthA": 8, "GlobalReadVectorWidthB": 8, "LocalReadVectorWidth": 8,
"MatrixInstruction": [16,16,32,1], "MIWaveGroup": [2,2],
"LDSTrInst": False, "TransposeLDS": 0
})

has_schedule, schedule_info = hasCustomSchedule(kernel)

assert has_schedule
assert isinstance(schedule_info, ScheduleInfo)
assert schedule_info.numCodePaths == 2
assert schedule_info.numMfma == 128
assert 'PackA0' in schedule_info.optSchedule
assert kernel["UsePLRPack"]

@pytest.mark.parametrize("transA, transB", [(False, False), (True, True)])
def test_schedule_256x256x64_16bit_NN_TT(self, transA, transB):
"""Tests the 256x256x64 16-bit NN and TT schedules."""
kernel = create_base_kernel()
dtype_16bit = _mock_dtype(is_16bit=True, num_bytes=2)
kernel["ProblemType"].update({
"DataType": dtype_16bit, "DataTypeA": dtype_16bit, "DataTypeB": dtype_16bit,
"TransposeA": transA, "TransposeB": transB
})
kernel.update({
"MacroTile0": 256, "MacroTile1": 256, "DepthU": 64,
"PrefetchGlobalRead": 2, "PrefetchLocalRead": 1, "DirectToLds": True,
"GlobalReadVectorWidthA": 8, "GlobalReadVectorWidthB": 8, "LocalReadVectorWidth": 8,
"MatrixInstruction": [16,16,32,1], "MIWaveGroup": [2,2],
"LDSTrInst": False, "TransposeLDS": 1
})

has_schedule, schedule_info = hasCustomSchedule(kernel)

assert has_schedule
assert isinstance(schedule_info, ScheduleInfo)
assert schedule_info.numCodePaths == 2
assert schedule_info.numMfma == 128
assert kernel["UsePLRPack"]
if transA and transB: # isTT
assert kernel["SwapGlobalReadOrder"]
assert 'PackB0' in schedule_info.optSchedule
assert 'PackA0' not in schedule_info.optSchedule
else: # isNN
assert not kernel["SwapGlobalReadOrder"]
assert 'PackA0' in schedule_info.optSchedule
assert 'PackB0' not in schedule_info.optSchedule

def test_schedule_256x256x128_8bit_TN(self):
"""Tests the 256x256x128 8-bit TN schedule."""
kernel = create_base_kernel()
dtype_8bit = _mock_dtype(is_8bit=True, num_bytes=1)
kernel["ProblemType"].update({
"DataType": dtype_8bit, "DataTypeA": dtype_8bit, "DataTypeB": dtype_8bit,
"TransposeA": True, "TransposeB": False
})
kernel.update({
"MacroTile0": 256, "MacroTile1": 256, "DepthU": 128,
"PrefetchGlobalRead": 2, "PrefetchLocalRead": 0, "DirectToLds": True,
"GlobalReadVectorWidthA": 16, "GlobalReadVectorWidthB": 16, "LocalReadVectorWidth": 16,
"MatrixInstruction": [16,16,128,1], "MIWaveGroup": [2,2], "TransposeLDS": 1
})

has_schedule, schedule_info = hasCustomSchedule(kernel)

assert has_schedule
assert isinstance(schedule_info, ScheduleInfo)
assert schedule_info.numCodePaths == 1
assert schedule_info.numMfma == 64
assert len(schedule_info.mfmaReorder) > 0

def test_schedule_192x256x64_16bit_NN(self):
"""Tests the 192x256x64 16-bit NN schedule."""
kernel = create_base_kernel()
dtype_16bit = _mock_dtype(is_16bit=True, num_bytes=2)
kernel["ProblemType"].update({
"DataType": dtype_16bit, "DataTypeA": dtype_16bit, "DataTypeB": dtype_16bit,
"TransposeA": False, "TransposeB": False
})
kernel.update({
"MacroTile0": 192, "MacroTile1": 256, "DepthU": 64,
"PrefetchGlobalRead": 2, "PrefetchLocalRead": 1, "DirectToLds": True,
"GlobalReadVectorWidthA": 8, "GlobalReadVectorWidthB": 8, "LocalReadVectorWidth": 8,
"MatrixInstruction": [16,16,32,1], "MIWaveGroup": [2,2],
"LDSTrInst": True, "TransposeLDS": 1
})

has_schedule, schedule_info = hasCustomSchedule(kernel)

assert has_schedule
assert isinstance(schedule_info, ScheduleInfo)
assert schedule_info.numCodePaths == 2
assert schedule_info.numMfma == 96
assert kernel["SwapGlobalReadOrder"]
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def test_convert_9_item_custom_kernel_config():
assert outputConf["MIInputPerThreadB"] == 5
assert outputConf["MIInputPerThreadMetadata"] == 5
assert outputConf["ThreadTile"] == [1, 1]
assert outputConf["Sparse"] == 0
assert outputConf["WorkGroup"] == [128, 3, 1]
assert outputConf["WavefrontSize"] == 48
assert outputConf["ISA"] == isa
Expand Down Expand Up @@ -201,7 +200,6 @@ def testConvert9ItemCustomKernelConfig():
assert outputConf["MIInputPerThreadB"] == 5
assert outputConf["MIInputPerThreadMetadata"] == 5
assert outputConf["ThreadTile"] == [1, 1]
assert outputConf["Sparse"] == 0
assert outputConf["WorkGroup"] == [1280, 2, 6] # Why do we change the workgroup here?
assert outputConf["WavefrontSize"] == 48
assert outputConf["ISA"] == isa
Expand Down
15 changes: 12 additions & 3 deletions projects/hipblaslt/tensilelite/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,26 @@ deps =
invoke
setenv =
TENSILE_CLIENT_STATIC = {env:TENSILE_CLIENT_STATIC:}
PYTHONPATH = {envdir}/build_tmp/tensilelite/rocisa/lib
PYTHONPATH = {toxinidir}/build_tmp/tensilelite/rocisa/lib
TENSILELITE_CLIENT_ARGS = {env:TENSILELITE_CLIENT_ARGS:}
commands =
pip install --upgrade pip
pip install pytest-cov
invoke build-client --build-dir {envdir}/build_tmp {env:TENSILELITE_CLIENT_ARGS}
pytest -v --basetemp={envtmpdir} --junit-xml={toxinidir}/python_tests.xml --junit-prefix={envname} --color=yes -n 4 --prebuilt-client={envdir}/build_tmp/tensilelite/client/tensilelite-client {posargs}
invoke build-client --build-dir {toxinidir}/build_tmp {env:TENSILELITE_CLIENT_ARGS}
pytest -v --basetemp={envtmpdir} --junit-xml={toxinidir}/python_tests.xml --junit-prefix={envname} --color=yes -n 4 --prebuilt-client={toxinidir}/build_tmp/tensilelite/client/tensilelite-client {posargs}
allowlist_externals =
mkdir
sh
cmake

[testenv:unit]
description = "Runs Python unit tests quickly, skipping the client build. Assumes a build has run before."
basepython = python3
# This environment inherits 'deps' and 'setenv' from [testenv]
commands =
pytest -v --basetemp={envtmpdir} {posargs}


[testenv:lint]
basepython = python3
deps =
Expand Down
Loading