diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d096a81150a0..9e5ff8a2ce37 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3082,7 +3082,7 @@ def convert_fp8_to_fp32(x, device, dtype_str): [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] - for input_precision in ["ieee" if is_hip() else "tf32"] + for input_precision in ["tf32" if is_cuda() else "ieee"] for col_a in [True, False] for col_b in [True, False] for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', @@ -3338,7 +3338,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: - input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": if triton.runtime.driver.active.utils.get_device_properties( @@ -5465,7 +5465,7 @@ def maxnreg_noinline2(X): def test_maxnreg(device): assert not is_interpreter(), "this test won't work with the interpreter" - if is_hip(): + if not is_cuda(): pytest.skip('maxnreg only works on CUDA') # triton kernel