Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions test/Conversion/amd/tritongpu_tdm_stride_order.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --verify-diagnostics | FileCheck %s

#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_load_runtime_strides(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) {
%c_shape = arith.constant 128 : i32
%c_stride0 = arith.constant 128 : i64
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires at least one dimension to have stride 1}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride0, %stride1] : <f16>, <tensor<64x64xf16, #shared>>
tt.return
}
}

// -----

#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_1x1_tensor(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) {
%c_shape = arith.constant 1 : i32
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires at least one dimension to have stride 1}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride1, %stride1] : <f16>, <tensor<64x64xf16, #shared>>
tt.return
}
}

// -----

#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_wrong_stride_order(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_shape = arith.constant 128 : i32
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires shared order [rank-2, rank-1, rank-3, rank-4, ..., 0] because dim[rank-2] has stride 1}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : <f16>, <tensor<64x64xf16, #shared>>
Copy link
Copy Markdown
Contributor

@peterbell10 peterbell10 Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In NVIDIA TMA support we limit the descriptor to only the logical order that is handled by the hardware, and transposed loads are handled by putting a transpose (view) after the load in the program. Is there a compelling reason why AMD should be different?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points Peter! This is sync'ing out some internal changes of our initial impl. We are taking another look on this based the pointers. :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done via #10078

tt.return
}
}

// -----

#shared = #ttg.padded_shared<[32:+4] {order = [0, 1], shape = [64, 64]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_wrong_smem_order(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_shape = arith.constant 128 : i32
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires shared order [rank-1, rank-2, ..., 0]}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%runtime_stride, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
tt.return
}
}

// -----

// Positive test case for 1x1x1 tensor
#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tdm_1x1x1
tt.func public @tdm_1x1x1(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%c_stride1 = arith.constant 1 : i64
%c_shape = arith.constant 1 : i32
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape], [%c_stride1, %c_stride1, %c_stride1] : <f16>, <tensor<1x1x1xf16, #shared>>
tt.return
}
}

// -----

// Positive test case for Xx1x1 tensor
#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tdm_Xx1x1
tt.func public @tdm_Xx1x1(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_stride1 = arith.constant 1 : i64
%c_shape = arith.constant 1 : i32
%c_shape2 = arith.constant 128 : i32
%0 = tt.make_tensor_descriptor %arg0, [%c_shape2, %c_shape, %c_shape], [%runtime_stride, %c_stride1, %c_stride1] : <f16>, <tensor<1x1x1xf16, #shared>>
tt.return
}
}

// -----

#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_1xix1(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_stride1 = arith.constant 1 : i64
%c_shape = arith.constant 1 : i32
%c_shape2 = arith.constant 128 : i32
// expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape2, %c_shape], [%c_stride1, %runtime_stride, %c_stride1] : <f16>, <tensor<1x1x1xf16, #shared>>
tt.return
}
}

// -----

#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_1x1xi(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_stride1 = arith.constant 1 : i64
%c_shape = arith.constant 1 : i32
%c_shape2 = arith.constant 128 : i32
// expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %c_stride1, %runtime_stride] : <f16>, <tensor<1x1x1xf16, #shared>>
tt.return
}
}
16 changes: 15 additions & 1 deletion third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() {
if (intervals.size() != 1)
return emitOpError("TDM store only supports single interval paddings.");

if (intervals[0] != blockShape.back())
if (intervals[0] != blockShape[paddedEnc.getOrder().front()])
return emitOpError("TDM store padding is only supported when padding "
"interval equals the innermost block dimension (got "
"padInterval=")
Expand Down Expand Up @@ -739,6 +739,13 @@ LogicalResult AsyncTDMScatterOp::verify() {
if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(smemTy.getEncoding()),
shapePerCTA);
if (sharedOrder[0] != (sharedOrder.size() - 1))
return emitOpError("TDM scatter only supports row-major shared order");

return success();
}

Expand Down Expand Up @@ -784,6 +791,13 @@ LogicalResult AsyncTDMGatherOp::verify() {
if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(smemTy.getEncoding()),
shapePerCTA);
if (sharedOrder[0] != (sharedOrder.size() - 1))
return emitOpError("TDM gather only supports row-major shared order");

return success();
}

Expand Down
15 changes: 13 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,10 +1249,16 @@ struct AsyncTDMCopyGlobalToLocalOpConversion
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);

auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(smemTy.getEncoding()),
shapePerCTA);
bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1);

mlir::LLVM::AMD::emitTDMLoadStore(
rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps,
padInterval, padAmount, offset, dstPtrs, op.getPred(), multicastMask,
elementType, barrierPtr, /*isLoad=*/true, sharedLayout, ctaId);
elementType, barrierPtr, /*isLoad=*/true, sharedLayout, ctaId,
isRowMajor);

rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -1327,12 +1333,17 @@ struct AsyncTDMCopyLocalToGlobalOpConversion
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);

auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(smemTy.getEncoding()),
shapePerCTA);
bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1);

Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
mlir::LLVM::AMD::emitTDMLoadStore(
rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps,
padInterval, padAmount, offset, srcPtrs, pred,
/*multicastMask=*/{}, elementType, barrierPtr,
/*isLoad=*/false, sharedLayout, ctaId);
/*isLoad=*/false, sharedLayout, ctaId, isRowMajor);

rewriter.eraseOp(op);
return success();
Expand Down
71 changes: 62 additions & 9 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ SmallVector<Value> TDMDescriptor::getAllGroups() const {
return result;
}

// Swap the trailing two dimensions of a vector for TDM operations.
template <typename T> void swapTrailingDims(SmallVector<T> &vec) {
assert(vec.size() >= 2 && "need at least 2 dims to swap");
std::swap(vec[vec.size() - 2], vec[vec.size() - 1]);
}

// Decode a full TDM descriptor from all 4 group vectors for 3D-5D tensors
// Returns (base, tensorShape[], tensorStride[], blockShape[])
std::tuple<Value, SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
Expand Down Expand Up @@ -143,8 +149,8 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
SmallVector<int64_t> blockShape, int numWarps,
unsigned padInterval, unsigned padAmount,
SmallVector<Value> tensorShape,
SmallVector<Value> tensorStride,
Value srcPtr) {
SmallVector<Value> tensorStride, Value srcPtr,
bool isRowMajor) {
size_t numDims = tensorShape.size();
assert(numDims >= 1 && numDims <= 5 && tensorStride.size() == numDims &&
"TDM only supported for 1D-5D tensors.");
Expand All @@ -154,6 +160,12 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
auto ctx = rewriter.getContext();
auto b = TritonLLVMOpBuilder(loc, rewriter);

if (!isRowMajor) {
swapTrailingDims(blockShape);
swapTrailingDims(tensorStride);
swapTrailingDims(tensorShape);
}

// Define common values for better readability
Value v16 = b.i32_val(16);
Value v32 = b.i64_val(32);
Expand Down Expand Up @@ -406,6 +418,29 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
return TDMDescriptor{group0, group1, group2, group3};
}

// Returns a copy of `layout` where the semantics of dimA and dimB are
// exchanged: new.apply(x)[dimA] == old.apply(x)[dimB] and vice versa. The
// output dimension order is preserved.
static triton::LinearLayout
swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA,
StringAttr dimB) {
assert(layout.hasOutDim(dimA));
assert(layout.hasOutDim(dimB));
SmallVector<std::pair<StringAttr, int32_t>> renamedOutDims;
for (auto [name, size] : layout.getOutDims()) {
if (name == dimA)
renamedOutDims.push_back({dimB, size});
else if (name == dimB)
renamedOutDims.push_back({dimA, size});
else
renamedOutDims.push_back({name, size});
}
// Transpose to restore the original output dimension order.
return triton::LinearLayout(layout.getBases(), renamedOutDims,
/*requireSurjective=*/false)
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
}

// Fill TDM descriptor for regular load/store operations (1D-5D tensors)
void fillTDMDescriptor(
RewriterBase &rewriter, Location loc,
Expand All @@ -416,7 +451,8 @@ void fillTDMDescriptor(
std::optional<std::reference_wrapper<SmallVector<Value>>> group3,
SmallVector<Value> offset, ArrayRef<Value> dstPtrs, Value pred,
Value multicastMask, Value barrierPtr,
const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore) {
const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore,
bool isRowMajor) {
size_t numDims = offset.size();
assert(numDims >= 1 && numDims <= 5 && "TDM supports 1D to 5D tensors.");
assert(!dstPtrs.empty() && "dstPtrs cannot be empty");
Expand All @@ -427,6 +463,22 @@ void fillTDMDescriptor(
Type globalPtrTy = ptr_ty(ctx, 1);
Type sharedPtrTy = ptr_ty(ctx, 3);

// For col-major tensors the TDM descriptor was created with the trailing two
// dimensions swapped. Swap shapePerCTA and offset to match that hardware
// view, and rename the same two dims in the shared layout to align out dims.
std::optional<triton::LinearLayout> adjustedSharedLayout;
if (!isRowMajor) {
swapTrailingDims(shapePerCTA);
swapTrailingDims(offset);
if (numDims >= 2) {
auto dimN_2 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 2));
auto dimN_1 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 1));
adjustedSharedLayout = swapOutDimSemantics(sharedLayout, dimN_2, dimN_1);
}
}
const auto &tdmViewSharedLayout =
adjustedSharedLayout ? *adjustedSharedLayout : sharedLayout;

