diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f453e6bdbb79..41f47bff8121 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3381,39 +3381,6 @@ def kernel( np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) -@pytest.mark.interpreter -def test_max_num_imprecise_acc(device): - - if not hasattr(torch, 'float8_e5m2'): - pytest.skip(f"torch {torch.__version__} does not support float8_e5m2") - - if is_cuda(): - capability = torch.cuda.get_device_capability() - if capability != (9, 0): - return - - @triton.jit - def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - MAX_NUM_IMPRECISE_ACC: tl.constexpr): - off_m = tl.arange(0, BLOCK_M) - off_n = tl.arange(0, BLOCK_N) - off_k = tl.arange(0, BLOCK_K) - x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :]) - y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :]) - z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :]) - z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC) - tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z) - - M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64 - x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device) - y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device) - z = torch.zeros((M, N), dtype=torch.float32, device=device) - h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps) - if not is_cuda(): - return - assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC) - - @pytest.mark.parametrize('in_dtype', ['float32']) def test_dot_mulbroadcasted(in_dtype, device): if is_cuda(): @@ -3698,7 +3665,7 @@ def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.co torch.testing.assert_close(output, reference_out) -# Testing masked loads with an intermate copy to shared memory run. +# Testing masked loads with a copy to shared memory. # FIXME: Shape too small for ldmatrix when num_ctas=4 @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -5145,27 +5112,29 @@ def matmul_kernel( # @pytest.mark.interpreter +@pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) @pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) @pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) -def test_fp8_dot_acc(in_type_str, low_precision_acc, device): - if is_hip(): - pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') +def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): if is_cuda(): cc = torch.cuda.get_device_capability() if cc[0] >= 9 and in_type_str == "float8e4b15": pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + elif is_hip(): + if in_type_str != 'float8e5': + pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') + check_type_supported(in_type_str, device) - M, N, K = 128, 256, 256 - BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 A = numpy_random((M, K), dtype_str=in_type_str) B = numpy_random((K, N), dtype_str=in_type_str) C = torch.empty((M, N), dtype=torch.float32, device=device) num_warps = 8 a = to_triton(A, device=device, dst_type=in_type_str) b = to_triton(B, device=device, dst_type=in_type_str) - grid = (triton.cdiv(M, BLOCK_M), 1) - matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), + C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) torch_a = torch.from_numpy(A).to(device=device) th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) @@ -5173,10 +5142,10 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): ref_out = torch.matmul(th_a, th_b).to(torch.float32) if in_type_str == 'float8e4nv': torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) - elif low_precision_acc > 32: - torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) else: - torch.testing.assert_close(ref_out, C) + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if is_cuda() and low_precision_acc > 0 and torch.cuda.get_device_capability()[0] >= 9: + assert h.asm["ptx"].count("add.f32") == (BLOCK_M * BLOCK_N) // (32 * num_warps) * (BLOCK_K // low_precision_acc) # -----------------------