Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Type elemTy,
RewriterBase &rewriter);

// Build a vector of shared-memory base pointers for dynamic partition
// indexing (expects at least two bases).
Value buildBasePtrVector(Location loc, RewriterBase &rewriter,
ArrayRef<Value> smemBases);

// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Expand Down
30 changes: 15 additions & 15 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,20 +616,6 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
warpId, rewriter, targetInfo, maybeMaxVecElems, emitLdSt);
}

// Build a vector containing multiple base pointers for dynamic indexing.
static Value buildBasePtrVector(Location loc, RewriterBase &rewriter,
ArrayRef<Value> smemBases) {
assert(smemBases.size() > 1 && "Need multiple bases to build a vector");
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto ptrTy = smemBases[0].getType();
auto vecTy = VectorType::get({static_cast<int64_t>(smemBases.size())}, ptrTy);
Value basesVec = b.undef(vecTy);
for (size_t i = 0; i < smemBases.size(); ++i) {
basesVec = b.insert_element(basesVec, smemBases[i], b.i32_val(i));
}
return basesVec;
}

SmallVector<Value>
lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt,
ArrayRef<Value> valsArray, // Input for store, output for load
Expand Down Expand Up @@ -669,7 +655,7 @@ lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt,
Value basesVec;
if (isPartitioned) {
partitionLayout = cvt.sublayout(inDimNames, {kPartition});
basesVec = buildBasePtrVector(loc, rewriter, smemBases);
basesVec = LLVM::buildBasePtrVector(loc, rewriter, smemBases);
}

// Strip kPartition output for vectorization analysis.
Expand Down Expand Up @@ -1285,6 +1271,20 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
/*offsets=*/{elems.begin() + numBases, elems.end()}};
}

// Build a vector containing multiple base pointers for dynamic indexing.
Value buildBasePtrVector(Location loc, RewriterBase &rewriter,
ArrayRef<Value> smemBases) {
assert(smemBases.size() > 1 && "Need multiple bases to build a vector");
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto ptrTy = smemBases[0].getType();
auto vecTy = VectorType::get({static_cast<int64_t>(smemBases.size())}, ptrTy);
Value basesVec = b.undef(vecTy);
for (size_t i = 0; i < smemBases.size(); ++i) {
basesVec = b.insert_element(basesVec, smemBases[i], b.i32_val(i));
}
return basesVec;
}

Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) {
// See NOTE: [Additional Function Arguments]
if (!isKernel(funcOp)) {
Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/amd/ds_transpose_gfx1250.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
#padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}>
#smem = #ttg.shared_memory

// Partitioned shared: inner padded tiles.
#inner_ps_dim0 = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [64, 64]}>
#partitioned_dim0 = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 0, partitionLayout = #inner_ps_dim0}>
#inner_ps_dim1 = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 32]}>
#partitioned_dim1 = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 1, partitionLayout = #inner_ps_dim1}>

#linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
#linear_ds_tr_tile_invalid = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [4, 0], [2, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>

Expand Down Expand Up @@ -121,4 +127,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
tt.return
}

// WMMA dot path from partitioned shared (partitionDim = 0): multiple LDS bases + ds transpose loads.
// CHECK-LABEL: partitioned_shared_ds_transpose_dot_op0_dim0
tt.func @partitioned_shared_ds_transpose_dot_op0_dim0(%arg0: !ttg.memdesc<128x64xf16, #partitioned_dim0, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
// CHECK: vector<2x!llvm.ptr<3>>
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>>
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>>
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16>
// CHECK-NOT: ds.load.tr16.b128
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #partitioned_dim0, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>

%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
tt.return
}

// Same tile shape with partitionDim = 1 (column partitions).
// CHECK-LABEL: partitioned_shared_ds_transpose_dot_op0_dim1
tt.func @partitioned_shared_ds_transpose_dot_op0_dim1(%arg0: !ttg.memdesc<128x64xf16, #partitioned_dim1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
// CHECK: vector<2x!llvm.ptr<3>>
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>>
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>>
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16>
// CHECK-NOT: ds.load.tr16.b128
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #partitioned_dim1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>

%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
tt.return
}
}
97 changes: 68 additions & 29 deletions third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class TransLocalLoadOpConversion
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();

