Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset);

// Emits the required padding in elements for the given shared memory offset
Value emitPadding(Location loc, RewriterBase &rewriter,
triton::gpu::PaddedSharedEncodingAttr layout,
Value smemOffset);

// Emits IR to load data from shared memory into registers, or to store data
// from registers into shared memory.
//
Expand Down
23 changes: 16 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,21 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
return ret;
}

Value emitPadding(Location loc, RewriterBase &rewriter,
triton::gpu::PaddedSharedEncodingAttr layout,
Value smemOffset) {
TritonLLVMOpBuilder b(loc, rewriter);

Value padOffset = b.i32_val(0);
for (auto [interval, padding] :
llvm::zip_equal(layout.getIntervals(), layout.getPaddings())) {
Value iVal = b.i32_val(llvm::Log2_32(interval));
Value pVal = b.i32_val(llvm::Log2_32(padding));
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
}
return padOffset;
}

namespace {

Value getSmemVecAddr(const LinearLayout &regLayout,
Expand Down Expand Up @@ -488,13 +503,7 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
if (auto paddedLayout =
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc)) {
// Apply the offset needed for padding.
Value padOffset = b.i32_val(0);
for (auto [interval, padding] : llvm::zip_equal(
paddedLayout.getIntervals(), paddedLayout.getPaddings())) {
Value iVal = b.i32_val(llvm::Log2_32(interval));
Value pVal = b.i32_val(llvm::Log2_32(padding));
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
}
Value padOffset = emitPadding(loc, rewriter, paddedLayout, smemOffset);
smemOffset = b.add(smemOffset, padOffset);
}
} else { // Case 2 -> rank-reduced swizzling
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,13 @@ struct MemDescSubviewOpConversion
.second;
}

if (auto paddedLayout = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
srcTy.getEncoding())) {
// Apply padding based on the computed offset
Value padOffset = emitPadding(loc, rewriter, paddedLayout, offset);
offset = b.add(offset, padOffset);
}

auto base = smemObj.getBase();
auto elemPtrTy = base.getType();
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
Expand Down
34 changes: 34 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,40 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n

// -----

// CHECK-LABEL: padded_shared_layout_subview
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @padded_shared_layout_subview(%arg0: !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable>) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
// Skip two constants from the stride calculation

// CHECK-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32)
// CHECK-DAG: %[[CST2:.+]] = llvm.mlir.constant(2 : i32)
// CHECK-DAG: %[[CST7:.+]] = llvm.mlir.constant(7 : i32)
// CHECK-DAG: %[[CST8:.+]] = llvm.mlir.constant(8 : i32)
// CHECK-DAG: %[[CST3:.+]] = llvm.mlir.constant(3 : i32)

// CHECK: %[[SHR0:.+]] = llvm.ashr %[[XOR:.+]], %[[CST7]] : i32
// CHECK-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST2]] : i32
// CHECK-NEXT: %[[ADD0:.+]] = llvm.add %[[SHL0]], %[[CST0]] : i32
// CHECK-NEXT: %[[SHR1:.+]] = llvm.ashr %[[XOR]], %[[CST8]] : i32
// CHECK-NEXT: %[[SHL1:.+]] = llvm.shl %[[SHR1]], %[[CST3]] : i32
// CHECK-NEXT: %[[ADD1:.+]] = llvm.add %[[ADD0]], %[[SHL1]] : i32
// CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[XOR]], %[[ADD1]] : i32
// CHECK-NEXT: llvm.getelementptr %{{.+}}[%[[ADD2]]]

%1 = ttg.memdesc_subview %arg0[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
%2 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

// GFX950-LABEL: reduce_32x32
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
Expand Down
Loading