diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 7062093aef6d..23065953d65b 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -4,7 +4,7 @@ import triton import triton.language as tl from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) -from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma +from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg from typing import Optional @@ -29,9 +29,11 @@ def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper): tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}) -@requires_tma @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimetal_descriptor_load(byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + device = "cuda" SIZE = 128 @@ -82,11 +84,13 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) -@requires_tma @pytest.mark.parametrize("num_stages", [1, 4]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + device = "cuda" M, N, K = 8192, 8192, 1024 torch.manual_seed(42) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index fa5df4f865d5..5ba0be1e34f9 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -4,6 +4,7 @@ import torch import triton import triton.language as tl +from triton.backends.nvidia.compiler import _path_to_binary import pytest from numpy.random import RandomState @@ -140,8 +141,19 @@ def to_numpy(x): raise ValueError(f"Not a triton-compatible tensor: {x}") -def supports_tma(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 +def supports_tma(byval_only=False): + _, cuda_version = _path_to_binary("ptxas") + min_cuda_version = (12, 0) if byval_only else (12, 3) + cuda_version_tuple = tuple(map(int, cuda_version.split("."))) + assert len(cuda_version_tuple) == 2, cuda_version_tuple + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version + + +def tma_skip_msg(byval_only=False): + if byval_only: + return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)" + else: + return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)" -requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)") +requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())