From c90b7b227828d11204e789b409f80151d69caad7 Mon Sep 17 00:00:00 2001 From: "Weinrauch, Alexander" Date: Fri, 20 Feb 2026 17:24:26 +0000 Subject: [PATCH 1/4] [AMD][BACKEND] Verify TDM strides and shared memory order TDM requires contiguous data so one of the strides has to be 1. Triton does implicitly annotate kernel arguments with value 1 as constexpr so we can check which strides is 1. Currently the lowering doesn't support reordering so this PR adds a strict check that the last dim is the fastest one. A follow up PR will allow some reordering. see this [ticket](https://github.com/ROCm/triton-internal/issues/1658). --- .../amd/tritongpu_tdm_stride_order.mlir | 100 ++++++++++++++++++ .../TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp | 45 ++++++++ .../amd/python/test/test_gluon_gfx1250.py | 7 +- 3 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 test/Conversion/amd/tritongpu_tdm_stride_order.mlir diff --git a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir new file mode 100644 index 000000000000..2edb311b8e4e --- /dev/null +++ b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir @@ -0,0 +1,100 @@ +// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --verify-diagnostics | FileCheck %s + +#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tdm_load_runtime_strides(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) { + %c_shape = arith.constant 128 : i32 + %c_stride0 = arith.constant 128 : i64 + %c_stride1 = arith.constant 1 : i64 + // expected-error @+2 {{requires at least one dimension to have stride 1}} + // expected-error @+1 {{failed to legalize operation}} + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride0, %stride1] : , > + tt.return + } +} + +// ----- + +#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tdm_1x1_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) { + %c_shape = arith.constant 1 : i32 + %c_stride1 = arith.constant 1 : i64 + // expected-error @+2 {{requires at least one dimension to have stride 1}} + // expected-error @+1 {{failed to legalize operation}} + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride1, %stride1] : , > + tt.return + } +} + +// ----- + +#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tdm_wrong_stride_order(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { + %c_shape = arith.constant 128 : i32 + %c_stride1 = arith.constant 1 : i64 + // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} + // expected-error @+1 {{failed to legalize operation}} + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : , > + tt.return + } +} + +// ----- + +#shared = #ttg.padded_shared<[32:+4] {order = [0, 1], shape = [64, 64]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tdm_wrong_smem_order(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { + %c_shape = arith.constant 128 : i32 + %c_stride1 = arith.constant 1 : i64 + // expected-error @+2 {{requires shared order [rank-1, rank-2, ..., 0]}} + // expected-error @+1 {{failed to legalize operation}} + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%runtime_stride, %c_stride1] : , > + tt.return + } +} + +// ----- + +// Positive test case for 1x1x1 tensor +#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tdm_1x1x1 + tt.func public @tdm_1x1x1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c_stride1 = arith.constant 1 : i64 + %c_shape = arith.constant 1 : i32 + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape], [%c_stride1, %c_stride1, %c_stride1] : , > + tt.return + } +} + +// ----- + +// Positive test case for Xx1x1 tensor +#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tdm_Xx1x1 + tt.func public @tdm_Xx1x1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { + %c_stride1 = arith.constant 1 : i64 + %c_shape = arith.constant 1 : i32 + %c_shape2 = arith.constant 128 : i32 + %0 = tt.make_tensor_descriptor %arg0, [%c_shape2, %c_shape, %c_shape], [%runtime_stride, %c_stride1, %c_stride1] : , > + tt.return + } +} + +// ----- + +#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tdm_1xix1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { + %c_stride1 = arith.constant 1 : i64 + %c_shape = arith.constant 1 : i32 + %c_shape2 = arith.constant 128 : i32 + // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} + // expected-error @+1 {{failed to legalize operation}} + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape2, %c_shape], [%c_stride1, %runtime_stride, %c_stride1] : , > + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index ed23598fb45a..d78cfa63a00b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -12,6 +12,46 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; namespace { +// Validates that the tensor descriptor's strides and shared layout are +// compatible with TDM. The shared order must be [rank-1, rank-2, ..., 0], +// and all stride-1 dimensions are consecutive at the end (trailing dims). +// TODO: this is too strict but the lowering doesn't support reordering +// dimensions yet; once it does we can relax these checks. +LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, + Attribute sharedEnc, + ArrayRef shape, + ValueRange strides) { + int rank = shape.size(); + auto sharedOrder = triton::gpu::getOrder( + cast(sharedEnc), shape); + + for (int i = 0; i < rank; ++i) { + if (sharedOrder[i] != rank - 1 - i) + return op.emitError() << "requires shared order [rank-1, rank-2, ..., 0]"; + } + + SmallVector strideOneDims; + for (auto [dim, strideVal] : llvm::enumerate(strides)) { + auto cst = getConstantIntValue(getAsOpFoldResult(strideVal)); + if (cst.value_or(0) == 1) { + strideOneDims.push_back(dim); + } + } + + if (strideOneDims.empty()) + return op.emitError() << "requires at least one dimension to have stride 1"; + + int expectedDim = rank - 1; + for (auto it = strideOneDims.rbegin(); it != strideOneDims.rend(); ++it) { + if (*it != expectedDim--) { + return op.emitError() << "requires all stride 1 dimensions to be " + "consecutive starting from the last dimension"; + } + } + + return success(); +} + // Collects all users of the value beyond the basic block boundaries // defining a given value. void collectUsers(Value value, llvm::SetVector &users) { @@ -109,6 +149,11 @@ struct MakeTensorDescOpConversion int numWarps = lookupNumWarps(op); auto shapePerCTA = triton::gpu::getShapePerCTA(sharedEnc, blockShape); + if (failed(validateStridesAndSharedOrder(op, sharedEnc, shapePerCTA, + tensorStride))) { + return failure(); + } + // Create TDM descriptor for 2D-5D tensors auto tdmDesc = LLVM::AMD::createTDMDescriptor( rewriter, loc, getTypeConverter(), elementType, shapePerCTA, numWarps, diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index f0b66fb23a9b..a87fc1895e0a 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -1704,6 +1704,7 @@ def test_compile_tensor_descriptor_prefetch_nd(dtype, ndim, INNER_BLOCK, SPECULA order=[ndim - 1 - i for i in range(ndim)]) BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:] + STRIDES = (1, 2, 4, 8, 16)[:ndim][::-1] shape_str = ", ".join(str(s) for s in BLOCK_SHAPE) if TDM_TYPE == "DEVICE_TDM": @@ -1711,13 +1712,15 @@ def test_compile_tensor_descriptor_prefetch_nd(dtype, ndim, INNER_BLOCK, SPECULA signature = { "a_ptr": f"*{dtype}", "shape": tuple("i32" for _ in range(ndim)), - "strides": tuple("i32" for _ in range(ndim)), + "strides": tuple("constexpr" for _ in range(ndim)), "BLOCK_SHAPE": tuple("constexpr" for _ in range(ndim)), "SHARED_LAYOUT": "constexpr", "PREFETCH_SPECULATIVE": "constexpr", } constexprs = { - # For tuples we need to specifiy the parameter index (BLOCK_SHAPE is the 3rd argument) + # For tuples we need to specifiy the parameter index + **{(2, i): STRIDES[i] + for i in range(ndim)}, **{(3, i): BLOCK_SHAPE[i] for i in range(ndim)}, "SHARED_LAYOUT": SHARED_LAYOUT, From 7492d45ce9cbf735be7d8c547237cdaa9961385b Mon Sep 17 00:00:00 2001 From: "Weinrauch, Alexander" Date: Tue, 3 Mar 2026 10:36:33 +0000 Subject: [PATCH 2/4] [AMD][BACKEND] TDM allow col-major tensors This PR allows TDM load and store with column-major (order=[0,1]) tensors. For this to work we need to swap the dimensions in the TDM descriptor to ensure the stride==1 dimension from Triton is the first dim in the HW TDM descriptor. Note that the order of dimension between Triton and our HW is reversed. For the >2D case we only allow to swap the last two dimension, this means any batch dims are not allowed to be the fastest dim. I adjusted the run_tensor_descriptor_load_store_test lit tests to not test uint dtypes because they do not affect the test case but makes the torch handling more tricky and I was unable to make the transpose work. Note that we cannot do this for gather and scatter because it would reverse the meaning of the indices (from rows to columns). --- .../amd/tritongpu_tdm_stride_order.mlir | 17 ++- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 16 ++- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 15 ++- .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp | 36 +++++- .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.h | 14 ++- .../TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp | 52 ++++++--- .../amd/python/test/test_gluon_gfx1250.py | 105 +++++++++++------- 7 files changed, 181 insertions(+), 74 deletions(-) diff --git a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir index 2edb311b8e4e..b77aa125acba 100644 --- a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir +++ b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir @@ -34,7 +34,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.func public @tdm_wrong_stride_order(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { %c_shape = arith.constant 128 : i32 %c_stride1 = arith.constant 1 : i64 - // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} + // expected-error @+2 {{requires shared order [rank-2, rank-1, rank-3, rank-4, ..., 0] because dim[rank-2] has stride 1}} // expected-error @+1 {{failed to legalize operation}} %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : , > tt.return @@ -98,3 +98,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.return } } + +// ----- + +#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tdm_1x1xi(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { + %c_stride1 = arith.constant 1 : i64 + %c_shape = arith.constant 1 : i32 + %c_shape2 = arith.constant 128 : i32 + // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} + // expected-error @+1 {{failed to legalize operation}} + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %c_stride1, %runtime_stride] : , > + tt.return + } +} diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 517b5ce66227..243727b00c5c 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -685,7 +685,7 @@ LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() { if (intervals.size() != 1) return emitOpError("TDM store only supports single interval paddings."); - if (intervals[0] != blockShape.back()) + if (intervals[0] != blockShape[paddedEnc.getOrder().front()]) return emitOpError("TDM store padding is only supported when padding " "interval equals the innermost block dimension (got " "padInterval=") @@ -739,6 +739,13 @@ LogicalResult AsyncTDMScatterOp::verify() { if (!paddedEnc && !swizzledEnc) return emitOpError("Invalid shared memory layout for TDM"); + auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy); + auto sharedOrder = triton::gpu::getOrder( + cast(smemTy.getEncoding()), + shapePerCTA); + if (sharedOrder[0] != (sharedOrder.size() - 1)) + return emitOpError("TDM scatter only supports row-major shared order"); + return success(); } @@ -784,6 +791,13 @@ LogicalResult AsyncTDMGatherOp::verify() { if (!paddedEnc && !swizzledEnc) return emitOpError("Invalid shared memory layout for TDM"); + auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy); + auto sharedOrder = triton::gpu::getOrder( + cast(smemTy.getEncoding()), + shapePerCTA); + if (sharedOrder[0] != (sharedOrder.size() - 1)) + return emitOpError("TDM gather only supports row-major shared order"); + return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index aae3ea21c36e..ff16c2bd771d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1249,10 +1249,16 @@ struct AsyncTDMCopyGlobalToLocalOpConversion auto ctaId = targetInfo.getClusterCTAId(rewriter, loc); auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy); + auto sharedOrder = triton::gpu::getOrder( + cast(smemTy.getEncoding()), + shapePerCTA); + bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); + mlir::LLVM::AMD::emitTDMLoadStore( rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps, padInterval, padAmount, offset, dstPtrs, op.getPred(), multicastMask, - elementType, barrierPtr, /*isLoad=*/true, sharedLayout, ctaId); + elementType, barrierPtr, /*isLoad=*/true, sharedLayout, ctaId, + isRowMajor); rewriter.eraseOp(op); return success(); @@ -1327,12 +1333,17 @@ struct AsyncTDMCopyLocalToGlobalOpConversion auto ctaId = targetInfo.getClusterCTAId(rewriter, loc); auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy); + auto sharedOrder = triton::gpu::getOrder( + cast(smemTy.getEncoding()), + shapePerCTA); + bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); + Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 32); mlir::LLVM::AMD::emitTDMLoadStore( rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps, padInterval, padAmount, offset, srcPtrs, pred, /*multicastMask=*/{}, elementType, barrierPtr, - /*isLoad=*/false, sharedLayout, ctaId); + /*isLoad=*/false, sharedLayout, ctaId, isRowMajor); rewriter.eraseOp(op); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp index 1577d03e1dc8..4e573f376744 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -44,6 +44,12 @@ SmallVector TDMDescriptor::getAllGroups() const { return result; } +// Swap the trailing two dimensions of a vector for TDM operations. +template void swapTrailingDims(SmallVector &vec) { + assert(vec.size() >= 2 && "need at least 2 dims to swap"); + std::swap(vec[vec.size() - 2], vec[vec.size() - 1]); +} + // Decode a full TDM descriptor from all 4 group vectors for 3D-5D tensors // Returns (base, tensorShape[], tensorStride[], blockShape[]) std::tuple, SmallVector, SmallVector> @@ -143,8 +149,8 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, SmallVector blockShape, int numWarps, unsigned padInterval, unsigned padAmount, SmallVector tensorShape, - SmallVector tensorStride, - Value srcPtr) { + SmallVector tensorStride, Value srcPtr, + bool isRowMajor) { size_t numDims = tensorShape.size(); assert(numDims >= 1 && numDims <= 5 && tensorStride.size() == numDims && "TDM only supported for 1D-5D tensors."); @@ -154,6 +160,12 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, auto ctx = rewriter.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); + if (!isRowMajor) { + swapTrailingDims(blockShape); + swapTrailingDims(tensorStride); + swapTrailingDims(tensorShape); + } + // Define common values for better readability Value v16 = b.i32_val(16); Value v32 = b.i64_val(32); @@ -416,7 +428,8 @@ void fillTDMDescriptor( std::optional>> group3, SmallVector offset, ArrayRef dstPtrs, Value pred, Value multicastMask, Value barrierPtr, - const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore) { + const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore, + bool isRowMajor) { size_t numDims = offset.size(); assert(numDims >= 1 && numDims <= 5 && "TDM supports 1D to 5D tensors."); assert(!dstPtrs.empty() && "dstPtrs cannot be empty"); @@ -439,6 +452,11 @@ void fillTDMDescriptor( : std::nullopt, numDims); + if (!isRowMajor) { + swapTrailingDims(decodedBlockShape); + swapTrailingDims(offset); + } + auto kMessage = str_attr("message"); auto kWarp = str_attr("warp"); auto kBlock = str_attr("block"); @@ -750,7 +768,8 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, ArrayRef offset, ArrayRef dstPtrs, Value pred, Value multicastMask, Type elementType, Value barrierPtr, bool isLoad, - const triton::LinearLayout &sharedLayout, Value ctaId) { + const triton::LinearLayout &sharedLayout, Value ctaId, + bool isRowMajor) { auto b = TritonLLVMOpBuilder(loc, rewriter); assert(shapePerCTA.size() <= 5); @@ -768,7 +787,8 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, to_vector(shapePerCTA), numWarps, padInterval, padAmount, group0Vec, group1Vec, std::ref(group2Vec), std::ref(group3Vec), to_vector(offset), dstPtrs, pred, - multicastMask, barrierPtr, sharedLayout, ctaId, !isLoad); + multicastMask, barrierPtr, sharedLayout, ctaId, !isLoad, + isRowMajor); auto group0 = packLLVector(loc, group0Vec, rewriter); auto group1 = packLLVector(loc, group1Vec, rewriter); @@ -788,7 +808,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, to_vector(shapePerCTA), numWarps, padInterval, padAmount, group0Vec, group1Vec, std::nullopt, std::nullopt, to_vector(offset), dstPtrs, pred, multicastMask, - barrierPtr, sharedLayout, ctaId, !isLoad); + barrierPtr, sharedLayout, ctaId, !isLoad, isRowMajor); auto group0 = packLLVector(loc, group0Vec, rewriter); auto group1 = packLLVector(loc, group1Vec, rewriter); @@ -1029,4 +1049,8 @@ SmallVector emitTDMPrefetch(RewriterBase &rewriter, Location loc, return offsets; } +bool needsTrailingDimSwapForTDM(ArrayRef sharedOrder) { + return sharedOrder[0] != (sharedOrder.size() - 1); +} + } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h index 716f5c88b514..80e189bb2180 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h @@ -31,8 +31,8 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, SmallVector blockShape, int numWarps, unsigned padInterval, unsigned padAmount, SmallVector tensorShape, - SmallVector tensorStride, - Value srcPtr); + SmallVector tensorStride, Value srcPtr, + bool isRowMajor); // Update the global memory address with offset, and fill the shared memory // address and pred in a given TDM descriptor for regular load/store (1D-5D). @@ -47,7 +47,8 @@ void fillTDMDescriptor( std::optional>> group3, SmallVector offset, ArrayRef dstPtrs, Value pred, Value multicastMask, Value barrierPtr, - const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore); + const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore, + bool isRowMajor); // Fill TDM descriptor for gather/scatter operations (2D only). // Gather reads from non-contiguous rows in global memory to LDS. @@ -80,7 +81,8 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, ArrayRef offset, ArrayRef dstPtrs, Value pred, Value multicastMask, Type elementType, Value barrierPtr, bool isLoad, - const triton::LinearLayout &sharedLayout, Value ctaId); + const triton::LinearLayout &sharedLayout, Value ctaId, + bool isRowMajor); // Calculate the number of TDM gather/scatter instructions needed. // - numIndices: number of row indices @@ -124,6 +126,10 @@ SmallVector emitTDMPrefetch(RewriterBase &rewriter, Location loc, Type elementType, Value laneId, Value warpId, Value ctaId, bool isSpeculative); +// Returns true if the shared memory encoding has is not row-mjaor, requiring +// the trailing two dimensions to be swapped for TDM. +bool needsTrailingDimSwapForTDM(ArrayRef sharedOrder); + } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TDMUTILITY_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index d78cfa63a00b..3410db682c0c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -15,8 +15,9 @@ namespace { // Validates that the tensor descriptor's strides and shared layout are // compatible with TDM. The shared order must be [rank-1, rank-2, ..., 0], // and all stride-1 dimensions are consecutive at the end (trailing dims). -// TODO: this is too strict but the lowering doesn't support reordering -// dimensions yet; once it does we can relax these checks. +// Additionally if we have single stride-1 dimension we allow it to be in rank-2 +// position, the lowering will swap the dimensions. However, this also requires +// the shared order to have rank-2 and rank-1 to be swapped. LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, Attribute sharedEnc, ArrayRef shape, @@ -25,27 +26,39 @@ LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, auto sharedOrder = triton::gpu::getOrder( cast(sharedEnc), shape); - for (int i = 0; i < rank; ++i) { - if (sharedOrder[i] != rank - 1 - i) - return op.emitError() << "requires shared order [rank-1, rank-2, ..., 0]"; - } - SmallVector strideOneDims; for (auto [dim, strideVal] : llvm::enumerate(strides)) { - auto cst = getConstantIntValue(getAsOpFoldResult(strideVal)); - if (cst.value_or(0) == 1) { + if (getConstantIntValue(getAsOpFoldResult(strideVal)).value_or(0) == 1) strideOneDims.push_back(dim); - } } if (strideOneDims.empty()) return op.emitError() << "requires at least one dimension to have stride 1"; - int expectedDim = rank - 1; - for (auto it = strideOneDims.rbegin(); it != strideOneDims.rend(); ++it) { - if (*it != expectedDim--) { - return op.emitError() << "requires all stride 1 dimensions to be " - "consecutive starting from the last dimension"; + // If the only stride-1 dim is the second-to-last dimension (col-major) we can + // safely reorder the dimensions during lowering. + bool isColMajor = + strideOneDims.size() == 1 && strideOneDims.front() == rank - 2; + + SmallVector expectedOrder(llvm::reverse(llvm::seq(rank))); + if (isColMajor) + std::swap(expectedOrder[0], expectedOrder[1]); + + if (sharedOrder != ArrayRef(expectedOrder)) { + if (isColMajor) + return op.emitError() + << "requires shared order [rank-2, rank-1, rank-3, " + "rank-4, ..., 0] because dim[rank-2] has stride 1"; + return op.emitError() << "requires shared order [rank-1, rank-2, ..., 0]"; + } + + if (strideOneDims.size() > 1) { + unsigned k = strideOneDims.size(); + unsigned numStride1Dims = strideOneDims.size(); + for (unsigned i = 0; i < numStride1Dims; ++i) { + if (strideOneDims[i] != rank - numStride1Dims + i) + return op.emitError() << "requires all stride 1 dimensions to be " + "consecutive starting from the last dimension"; } } @@ -120,8 +133,8 @@ struct MakeTensorDescOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto basePtr = adaptor.getBase(); - auto tensorShape = adaptor.getShape(); - auto tensorStride = adaptor.getStrides(); + auto tensorShape = llvm::to_vector(adaptor.getShape()); + auto tensorStride = llvm::to_vector(adaptor.getStrides()); auto result = op.getResult(); auto tensorDescTy = result.getType(); @@ -153,11 +166,14 @@ struct MakeTensorDescOpConversion tensorStride))) { return failure(); } + auto sharedOrder = triton::gpu::getOrder( + cast(sharedEnc), shapePerCTA); + bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); // Create TDM descriptor for 2D-5D tensors auto tdmDesc = LLVM::AMD::createTDMDescriptor( rewriter, loc, getTypeConverter(), elementType, shapePerCTA, numWarps, - padInterval, padAmount, tensorShape, tensorStride, basePtr); + padInterval, padAmount, tensorShape, tensorStride, basePtr, isRowMajor); SmallVector groups = tdmDesc.getAllGroups(); diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index a87fc1895e0a..ff05e3236114 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -14,7 +14,7 @@ import triton import triton.language as tl from triton.backends.compiler import GPUTarget -from triton._internal_testing import is_hip_gfx1250, str_to_triton_dtype, numpy_random, to_triton, unwrap_tensor, dtypes_with_bfloat16, uint_dtypes +from triton._internal_testing import is_hip_gfx1250, str_to_triton_dtype, numpy_random, to_triton, unwrap_tensor, float_dtypes, int_dtypes, uint_dtypes from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from triton.experimental import gluon import triton.experimental.gluon.language as ttgl @@ -327,7 +327,7 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr, # - NUM_BUFFERS: ttgl.constexpr, USE_TDM: ttgl.constexpr): + NUM_BUFFERS: ttgl.constexpr, USE_TDM: ttgl.constexpr, IS_B_K_CONTIG: ttgl.constexpr): a_dtype: ttgl.constexpr = a_ptr.type.element_ty b_dtype: ttgl.constexpr = b_ptr.type.element_ty ttgl.static_assert(a_dtype.is_fp16() or a_dtype.is_bf16(), "Only fp16/bf16 supported for A") @@ -335,11 +335,19 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # ttgl.static_assert(NUM_BUFFERS >= 2, "NUM_BUFFERS must be at least 2") BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + if IS_B_K_CONTIG: + BLOCKED_LAYOUT_B: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + else: + BLOCKED_LAYOUT_B: ttgl.constexpr = ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1]) WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, 32]) SHARED_LAYOUT_A: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_K, 8]], [BLOCK_M, BLOCK_K], [1, 0]) - SHARED_LAYOUT_B: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_N, 8]], [BLOCK_K, BLOCK_N], - [1, 0]) + if IS_B_K_CONTIG: + SHARED_LAYOUT_B: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_N, 8]], [BLOCK_K, BLOCK_N], + [1, 0]) + else: + SHARED_LAYOUT_B: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_K, 8]], [BLOCK_K, BLOCK_N], + [0, 1]) OPERAND_LAYOUT_A: ttgl.constexpr = ttgl.DotOperandLayout(0, WMMA_LAYOUT, 8) OPERAND_LAYOUT_B: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8) @@ -367,8 +375,8 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # offs_am = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))) % M a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak - offs_bk = ttgl.arange(0, BLOCK_K, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT)) - offs_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT))) % N + offs_bk = ttgl.arange(0, BLOCK_K, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT_B)) + offs_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT_B))) % N b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn a_buffer = ttgl.allocate_shared_memory(a_desc.dtype, shape=[NUM_BUFFERS] + a_desc.block_shape, layout=a_desc.layout) @@ -449,23 +457,27 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # @pytest.mark.parametrize("BLOCK_M,BLOCK_N,BLOCK_K", [(m, n, k) for (m, n) in [(32, 32), (64, 64)] \ for k in [32, 64]]) @pytest.mark.parametrize("NUM_BUFFERS", [2, 4]) +@pytest.mark.parametrize("IS_B_K_CONTIG", [True, False]) @pytest.mark.parametrize("ASYNC_LOAD_TYPE", ["ASYNC_COPY", "TDM"]) -def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, ASYNC_LOAD_TYPE): +def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, IS_B_K_CONTIG, ASYNC_LOAD_TYPE): # Inner strides need to be constexpr (1) to get contiguity. Note the compiler frontend does the same for normal dispatches signature = { "a_ptr": "*fp16", "b_ptr": "*fp16", "c_ptr": "*fp32", # "M": "i32", "N": "i32", "K": "i32", # "stride_am": "i32", "stride_ak": "constexpr", # - "stride_bk": "i32", "stride_bn": "constexpr", # + "stride_bk": "constexpr", "stride_bn": "constexpr", # "stride_cm": "i32", "stride_cn": "constexpr", # "BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "BLOCK_K": "constexpr", # - "NUM_BUFFERS": "constexpr", "USE_TDM": "constexpr" + "NUM_BUFFERS": "constexpr", "USE_TDM": "constexpr", "IS_B_K_CONTIG": "constexpr" } constexprs = { - "stride_ak": 1, "stride_bn": 1, "stride_cn": 1, "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, - "NUM_BUFFERS": NUM_BUFFERS, "USE_TDM": ASYNC_LOAD_TYPE == "TDM" + "stride_ak": 1, "stride_cn": 1, "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, "NUM_BUFFERS": + NUM_BUFFERS, "USE_TDM": ASYNC_LOAD_TYPE == "TDM", "IS_B_K_CONTIG": IS_B_K_CONTIG } + constexprs["stride_bn"] = BLOCK_N if not IS_B_K_CONTIG else 1 + constexprs["stride_bk"] = BLOCK_K if IS_B_K_CONTIG else 1 + fn = gemm_async_pipelined_kernel # AsyncCopy requires >= 32 bits per lane so we have to pass divisibility for arguments used in pointer arithmetic @@ -485,7 +497,7 @@ def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, AS assert len(re.findall("tensor_load_to_lds", amdgcn)) == NUM_BUFFERS * 2 else: copy_instr_for_A = BLOCK_M // 4 // 4 - copy_isntr_for_B = BLOCK_K // 4 // 4 + copy_isntr_for_B = (BLOCK_K if IS_B_K_CONTIG else BLOCK_N) // 4 // 4 copy_instr_per_iter = copy_instr_for_A + copy_isntr_for_B for cnt in range(NUM_BUFFERS - 1, -1, -1): assert re.search(f"s_wait_asynccnt 0x{(cnt * copy_instr_per_iter):x}", amdgcn) @@ -497,8 +509,9 @@ def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, AS for k in [32, 64]]) @pytest.mark.parametrize("NUM_BUFFERS", [2, 4]) @pytest.mark.parametrize("M,N,K", [(256, 256, 512), (240, 240, 496), (250, 250, 510)]) +@pytest.mark.parametrize("IS_B_K_CONTIG", [True, False]) @pytest.mark.parametrize("ASYNC_LOAD_TYPE", ["ASYNC_COPY", "TDM"]) -def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, N, K, ASYNC_LOAD_TYPE): +def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, N, K, IS_B_K_CONTIG, ASYNC_LOAD_TYPE): if triton.cdiv(K, BLOCK_K) < NUM_BUFFERS: pytest.skip("Skip tests where K/BLOCK_K < NUM_BUFFERS") @@ -510,13 +523,15 @@ def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, a = torch.randn((M, K), dtype=torch.float16) b = torch.randn((K, N), dtype=torch.float16) c = torch.zeros((M, N), dtype=torch.float32) - stride_am, stride_ak = a.stride(0), a.stride(1) - stride_bk, stride_bn = b.stride(0), b.stride(1) - stride_cm, stride_cn = c.stride(0), c.stride(1) a_device = a.cuda() - b_device = b.cuda() + b_device = b.cuda() if IS_B_K_CONTIG else b.data.T.contiguous().T.cuda() c_device = c.cuda() + + stride_am, stride_ak = a_device.stride(0), a_device.stride(1) + stride_bk, stride_bn = b_device.stride(0), b_device.stride(1) + stride_cm, stride_cn = c_device.stride(0), c_device.stride(1) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) gemm_async_pipelined_kernel[grid]( a_device, b_device, c_device, # @@ -525,7 +540,7 @@ def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, # - NUM_BUFFERS=NUM_BUFFERS, USE_TDM=ASYNC_LOAD_TYPE == "TDM") + NUM_BUFFERS=NUM_BUFFERS, USE_TDM=ASYNC_LOAD_TYPE == "TDM", IS_B_K_CONTIG=IS_B_K_CONTIG) c_triton = c_device.cpu() c_torch = a.to(torch.float32) @ b.to(torch.float32) @@ -1567,7 +1582,9 @@ def tensor_descriptor_load_store_nd_kernel_host_tdm(out_desc, inp_desc): ttgl.amd.gfx1250.tdm.async_wait(0) -def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT): +def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT, ROW_MAJOR): + if not ROW_MAJOR and TDM_TYPE == "HOST_TDM": + pytest.skip("NYI: Host TDM does not support non-row major layouts") """Utility function to run TDM load/store tests with a given shared layout.""" alloc_shape = [1, 1, 3, 7, INNER_BLOCK][-ndim:] @@ -1575,17 +1592,13 @@ def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYP inp = to_triton(numpy_random(alloc_shape, dtype_str), device="cpu", dst_type=dtype_str) inp.data = inp.data[..., :INNER_BLOCK - 3] out = inp.new_empty(BLOCK_SHAPE) - # uint_dtypes require special handling because PyTorch only has full native support - # for uint8. While PyTorch 2.1+ added limited support for uint16, uint32, and uint64, - # they still lack complete functionality across all PyTorch ops. They are stored as - # signed tensors with the same bit width and wrapped in TensorWrapper for reinterpretation - # to unsigned. The .base attribute accesses the underlying signed tensor for CUDA transfer. - if dtype_str in uint_dtypes: - inp.base = inp.base.cuda() - out.base = out.base.cuda() - else: - inp = inp.cuda() - out = out.cuda() + + if ndim > 1 and not ROW_MAJOR: + out = out.data.transpose(-2, -1).contiguous().transpose(-2, -1) + inp = inp.data.transpose(-2, -1).contiguous().transpose(-2, -1) + + inp = inp.cuda() + out = out.cuda() if TDM_TYPE == "DEVICE_TDM": constexpr_block_shape = tuple(ttgl.constexpr(v) for v in BLOCK_SHAPE) @@ -1616,19 +1629,25 @@ def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYP @pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("INNER_BLOCK", [4, 8, 16, 32, 64, 128]) -@pytest.mark.parametrize("dtype_str", sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})) +@pytest.mark.parametrize("dtype_str", sorted(set(float_dtypes + int_dtypes))) @pytest.mark.parametrize("TDM_TYPE", ["DEVICE_TDM", "HOST_TDM"]) -def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK, TDM_TYPE): +@pytest.mark.parametrize("ROW_MAJOR", [False, True]) +def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, ROW_MAJOR): """Test TDM load/store with swizzled shared layout.""" - SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, - order=[ndim - 1 - i for i in range(ndim)]) - _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT) + + order = [ndim - 1 - i for i in range(ndim)] + if ndim > 1 and not ROW_MAJOR: + order[0], order[1] = order[1], order[0] + SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=order) + + _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT, ROW_MAJOR) @pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("INNER_BLOCK", [16, 64]) @pytest.mark.parametrize("dtype_str", ["float16", "int32"]) -def test_tensor_descriptor_load_store_nd_with_padding(dtype_str, ndim, INNER_BLOCK): +@pytest.mark.parametrize("ROW_MAJOR", [False, True]) +def test_tensor_descriptor_load_store_nd_with_padding(dtype_str, ndim, INNER_BLOCK, ROW_MAJOR): """Test TDM load/store with padded shared memory layout. TDM store only supports padding when: 1. There is a single padding interval @@ -1636,12 +1655,14 @@ def test_tensor_descriptor_load_store_nd_with_padding(dtype_str, ndim, INNER_BLO """ # Create padded shared layout where padding interval = innermost block dimension BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:] - PADDED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[INNER_BLOCK, 8]], BLOCK_SHAPE, - [ndim - 1 - i - for i in range(ndim)] # standard order - ) - - _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, "DEVICE_TDM", PADDED_LAYOUT) + order = [ndim - 1 - i for i in range(ndim)] + padding = [INNER_BLOCK, 8] + if ndim > 1 and not ROW_MAJOR: + order[0], order[1] = order[1], order[0] + padding = [8, INNER_BLOCK] + PADDED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([padding], BLOCK_SHAPE, order) + + _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, "DEVICE_TDM", PADDED_LAYOUT, ROW_MAJOR) def test_tensor_descriptor_load_store_invalid_blocksize(): From 7313d233c8de9085d2b27b3d8846c85d713029d4 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 13 Mar 2026 14:02:27 +0000 Subject: [PATCH 3/4] Fix swapTrailingDims --- .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp index 4e573f376744..03c86cabd984 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -418,6 +418,29 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, return TDMDescriptor{group0, group1, group2, group3}; } +// Returns a copy of `layout` where the semantics of dimA and dimB are +// exchanged: new.apply(x)[dimA] == old.apply(x)[dimB] and vice versa. The +// output dimension order is preserved. +static triton::LinearLayout +swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA, + StringAttr dimB) { + assert(layout.hasOutDim(dimA)); + assert(layout.hasOutDim(dimB)); + SmallVector> renamedOutDims; + for (auto [name, size] : layout.getOutDims()) { + if (name == dimA) + renamedOutDims.push_back({dimB, size}); + else if (name == dimB) + renamedOutDims.push_back({dimA, size}); + else + renamedOutDims.push_back({name, size}); + } + // Transpose to restore the original output dimension order. + return triton::LinearLayout(layout.getBases(), renamedOutDims, + /*requireSurjective=*/false) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); +} + // Fill TDM descriptor for regular load/store operations (1D-5D tensors) void fillTDMDescriptor( RewriterBase &rewriter, Location loc, @@ -440,6 +463,22 @@ void fillTDMDescriptor( Type globalPtrTy = ptr_ty(ctx, 1); Type sharedPtrTy = ptr_ty(ctx, 3); + // For col-major tensors the TDM descriptor was created with the trailing two + // dimensions swapped. Swap shapePerCTA and offset to match that hardware + // view, and rename the same two dims in the shared layout to align out dims. + std::optional adjustedSharedLayout; + if (!isRowMajor) { + swapTrailingDims(shapePerCTA); + swapTrailingDims(offset); + if (numDims >= 2) { + auto dimN_2 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 2)); + auto dimN_1 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 1)); + adjustedSharedLayout = swapOutDimSemantics(sharedLayout, dimN_2, dimN_1); + } + } + const auto &tdmViewSharedLayout = + adjustedSharedLayout ? *adjustedSharedLayout : sharedLayout; + // Decode the full TDM descriptor to get all values auto [srcPtr, tensorShape, tensorStride, decodedBlockShape] = decodeTDMDescriptorFull( @@ -452,11 +491,6 @@ void fillTDMDescriptor( : std::nullopt, numDims); - if (!isRowMajor) { - swapTrailingDims(decodedBlockShape); - swapTrailingDims(offset); - } - auto kMessage = str_attr("message"); auto kWarp = str_attr("warp"); auto kBlock = str_attr("block"); @@ -464,7 +498,7 @@ void fillTDMDescriptor( auto kPartition = str_attr("partition"); auto cgaLayout = triton::gpu::SharedLinearEncodingAttr::get( - ctx, sharedLayout, /*layoutAlignment=*/16) + ctx, tdmViewSharedLayout, /*layoutAlignment=*/16) .getCGALayout() .getLinearLayout(); @@ -492,7 +526,7 @@ void fillTDMDescriptor( } srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset); - auto tdmToShared = tdmLayout.invertAndCompose(sharedLayout); + auto tdmToShared = tdmLayout.invertAndCompose(tdmViewSharedLayout); auto sharedOffsets = applyLinearLayout( loc, rewriter, tdmToShared, {{kMessage, b.i32_val(0)}, {kWarp, warpId}, {kBlock, ctaId}}); From 8de5d4912c61ea374d3dbe8a32e73cf9919b6c0c Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Mon, 16 Mar 2026 11:44:45 +0000 Subject: [PATCH 4/4] Cleanup --- third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp | 5 ----- third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h | 4 ---- .../amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp | 10 +++++----- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp index 03c86cabd984..7ce34eb31dbf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -1082,9 +1082,4 @@ SmallVector emitTDMPrefetch(RewriterBase &rewriter, Location loc, } return offsets; } - -bool needsTrailingDimSwapForTDM(ArrayRef sharedOrder) { - return sharedOrder[0] != (sharedOrder.size() - 1); -} - } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h index 80e189bb2180..1a5fd15e0435 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h @@ -126,10 +126,6 @@ SmallVector emitTDMPrefetch(RewriterBase &rewriter, Location loc, Type elementType, Value laneId, Value warpId, Value ctaId, bool isSpeculative); -// Returns true if the shared memory encoding has is not row-mjaor, requiring -// the trailing two dimensions to be swapped for TDM. -bool needsTrailingDimSwapForTDM(ArrayRef sharedOrder); - } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TDMUTILITY_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index 3410db682c0c..aacaaba14582 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -13,11 +13,11 @@ using namespace mlir::triton::gpu; namespace { // Validates that the tensor descriptor's strides and shared layout are -// compatible with TDM. The shared order must be [rank-1, rank-2, ..., 0], -// and all stride-1 dimensions are consecutive at the end (trailing dims). -// Additionally if we have single stride-1 dimension we allow it to be in rank-2 -// position, the lowering will swap the dimensions. However, this also requires -// the shared order to have rank-2 and rank-1 to be swapped. +// compatible with TDM. Requirements: +// - The shared order must be [rank-1, rank-2, ..., 0]. +// - All stride-1 dimensions must be consecutive trailing dims. +// Additionally, a single stride-1 dimension may appear at the rank-2 +// position (col-major) if the shared order has rank-2 and rank-1 swapped. LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, Attribute sharedEnc, ArrayRef shape,