diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 061647d437e9..ac4b43e3fbcf 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -516,6 +516,11 @@ SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset); +// Emits the required padding in elements for the given shared memory offset +Value emitPadding(Location loc, RewriterBase &rewriter, + triton::gpu::PaddedSharedEncodingAttr layout, + Value smemOffset); + // Emits IR to load data from shared memory into registers, or to store data // from registers into shared memory. // diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 058a6ceaa4f3..2218a1e5d477 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -398,6 +398,21 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, return ret; } +Value emitPadding(Location loc, RewriterBase &rewriter, + triton::gpu::PaddedSharedEncodingAttr layout, + Value smemOffset) { + TritonLLVMOpBuilder b(loc, rewriter); + + Value padOffset = b.i32_val(0); + for (auto [interval, padding] : + llvm::zip_equal(layout.getIntervals(), layout.getPaddings())) { + Value iVal = b.i32_val(llvm::Log2_32(interval)); + Value pVal = b.i32_val(llvm::Log2_32(padding)); + padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal)); + } + return padOffset; +} + namespace { Value getSmemVecAddr(const LinearLayout ®Layout, @@ -488,13 +503,7 @@ Value getSmemVecAddr(const LinearLayout ®Layout, if (auto paddedLayout = dyn_cast(sharedEnc)) { // Apply the offset needed for padding. - Value padOffset = b.i32_val(0); - for (auto [interval, padding] : llvm::zip_equal( - paddedLayout.getIntervals(), paddedLayout.getPaddings())) { - Value iVal = b.i32_val(llvm::Log2_32(interval)); - Value pVal = b.i32_val(llvm::Log2_32(padding)); - padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal)); - } + Value padOffset = emitPadding(loc, rewriter, paddedLayout, smemOffset); smemOffset = b.add(smemOffset, padOffset); } } else { // Case 2 -> rank-reduced swizzling diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 0d4034f319ed..8ff9be6d6c98 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -513,6 +513,13 @@ struct MemDescSubviewOpConversion .second; } + if (auto paddedLayout = dyn_cast( + srcTy.getEncoding())) { + // Apply padding based on the computed offset + Value padOffset = emitPadding(loc, rewriter, paddedLayout, offset); + offset = b.add(offset, padOffset); + } + auto base = smemObj.getBase(); auto elemPtrTy = base.getType(); smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset), diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 5ee33efdabf2..98408d0969a8 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -413,6 +413,40 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n // ----- +// CHECK-LABEL: padded_shared_layout_subview +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @padded_shared_layout_subview(%arg0: !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + // Skip two constants from the stride calculation + + // CHECK-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[CST2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[CST7:.+]] = llvm.mlir.constant(7 : i32) + // CHECK-DAG: %[[CST8:.+]] = llvm.mlir.constant(8 : i32) + // CHECK-DAG: %[[CST3:.+]] = llvm.mlir.constant(3 : i32) + + // CHECK: %[[SHR0:.+]] = llvm.ashr %[[XOR:.+]], %[[CST7]] : i32 + // CHECK-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST2]] : i32 + // CHECK-NEXT: %[[ADD0:.+]] = llvm.add %[[SHL0]], %[[CST0]] : i32 + // CHECK-NEXT: %[[SHR1:.+]] = llvm.ashr %[[XOR]], %[[CST8]] : i32 + // CHECK-NEXT: %[[SHL1:.+]] = llvm.shl %[[SHR1]], %[[CST3]] : i32 + // CHECK-NEXT: %[[ADD1:.+]] = llvm.add %[[ADD0]], %[[SHL1]] : i32 + // CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[XOR]], %[[ADD1]] : i32 + // CHECK-NEXT: llvm.getelementptr %{{.+}}[%[[ADD2]]] + + %1 = ttg.memdesc_subview %arg0[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %2 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked> + ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + // GFX950-LABEL: reduce_32x32 // GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap" module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {