Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/test/unit/language/test_tensor_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import triton
import triton.language as tl
from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
from triton._internal_testing import is_hopper, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
from typing import Optional
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
Expand Down Expand Up @@ -1474,6 +1474,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
@pytest.mark.parametrize("y", [0, 32, 48])
@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper")
@pytest.mark.skipif(is_sm12x(), reason="TMA Scatter is not supported on sm120")
def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device):
if BLOCK_X > X or y + BLOCK_Y > Y:
pytest.skip()
Expand Down
4 changes: 4 additions & 0 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def is_hopper():
return is_cuda() and torch.cuda.get_device_capability()[0] == 9


def is_sm12x():
return is_cuda() and torch.cuda.get_device_capability()[0] == 12


def is_hip():
target = get_current_target()
return False if target is None else target.backend == "hip"
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tma_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ tt.func @tma_gather_simple(%arg0: !tt.tensordesc<tensor<1x128xbf16, #shared1>>,

// CHECK: [[OFFSET0:%.*]] = zext nneg i32 [[WARP_STRIDE]] to i64
// CHECK: [[BASEPTR0:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET0]]
// CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r"
// CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r"
// CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR0]], ptr nonnull %0, i32 [[Y0]], i32 [[IDX0]], i32 [[IDX1]], i32 [[IDX2]], i32 [[IDX3]], ptr addrspace(3) [[BAR]])

// CHECK: [[BASEPTR1:%.*]] = getelementptr i8, ptr addrspace(3) [[BASEPTR0]], i64 4096
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,7 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite(
auto callback = [&](Value pred, Value shMemPtr, Value yOffset,
ArrayRef<Value> xOffsets) {
std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared"
"::cluster.global.mbarrier::complete_tx::bytes "
"::cta.global.mbarrier::complete_tx::bytes "
"[$1], [$2, {$3, $4, $5, $6, $7}], [$8];";

PTXBuilder ptxBuilder;
Expand Down
Loading