diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 029b726e67e6..ce453ad3816b 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy +from triton._internal_testing import is_hopper, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from typing import Optional from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3 @@ -1474,6 +1474,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl. @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) @pytest.mark.parametrize("y", [0, 32, 48]) @pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") +@pytest.mark.skipif(is_sm12x(), reason="TMA Scatter is not supported on sm120") def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device): if BLOCK_X > X or y + BLOCK_Y > Y: pytest.skip() diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index ba8621812022..562e8df721ea 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -54,6 +54,10 @@ def is_hopper(): return is_cuda() and torch.cuda.get_device_capability()[0] == 9 +def is_sm12x(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 12 + + def is_hip(): target = get_current_target() return False if target is None else target.backend == "hip" diff --git a/test/Conversion/tma_to_llvm.mlir b/test/Conversion/tma_to_llvm.mlir index 1f43127e571c..8552bebf0228 100644 --- a/test/Conversion/tma_to_llvm.mlir +++ b/test/Conversion/tma_to_llvm.mlir @@ -65,7 +65,7 @@ tt.func @tma_gather_simple(%arg0: !tt.tensordesc>, // CHECK: [[OFFSET0:%.*]] = zext nneg i32 [[WARP_STRIDE]] to i64 // CHECK: [[BASEPTR0:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET0]] - // CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r" + // CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r" // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR0]], ptr nonnull %0, i32 [[Y0]], i32 [[IDX0]], i32 [[IDX1]], i32 [[IDX2]], i32 [[IDX3]], ptr addrspace(3) [[BAR]]) // CHECK: [[BASEPTR1:%.*]] = getelementptr i8, ptr addrspace(3) [[BASEPTR0]], i64 4096 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index c860a05fac64..0611453915b5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1717,7 +1717,7 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite( auto callback = [&](Value pred, Value shMemPtr, Value yOffset, ArrayRef xOffsets) { std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared" - "::cluster.global.mbarrier::complete_tx::bytes " + "::cta.global.mbarrier::complete_tx::bytes " "[$1], [$2, {$3, $4, $5, $6, $7}], [$8];"; PTXBuilder ptxBuilder;