From 5a72311abe10a255090bf4bae01a43f83d832c59 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 13 Mar 2024 23:49:48 -0500 Subject: [PATCH 1/2] Skip nv float8 type in test_dot --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4115bec88a96..f1f7c57d942d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2793,6 +2793,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o pytest.skip("Only test out_dtype=float16 on devices with sm >=80") if capability[0] < 9 and in_dtype == 'float8e4nv': pytest.skip("float8e4nv not supported on sm <= 80") + if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'): + pytest.skip("float8e4nv and float8e5 not supported on HIP") if is_interpreter() and in_dtype == 'int8': pytest.skip( "numpy.dot with int8 inputs will overflow while tl.dot doesn't because MMA instruction's accumulator is 32-bit" From 130f5813395c0ac3661532bef887c46d77098ff7 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 14 Mar 2024 00:06:21 -0500 Subject: [PATCH 2/2] Make target a list for HIP backend --- third_party/amd/backend/compiler.py | 6 +++--- third_party/amd/backend/driver.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index d11b0883410f..eee4fd77721e 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -82,12 +82,12 @@ def hash(self): class HIPBackend(BaseBackend): @staticmethod - def supports_target(target: tuple): + def supports_target(target: list): return target[0] == 'hip' - def __init__(self, target: tuple) -> None: + def __init__(self, target: list) -> None: super().__init__(target) - assert isinstance(target, tuple) and len(target) == 3 + assert isinstance(target, list) and len(target) == 3 assert isinstance(target[1], str) self.binary_ext = "hsaco" diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 391512a721ec..1247c91df4ed 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -272,4 +272,4 @@ def get_current_target(self): device_properties = self.utils.get_device_properties(device) arch = device_properties['arch'] warpSize = device_properties['warpSize'] - return ("hip", arch.split(':')[0], warpSize) + return ["hip", arch.split(':')[0], warpSize]