// Decode the full TDM descriptor to get all values
auto [srcPtr, tensorShape, tensorStride, decodedBlockShape] =
decodeTDMDescriptorFull(
Expand All @@ -446,7 +498,7 @@ void fillTDMDescriptor(
auto kPartition = str_attr("partition");

auto cgaLayout = triton::gpu::SharedLinearEncodingAttr::get(
ctx, sharedLayout, /*layoutAlignment=*/16)
ctx, tdmViewSharedLayout, /*layoutAlignment=*/16)
.getCGALayout()
.getLinearLayout();

Expand Down Expand Up @@ -474,7 +526,7 @@ void fillTDMDescriptor(
}
srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset);

auto tdmToShared = tdmLayout.invertAndCompose(sharedLayout);
auto tdmToShared = tdmLayout.invertAndCompose(tdmViewSharedLayout);
auto sharedOffsets = applyLinearLayout(
loc, rewriter, tdmToShared,
{{kMessage, b.i32_val(0)}, {kWarp, warpId}, {kBlock, ctaId}});
Expand Down Expand Up @@ -750,7 +802,8 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
ArrayRef<Value> offset, ArrayRef<Value> dstPtrs,
Value pred, Value multicastMask, Type elementType,
Value barrierPtr, bool isLoad,
const triton::LinearLayout &sharedLayout, Value ctaId) {
const triton::LinearLayout &sharedLayout, Value ctaId,
bool isRowMajor) {
auto b = TritonLLVMOpBuilder(loc, rewriter);

assert(shapePerCTA.size() <= 5);
Expand All @@ -768,7 +821,8 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
to_vector(shapePerCTA), numWarps, padInterval, padAmount,
group0Vec, group1Vec, std::ref(group2Vec),
std::ref(group3Vec), to_vector(offset), dstPtrs, pred,
multicastMask, barrierPtr, sharedLayout, ctaId, !isLoad);
multicastMask, barrierPtr, sharedLayout, ctaId, !isLoad,
isRowMajor);

auto group0 = packLLVector(loc, group0Vec, rewriter);
auto group1 = packLLVector(loc, group1Vec, rewriter);
Expand All @@ -788,7 +842,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
to_vector(shapePerCTA), numWarps, padInterval, padAmount,
group0Vec, group1Vec, std::nullopt, std::nullopt,
to_vector(offset), dstPtrs, pred, multicastMask,
barrierPtr, sharedLayout, ctaId, !isLoad);
barrierPtr, sharedLayout, ctaId, !isLoad, isRowMajor);

auto group0 = packLLVector(loc, group0Vec, rewriter);
auto group1 = packLLVector(loc, group1Vec, rewriter);
Expand Down Expand Up @@ -1028,5 +1082,4 @@ SmallVector<Value> emitTDMPrefetch(RewriterBase &rewriter, Location loc,
}
return offsets;
}

} // namespace mlir::LLVM::AMD
10 changes: 6 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
SmallVector<int64_t> blockShape, int numWarps,
unsigned padInterval, unsigned padAmount,
SmallVector<Value> tensorShape,
SmallVector<Value> tensorStride,
Value srcPtr);
SmallVector<Value> tensorStride, Value srcPtr,
bool isRowMajor);

// Update the global memory address with offset, and fill the shared memory
// address and pred in a given TDM descriptor for regular load/store (1D-5D).
Expand All @@ -47,7 +47,8 @@ void fillTDMDescriptor(
std::optional<std::reference_wrapper<SmallVector<Value>>> group3,
SmallVector<Value> offset, ArrayRef<Value> dstPtrs, Value pred,
Value multicastMask, Value barrierPtr,
const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore);
const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore,
bool isRowMajor);

// Fill TDM descriptor for gather/scatter operations (2D only).
// Gather reads from non-contiguous rows in global memory to LDS.
Expand Down Expand Up @@ -80,7 +81,8 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
ArrayRef<Value> offset, ArrayRef<Value> dstPtrs,
Value pred, Value multicastMask, Type elementType,
Value barrierPtr, bool isLoad,
const triton::LinearLayout &sharedLayout, Value ctaId);
const triton::LinearLayout &sharedLayout, Value ctaId,
bool isRowMajor);

// Calculate the number of TDM gather/scatter instructions needed.
// - numIndices: number of row indices
Expand Down
Loading
Loading