Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions test/Conversion/amd/tritongpu_tdm_stride_order.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_load_runtime_strides(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) {
%c_shape = arith.constant 128 : i32
%c_stride0 = arith.constant 128 : i64
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires at least one dimension to have stride 1}}
// expected-error @+2 {{last dimension must have stride 1}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride0, %stride1] : <f16>, <64x64xf16, #shared>
tt.return
Expand All @@ -17,12 +15,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr

#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_1x1_tensor(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) {
// CHECK-LABEL: tdm_1x1_tensor
tt.func public @tdm_1x1_tensor(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %stride1: i64) {
%c_shape = arith.constant 1 : i32
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires at least one dimension to have stride 1}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride1, %stride1] : <f16>, <64x64xf16, #shared>
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %c_stride1] : <f16>, <64x64xf16, #shared>
tt.return
}
}
Expand All @@ -34,7 +31,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
tt.func public @tdm_wrong_stride_order(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_shape = arith.constant 128 : i32
%c_stride1 = arith.constant 1 : i64
// expected-error @+2 {{requires shared order [rank-2, rank-1, rank-3, rank-4, ..., 0] because dim[rank-2] has stride 1}}
// expected-error @+2 {{last dimension must have stride 1}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : <f16>, <64x64xf16, #shared>
tt.return
Expand Down Expand Up @@ -103,13 +100,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr

#shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @tdm_1x1xi(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
tt.func public @tdm_1xix1(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %runtime_stride: i64) {
%c_stride1 = arith.constant 1 : i64
%c_shape = arith.constant 1 : i32
%c_shape2 = arith.constant 128 : i32
// expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}}
// expected-error @+1 {{failed to legalize operation}}
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %c_stride1, %runtime_stride] : <f16>, <1x1x1xf16, #shared>
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %runtime_stride, %c_stride1] : <f16>, <1x1x1xf16, #shared>
tt.return
}
}
9 changes: 1 addition & 8 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() {
return emitOpError("TDM store only supports single interval paddings.");

auto shapePerCTA = triton::gpu::getShapePerCTA(paddedEnc, blockShape);
if (intervals[0] != shapePerCTA[paddedEnc.getOrder().front()])
if (intervals[0] != shapePerCTA.back())
return emitOpError("TDM store padding is only supported when padding "
"interval equals the innermost block dimension (got "
"padInterval=")
Expand Down Expand Up @@ -769,13 +769,6 @@ LogicalResult AsyncTDMScatterOp::verify() {
if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(smemTy.getEncoding()),
shapePerCTA);
if (sharedOrder[0] != (sharedOrder.size() - 1))
return emitOpError("TDM scatter only supports row-major shared order");

return success();
}

Expand Down
13 changes: 3 additions & 10 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,15 +1267,12 @@ struct AsyncTDMCopyGlobalToLocalOpConversion

auto shapePerCTA =
triton::gpu::getShapePerCTA(encoding, tensorDescTy.getShape());
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(encoding), shapePerCTA);
bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1);

mlir::LLVM::AMD::emitTDMLoadStore(
rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps,
padInterval, padAmount, offset, dstPtrs, op.getPred(), multicastMask,
elementType, barrierPtr, /*isLoad=*/true, sharedLayout, encoding, ctaId,
isRowMajor);
elementType, barrierPtr, /*isLoad=*/true, sharedLayout, encoding,
ctaId);

rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -1347,17 +1344,13 @@ struct AsyncTDMCopyLocalToGlobalOpConversion
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);

auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy);
auto sharedOrder = triton::gpu::getOrder(
cast<triton::gpu::SharedEncodingTrait>(smemTy.getEncoding()),
shapePerCTA);
bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1);

Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
mlir::LLVM::AMD::emitTDMLoadStore(
rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps,
padInterval, padAmount, offset, srcPtrs, pred,
/*multicastMask=*/{}, elementType, barrierPtr,
/*isLoad=*/false, sharedLayout, encoding, ctaId, isRowMajor);
/*isLoad=*/false, sharedLayout, encoding, ctaId);

rewriter.eraseOp(op);
return success();
Expand Down
80 changes: 14 additions & 66 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,6 @@ SmallVector<Value> TDMDescriptor::getAllGroups() const {
return result;
}

// Swap the trailing two dimensions of a vector for TDM operations.
template <typename T> void swapTrailingDims(SmallVector<T> &vec) {
assert(vec.size() >= 2 && "need at least 2 dims to swap");
std::swap(vec[vec.size() - 2], vec[vec.size() - 1]);
}

