Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3082,7 +3082,7 @@ def convert_fp8_to_fp32(x, device, dtype_str):
[(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack)
for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4],
[32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]]
for input_precision in ["ieee" if is_hip() else "tf32"]
for input_precision in ["tf32" if is_cuda() else "ieee"]
for col_a in [True, False]
for col_b in [True, False]
for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32',
Expand Down Expand Up @@ -3338,7 +3338,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
if out_dtype_str == "float16":
pytest.skip(f"{out_dtype_str} has low precision in WMMA dot")
else:
input_precision = "tf32" if in_dtype_str == 'float32' else "ieee"
input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee"

if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32":
if triton.runtime.driver.active.utils.get_device_properties(
Expand Down Expand Up @@ -5465,7 +5465,7 @@ def maxnreg_noinline2(X):

def test_maxnreg(device):
assert not is_interpreter(), "this test won't work with the interpreter"
if is_hip():
if not is_cuda():
pytest.skip('maxnreg only works on CUDA')

# triton kernel
Expand Down