From 958d4f9ddc786fa7f23ff3fe52fb5c6f0db078ec Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Thu, 21 Nov 2024 12:18:27 -0800 Subject: [PATCH 1/2] Gate TMA tests on cuda toolchain version --- python/test/unit/hopper/test_experimental_tma.py | 10 +++++++--- python/triton/_internal_testing.py | 16 +++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 7062093aef6d..6015a9e35808 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -4,7 +4,7 @@ import triton import triton.language as tl from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) -from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma +from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg from typing import Optional @@ -29,9 +29,11 @@ def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper): tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}) -@requires_tma @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimetal_descriptor_load(byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + return device = "cuda" SIZE = 128 @@ -82,11 +84,13 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) -@requires_tma @pytest.mark.parametrize("num_stages", [1, 4]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + return device = "cuda" M, N, K = 8192, 8192, 1024 torch.manual_seed(42) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index fa5df4f865d5..9ebda3784108 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -4,6 +4,7 @@ import torch import triton import triton.language as tl +from triton.backends.nvidia.compiler import _path_to_binary import pytest from numpy.random import RandomState @@ -140,8 +141,17 @@ def to_numpy(x): raise ValueError(f"Not a triton-compatible tensor: {x}") -def supports_tma(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 +def supports_tma(byval_only=False): + _, cuda_version = _path_to_binary("ptxas") + min_cuda_version = 12.0 if byval_only else 12.3 + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and float(cuda_version) >= min_cuda_version + + +def tma_skip_msg(byval_only=False): + if byval_only: + return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)" + else: + return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)" -requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)") +requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg()) From 971eadbf2b3ac3ea9933999785ad7ce69bf69f28 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Thu, 21 Nov 2024 15:58:31 -0800 Subject: [PATCH 2/2] nits --- python/test/unit/hopper/test_experimental_tma.py | 4 ++-- python/triton/_internal_testing.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 6015a9e35808..23065953d65b 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -33,7 +33,7 @@ def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper): def test_experimetal_descriptor_load(byval_tma): if not supports_tma(byval_tma): pytest.skip(tma_skip_msg(byval_tma)) - return + device = "cuda" SIZE = 128 @@ -90,7 +90,7 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): if not supports_tma(byval_tma): pytest.skip(tma_skip_msg(byval_tma)) - return + device = "cuda" M, N, K = 8192, 8192, 1024 torch.manual_seed(42) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 9ebda3784108..5ba0be1e34f9 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -143,8 +143,10 @@ def to_numpy(x): def supports_tma(byval_only=False): _, cuda_version = _path_to_binary("ptxas") - min_cuda_version = 12.0 if byval_only else 12.3 - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and float(cuda_version) >= min_cuda_version + min_cuda_version = (12, 0) if byval_only else (12, 3) + cuda_version_tuple = tuple(map(int, cuda_version.split("."))) + assert len(cuda_version_tuple) == 2, cuda_version_tuple + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version def tma_skip_msg(byval_only=False):