// Partitioned tensors have multiple bases; fall back to generic lowering.
if (isa<triton::gpu::PartitionedSharedEncodingAttr>(srcTy.getEncoding())) {
return failure();
}
auto typeConverter = this->getTypeConverter();
auto llvmElemTy = typeConverter->convertType(dstTy.getElementType());
unsigned bitWidth = llvmElemTy.getIntOrFloatBitWidth();
Expand All @@ -48,18 +44,17 @@ class TransLocalLoadOpConversion
if (!ldsParams)
return failure();

auto paddedEnc =
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(srcTy.getEncoding());
LinearLayout cvtDstLL = LinearLayout::empty();
if (paddedEnc) {
const auto &sharedLL = paddedEnc.getLinearComponent();
cvtDstLL = triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL);
if (paddedEnc.getMinInterval() < ldsParams->tileSize)
LinearLayout sharedLL;
if (triton::gpu::isPaddedEncoding(srcTy.getEncoding())) {
sharedLL = triton::gpu::paddedLinearLayout(srcTy);
if (triton::gpu::getMinInterval(srcTy.getEncoding()) <
ldsParams->tileSize)
return failure();
} else {
auto sharedLL = triton::gpu::toLinearLayout(srcTy);
cvtDstLL = triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL);
sharedLL = triton::gpu::toLinearLayout(srcTy);
}
LinearLayout cvtDstLL =
triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL);
auto kBlock = StringAttr::get(ctx, "block");
auto maybeSublayout = cvtDstLL.quotient({kBlock});
if (!maybeSublayout) {
Expand All @@ -68,7 +63,7 @@ class TransLocalLoadOpConversion
cvtDstLL = maybeSublayout.value();
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto smemBase = smemObj.getBase();
SmallVector<Value> smemBases = llvm::to_vector(smemObj.getBases());
auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy);
auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy);
auto paddingShifts = getPaddedSharedShifts(srcTy.getEncoding(),
Expand All @@ -77,7 +72,7 @@ class TransLocalLoadOpConversion

llvm::SmallVector<Value> values;
auto result = lowerDsReadTr(
op, ldsParams.value(), loc, cvtDstLL, values, smemBase, affineOffset,
op, ldsParams.value(), loc, cvtDstLL, values, smemBases, affineOffset,
maskSpanAffineOffset, paddingShifts, llvmElemTy, rewriter, targetInfo);
if (failed(result)) {
return failure();
Expand All @@ -92,15 +87,15 @@ class TransLocalLoadOpConversion
}

private:
LogicalResult lowerDsReadTr(
triton::gpu::LocalLoadOp op,
::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, Location loc,
LinearLayout cvt,
SmallVector<Value> &vals, // Input for stmatrix, output for ldmatrix
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this comment on purpose, as it seems as an accidental copy/paste from nv path

Value smemBase, Value affineOffset, uint64_t maskSpanAffineOffset,
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts, Type llvmElemTy,
ConversionPatternRewriter &rewriter,
const ::triton::AMD::TargetInfo &targetInfo) const {
LogicalResult
lowerDsReadTr(triton::gpu::LocalLoadOp op,
::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams,
Location loc, LinearLayout cvt, SmallVector<Value> &vals,
ArrayRef<Value> smemBases, Value affineOffset,
uint64_t maskSpanAffineOffset,
ArrayRef<std::pair<unsigned, unsigned>> paddingShifts,
Type llvmElemTy, ConversionPatternRewriter &rewriter,
const ::triton::AMD::TargetInfo &targetInfo) const {

auto b = TritonLLVMOpBuilder(loc, rewriter);
auto *ctx = rewriter.getContext();
Expand All @@ -111,9 +106,30 @@ class TransLocalLoadOpConversion
auto kWarp = S("warp");
auto kOffset = S("offset");
auto kAddr = S("addr");
auto kPartition = S("partition");
auto smemPtrTy = ptr_ty(ctx, 3);
auto bitWidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);

assert(!smemBases.empty() && "expected at least one smem base");
LinearLayout cvtLayout = cvt;
LinearLayout partitionLayout;
Value basesVec;
const bool isPartitioned = smemBases.size() > 1;

if (isPartitioned) {
assert(cvtLayout.hasOutDim(kPartition) &&
cvtLayout.getOutDimSize(kPartition) ==
static_cast<int32_t>(smemBases.size()) &&
"smemBases size must match partition dimension size");
auto inDimNames = llvm::to_vector(cvtLayout.getInDimNames());
partitionLayout = cvtLayout.sublayout(inDimNames, {kPartition});
SmallVector<StringAttr> outDims =
llvm::to_vector(cvtLayout.getOutDimNames());
llvm::erase(outDims, kPartition);
cvtLayout = cvtLayout.sublayout(inDimNames, outDims);
basesVec = LLVM::buildBasePtrVector(loc, rewriter, smemBases);
}

// Map onto offsets (contiguous part) and addr (non-contiguous part)
LinearLayout fullTile;
// Contiguous tile
Expand Down Expand Up @@ -170,11 +186,11 @@ class TransLocalLoadOpConversion
// Add warp dimension so we can invert and compose with reps later
fullTile *= LinearLayout::identity1D(1, kWarp, kAddr);

if (cvt.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) {
if (cvtLayout.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) {
return failure();
}

auto maybeQuot = divideLeft(cvt, tile);
auto maybeQuot = divideLeft(cvtLayout, tile);
if (!maybeQuot.has_value()) {
return failure();
}
Expand Down Expand Up @@ -209,6 +225,22 @@ class TransLocalLoadOpConversion
auto [nAdditive, permStrides] =
actionAdditiveStrides(reps, addrLayout, maskSpanAffineOffset);
reps = permStrides.apply(reps);
if (isPartitioned) {
partitionLayout = permStrides.apply(partitionLayout);

// One ds_read_tr* instruction produces `fullTile.getInDimSize(kReg)`
// consecutive register values from a single LDS base pointer. We only
// select a partition once per instruction, so all of those register
// positions must map to the same partition. For a LinearLayout that holds
// iff the low log2(elemsPerInstr) register bases contribute 0 to
// kPartition. Bail out if not, so a generic lowering can take over.
const unsigned numInstrRegBits =
llvm::Log2_32(fullTile.getInDimSize(kReg));
for (unsigned pos = 0; pos < numInstrRegBits; ++pos) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just checking partitionLayout.sublayoutIsZero({kReg}, {kPartition})?

Copy link
Copy Markdown
Contributor Author

@plognjen plognjen Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because this would check whole reg bases of partition layout, which would include repetitions. The point is that we want to check just first numInstrRegBits, which are number of register from fullTile layout, which is one instruction. It's fine for different repetitions (instructions) to be in different partitions, but we want to check if registers from single instruction are in different partition.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, if want it without looking at the bases, you can do that by reshaping kReg into two dimensions, one of dim numInstrRegBits and a different one and check the sublayoutIsZero there. But tbh I wouldn't rewrite it, the current solution seems alright.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah makes sense. Thanks for the explanation.

if (partitionLayout.getBasis(kReg, pos, kPartition) != 0)
return failure();
}
Comment on lines +231 to +242
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably add same check in regular lowering path as well.

}

// Perform computation in bytes, LLVM optimises this better
assert(bitWidth >= 8);
Expand Down Expand Up @@ -282,7 +314,7 @@ class TransLocalLoadOpConversion
auto elemsPerInstr = fullTile.getInDimSize(kReg);
auto elemsPerVec = ldsParams.instBitWidth / bitWidth;
auto vecTy = vec_ty(llvmElemTy, elemsPerVec);
for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) {
for (int i = 0; i < cvtLayout.getInDimSize(kReg); i += nAdditive) {
auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second;
auto regIdxI8 = regIdx * (bitWidth / 8);
Value offset = b.xor_(regBase, b.i32_val(regIdxI8));
Expand All @@ -297,14 +329,21 @@ class TransLocalLoadOpConversion
// separately.
regIdxAddI8 = applyPadding(regIdxAddI8, paddingShifts);
Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8));
auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset,
Value smemBaseVal = smemBases[0];
if (isPartitioned) {
auto partOut = applyLinearLayout(
loc, rewriter, partitionLayout,
{{kReg, b.i32_val(i + i2)}, {kLane, laneId}, {kWarp, warpId}});
smemBaseVal = b.extract_element(basesVec, partOut[0].second);
}
auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBaseVal, innerOffset,
LLVM::GEPNoWrapFlags::inbounds);
llvm::append_range(vals,
lowerInst(rewriter, loc, vecAddr, i + i2, vecTy));
}
}
// apply all the inverse permutations in the reverse order
assert(vals.size() == cvt.getInDimSize(kReg));
assert(vals.size() == cvtLayout.getInDimSize(kReg));
vals = permStrides.inverse().apply(vals);

return success();
Expand Down
Loading
Loading