// Decode a full TDM descriptor from all 4 group vectors for 3D-5D tensors
// Returns (base, tensorShape[], tensorStride[], blockShape[])
std::tuple<Value, SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
Expand Down Expand Up @@ -233,7 +227,7 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
unsigned padInterval, unsigned padAmount,
SmallVector<Value> tensorShape,
SmallVector<Value> tensorStride, Value srcPtr,
bool isRowMajor, Attribute sharedEncoding) {
Attribute sharedEncoding) {
size_t numDims = tensorShape.size();
assert(numDims >= 1 && numDims <= 5 && tensorStride.size() == numDims &&
"TDM only supported for 1D-5D tensors.");
Expand All @@ -242,12 +236,6 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
"Block/tensor/stride dim count must all be equal.");
auto b = TritonLLVMOpBuilder(loc, rewriter);

if (!isRowMajor) {
swapTrailingDims(blockShape);
swapTrailingDims(tensorStride);
swapTrailingDims(tensorShape);
}

// Define common values for better readability
Value v16 = b.i32_val(16);
Value v32 = b.i64_val(32);
Expand Down Expand Up @@ -508,29 +496,6 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
return TDMDescriptor{group0, group1, group2, group3};
}

// Returns a copy of `layout` where the semantics of dimA and dimB are
// exchanged: new.apply(x)[dimA] == old.apply(x)[dimB] and vice versa. The
// output dimension order is preserved.
static triton::LinearLayout
swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA,
StringAttr dimB) {
assert(layout.hasOutDim(dimA));
assert(layout.hasOutDim(dimB));
SmallVector<std::pair<StringAttr, int32_t>> renamedOutDims;
for (auto [name, size] : layout.getOutDims()) {
if (name == dimA)
renamedOutDims.push_back({dimB, size});
else if (name == dimB)
renamedOutDims.push_back({dimA, size});
else
renamedOutDims.push_back({name, size});
}
// Transpose to restore the original output dimension order.
return triton::LinearLayout(layout.getBases(), renamedOutDims,
/*requireSurjective=*/false)
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
}

// Fill TDM descriptor for regular load/store operations (1D-5D tensors)
void fillTDMDescriptor(
RewriterBase &rewriter, Location loc,
Expand All @@ -542,7 +507,7 @@ void fillTDMDescriptor(
SmallVector<Value> offset, ArrayRef<Value> dstPtrs, Value pred,
Value multicastMask, Value barrierPtr,
const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore,
bool isRowMajor, ArrayRef<unsigned> warpsPerCTA) {
ArrayRef<unsigned> warpsPerCTA) {
size_t numDims = offset.size();
assert(numDims >= 1 && numDims <= 5 && "TDM supports 1D to 5D tensors.");
assert(!dstPtrs.empty() && "dstPtrs cannot be empty");
Expand All @@ -553,22 +518,6 @@ void fillTDMDescriptor(
Type globalPtrTy = ptr_ty(ctx, 1);
Type sharedPtrTy = ptr_ty(ctx, 3);

// For col-major tensors the TDM descriptor was created with the trailing two
// dimensions swapped. Swap shapePerCTA and offset to match that hardware
// view, and rename the same two dims in the shared layout to align out dims.
std::optional<triton::LinearLayout> adjustedSharedLayout;
if (!isRowMajor) {
swapTrailingDims(shapePerCTA);
swapTrailingDims(offset);
if (numDims >= 2) {
auto dimN_2 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 2));
auto dimN_1 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 1));
adjustedSharedLayout = swapOutDimSemantics(sharedLayout, dimN_2, dimN_1);
}
}
const auto &tdmViewSharedLayout =
adjustedSharedLayout ? *adjustedSharedLayout : sharedLayout;

// Decode the full TDM descriptor to get all values
auto [srcPtr, tensorShape, tensorStride, decodedBlockShape] =
decodeTDMDescriptorFull(
Expand All @@ -588,7 +537,7 @@ void fillTDMDescriptor(
auto kPartition = str_attr("partition");

auto cgaLayout = triton::gpu::SharedLinearEncodingAttr::get(
ctx, tdmViewSharedLayout, /*layoutAlignment=*/16)
ctx, sharedLayout, /*layoutAlignment=*/16)
.getCGALayout()
.getLinearLayout();

Expand All @@ -615,7 +564,7 @@ void fillTDMDescriptor(
}
srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset);

auto tdmToShared = tdmLayout.invertAndCompose(tdmViewSharedLayout);
auto tdmToShared = tdmLayout.invertAndCompose(sharedLayout);
auto sharedOffsets = applyLinearLayout(
loc, rewriter, tdmToShared,
{{kMessage, b.i32_val(0)}, {kWarp, warpId}, {kBlock, ctaId}});
Expand Down Expand Up @@ -956,7 +905,7 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc,
SmallVector<Value> globalOffset, ArrayRef<Value> instrDstPtrs,
Value pred, Value multicastMask, Value barrier,
const triton::LinearLayout &instrSharedLayout, Value ctaId,
bool isLoad, bool isRowMajor, ArrayRef<unsigned> warpsPerCTA) {
bool isLoad, ArrayRef<unsigned> warpsPerCTA) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto v8i32Ty = VectorType::get(8, rewriter.getI32Type());
Value group4Zero = LLVM::ZeroOp::create(rewriter, loc, v8i32Ty);
Expand All @@ -972,7 +921,7 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc,
group0Vec, group1Vec, std::ref(group2Vec),
std::ref(group3Vec), globalOffset, instrDstPtrs, pred,
multicastMask, barrier, instrSharedLayout, ctaId, !isLoad,
isRowMajor, warpsPerCTA);
warpsPerCTA);

auto group0 = packLLVector(loc, group0Vec, rewriter);
auto group1 = packLLVector(loc, group1Vec, rewriter);
Expand All @@ -988,11 +937,11 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc,
auto group0Vec = SmallVector<Value>(desc.begin(), desc.begin() + 4);
auto group1Vec = SmallVector<Value>(desc.begin() + 4, desc.end());

fillTDMDescriptor(
rewriter, loc, typeConverter, elementType, effectiveBlockShape,
numWarps, padInterval, padAmount, group0Vec, group1Vec, std::nullopt,
std::nullopt, globalOffset, instrDstPtrs, pred, multicastMask, barrier,
instrSharedLayout, ctaId, !isLoad, isRowMajor, warpsPerCTA);
fillTDMDescriptor(rewriter, loc, typeConverter, elementType,
effectiveBlockShape, numWarps, padInterval, padAmount,
group0Vec, group1Vec, std::nullopt, std::nullopt,
globalOffset, instrDstPtrs, pred, multicastMask, barrier,
instrSharedLayout, ctaId, !isLoad, warpsPerCTA);

auto group0 = packLLVector(loc, group0Vec, rewriter);
auto group1 = packLLVector(loc, group1Vec, rewriter);
Expand Down Expand Up @@ -1024,7 +973,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
Value pred, Value multicastMask, Type elementType,
Value barrierPtr, bool isLoad,
const triton::LinearLayout &sharedLayout,
Attribute encoding, Value ctaId, bool isRowMajor) {
Attribute encoding, Value ctaId) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
size_t numDims = blockShape.size();
assert(numDims <= 5);
Expand All @@ -1040,8 +989,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
emitTDMIntrinsic(rewriter, loc, typeConverter, desc, numDims, elementType,
to_vector(blockShape), numWarps, padInterval, padAmount,
to_vector(offset), dstPtrs, pred, multicastMask,
barrierPtr, sharedLayout, ctaId, isLoad, isRowMajor,
warpsPerCTA);
barrierPtr, sharedLayout, ctaId, isLoad, warpsPerCTA);
return;
}

Expand Down Expand Up @@ -1105,7 +1053,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
emitTDMIntrinsic(rewriter, loc, typeConverter, desc, numDims, elementType,
effectiveBlockShape, numWarps, padInterval, padAmount,
globalOffset, instrDstPtrs, pred, multicastMask, barrier,
sliceLayout, ctaId, isLoad, isRowMajor, warpsPerCTA);
sliceLayout, ctaId, isLoad, warpsPerCTA);
}
}

Expand Down
6 changes: 3 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc,
unsigned padInterval, unsigned padAmount,
SmallVector<Value> tensorShape,
SmallVector<Value> tensorStride, Value srcPtr,
bool isRowMajor, Attribute sharedEncoding);
Attribute sharedEncoding);

// Update the global memory address with offset, and fill the shared memory
// address and pred in a given TDM descriptor for regular load/store (1D-5D).
Expand All @@ -53,7 +53,7 @@ void fillTDMDescriptor(
SmallVector<Value> offset, ArrayRef<Value> dstPtrs, Value pred,
Value multicastMask, Value barrierPtr,
const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore,
bool isRowMajor, ArrayRef<unsigned> warpsPerCTA);
ArrayRef<unsigned> warpsPerCTA);

// Fill TDM descriptor for gather/scatter operations (2D only).
// Gather reads from non-contiguous rows in global memory to LDS.
Expand Down Expand Up @@ -94,7 +94,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc,
Value pred, Value multicastMask, Type elementType,
Value barrierPtr, bool isLoad,
const triton::LinearLayout &sharedLayout,
Attribute encoding, Value ctaId, bool isRowMajor);
Attribute encoding, Value ctaId);

// Returns (warpsPerCTA, numTDMInstructions) for a given shared encoding.
// For PartitionedSharedEncodingAttr, computes a partition-aligned warp
Expand Down
Loading
Loading