From a539692828fa223266067b3619bec685fb430ea2 Mon Sep 17 00:00:00 2001 From: Ognjen Plavsic Date: Fri, 17 Apr 2026 15:10:28 +0000 Subject: [PATCH 1/2] Enable LDS transpose path for tensors with PartitionedSharedEncodingAttr --- .../Conversion/TritonGPUToLLVM/Utility.h | 5 + lib/Conversion/TritonGPUToLLVM/Utility.cpp | 30 +++--- test/Conversion/amd/ds_transpose_gfx1250.mlir | 36 +++++++ .../lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 98 +++++++++++++------ .../amd/python/test/test_gluon_gfx1250.py | 95 ++++++++++++------ 5 files changed, 190 insertions(+), 74 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 1ca12f323b92..067e3c0463aa 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -456,6 +456,11 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, Type elemTy, RewriterBase &rewriter); +// Build a vector of shared-memory base pointers for dynamic partition +// indexing (expects at least two bases). +Value buildBasePtrVector(Location loc, RewriterBase &rewriter, + ArrayRef smemBases); + // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. SmallVector delinearize(RewriterBase &rewriter, Location loc, diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 8975bba0f894..77d8d16c4686 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -616,20 +616,6 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt, warpId, rewriter, targetInfo, maybeMaxVecElems, emitLdSt); } -// Build a vector containing multiple base pointers for dynamic indexing. -static Value buildBasePtrVector(Location loc, RewriterBase &rewriter, - ArrayRef smemBases) { - assert(smemBases.size() > 1 && "Need multiple bases to build a vector"); - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto ptrTy = smemBases[0].getType(); - auto vecTy = VectorType::get({static_cast(smemBases.size())}, ptrTy); - Value basesVec = b.undef(vecTy); - for (size_t i = 0; i < smemBases.size(); ++i) { - basesVec = b.insert_element(basesVec, smemBases[i], b.i32_val(i)); - } - return basesVec; -} - SmallVector lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt, ArrayRef valsArray, // Input for store, output for load @@ -669,7 +655,7 @@ lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt, Value basesVec; if (isPartitioned) { partitionLayout = cvt.sublayout(inDimNames, {kPartition}); - basesVec = buildBasePtrVector(loc, rewriter, smemBases); + basesVec = LLVM::buildBasePtrVector(loc, rewriter, smemBases); } // Strip kPartition output for vectorization analysis. @@ -1285,6 +1271,20 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, /*offsets=*/{elems.begin() + numBases, elems.end()}}; } +// Build a vector containing multiple base pointers for dynamic indexing. +Value buildBasePtrVector(Location loc, RewriterBase &rewriter, + ArrayRef smemBases) { + assert(smemBases.size() > 1 && "Need multiple bases to build a vector"); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ptrTy = smemBases[0].getType(); + auto vecTy = VectorType::get({static_cast(smemBases.size())}, ptrTy); + Value basesVec = b.undef(vecTy); + for (size_t i = 0; i < smemBases.size(); ++i) { + basesVec = b.insert_element(basesVec, smemBases[i], b.i32_val(i)); + } + return basesVec; +} + Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { // See NOTE: [Additional Function Arguments] if (!isKernel(funcOp)) { diff --git a/test/Conversion/amd/ds_transpose_gfx1250.mlir b/test/Conversion/amd/ds_transpose_gfx1250.mlir index 98b737fc9b5e..ef7dc8054726 100644 --- a/test/Conversion/amd/ds_transpose_gfx1250.mlir +++ b/test/Conversion/amd/ds_transpose_gfx1250.mlir @@ -12,6 +12,12 @@ #padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}> #smem = #ttg.shared_memory +// Partitioned shared: inner padded tiles. +#inner_ps_dim0 = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [64, 64]}> +#partitioned_dim0 = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 0, partitionLayout = #inner_ps_dim0}> +#inner_ps_dim1 = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 32]}> +#partitioned_dim1 = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 1, partitionLayout = #inner_ps_dim1}> + #linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> #linear_ds_tr_tile_invalid = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [4, 0], [2, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> @@ -121,4 +127,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.store %ptr1, %1 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> tt.return } + + // WMMA dot path from partitioned shared (partitionDim = 0): multiple LDS bases + ds transpose loads. + // CHECK-LABEL: partitioned_shared_ds_transpose_dot_op0_dim0 + tt.func @partitioned_shared_ds_transpose_dot_op0_dim0(%arg0: !ttg.memdesc<128x64xf16, #partitioned_dim0, #smem, mutable>, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + // CHECK: vector<2x!llvm.ptr<3>> + // CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>> + // CHECK: llvm.extractelement %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>> + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16> + // CHECK-NOT: ds.load.tr16.b128 + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #partitioned_dim0, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> + + %ptr1 = tt.splat %arg2 : !tt.ptr -> tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> + tt.store %ptr1, %1 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> + tt.return + } + + // Same tile shape with partitionDim = 1 (column partitions). + // CHECK-LABEL: partitioned_shared_ds_transpose_dot_op0_dim1 + tt.func @partitioned_shared_ds_transpose_dot_op0_dim1(%arg0: !ttg.memdesc<128x64xf16, #partitioned_dim1, #smem, mutable>, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + // CHECK: vector<2x!llvm.ptr<3>> + // CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>> + // CHECK: llvm.extractelement %{{.*}}[%{{.*}} : i32] : vector<2x!llvm.ptr<3>> + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16> + // CHECK-NOT: ds.load.tr16.b128 + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #partitioned_dim1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> + + %ptr1 = tt.splat %arg2 : !tt.ptr -> tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> + tt.store %ptr1, %1 : tensor<128x64x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>> + tt.return + } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index 323900cf9df6..05c4fba3a4d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -31,10 +31,6 @@ class TransLocalLoadOpConversion MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - // Partitioned tensors have multiple bases; fall back to generic lowering. - if (isa(srcTy.getEncoding())) { - return failure(); - } auto typeConverter = this->getTypeConverter(); auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); unsigned bitWidth = llvmElemTy.getIntOrFloatBitWidth(); @@ -48,18 +44,17 @@ class TransLocalLoadOpConversion if (!ldsParams) return failure(); - auto paddedEnc = - dyn_cast(srcTy.getEncoding()); - LinearLayout cvtDstLL = LinearLayout::empty(); - if (paddedEnc) { - const auto &sharedLL = paddedEnc.getLinearComponent(); - cvtDstLL = triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL); - if (paddedEnc.getMinInterval() < ldsParams->tileSize) + LinearLayout sharedLL; + if (triton::gpu::isPaddedEncoding(srcTy.getEncoding())) { + sharedLL = triton::gpu::paddedLinearLayout(srcTy); + if (triton::gpu::getMinInterval(srcTy.getEncoding()) < + ldsParams->tileSize) return failure(); } else { - auto sharedLL = triton::gpu::toLinearLayout(srcTy); - cvtDstLL = triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL); + sharedLL = triton::gpu::toLinearLayout(srcTy); } + LinearLayout cvtDstLL = + triton::gpu::toLinearLayout(dstTy).invertAndCompose(sharedLL); auto kBlock = StringAttr::get(ctx, "block"); auto maybeSublayout = cvtDstLL.quotient({kBlock}); if (!maybeSublayout) { @@ -68,7 +63,8 @@ class TransLocalLoadOpConversion cvtDstLL = maybeSublayout.value(); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); - auto smemBase = smemObj.getBase(); + SmallVector smemBases(smemObj.getBases().begin(), + smemObj.getBases().end()); auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy); auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy); auto paddingShifts = getPaddedSharedShifts(srcTy.getEncoding(), @@ -77,7 +73,7 @@ class TransLocalLoadOpConversion llvm::SmallVector values; auto result = lowerDsReadTr( - op, ldsParams.value(), loc, cvtDstLL, values, smemBase, affineOffset, + op, ldsParams.value(), loc, cvtDstLL, values, smemBases, affineOffset, maskSpanAffineOffset, paddingShifts, llvmElemTy, rewriter, targetInfo); if (failed(result)) { return failure(); @@ -92,15 +88,15 @@ class TransLocalLoadOpConversion } private: - LogicalResult lowerDsReadTr( - triton::gpu::LocalLoadOp op, - ::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, Location loc, - LinearLayout cvt, - SmallVector &vals, // Input for stmatrix, output for ldmatrix - Value smemBase, Value affineOffset, uint64_t maskSpanAffineOffset, - ArrayRef> paddingShifts, Type llvmElemTy, - ConversionPatternRewriter &rewriter, - const ::triton::AMD::TargetInfo &targetInfo) const { + LogicalResult + lowerDsReadTr(triton::gpu::LocalLoadOp op, + ::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, + Location loc, LinearLayout cvt, SmallVector &vals, + ArrayRef smemBases, Value affineOffset, + uint64_t maskSpanAffineOffset, + ArrayRef> paddingShifts, + Type llvmElemTy, ConversionPatternRewriter &rewriter, + const ::triton::AMD::TargetInfo &targetInfo) const { auto b = TritonLLVMOpBuilder(loc, rewriter); auto *ctx = rewriter.getContext(); @@ -111,9 +107,30 @@ class TransLocalLoadOpConversion auto kWarp = S("warp"); auto kOffset = S("offset"); auto kAddr = S("addr"); + auto kPartition = S("partition"); auto smemPtrTy = ptr_ty(ctx, 3); auto bitWidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); + assert(!smemBases.empty() && "expected at least one smem base"); + LinearLayout cvtLayout = cvt; + LinearLayout partitionLayout; + Value basesVec; + const bool isPartitioned = smemBases.size() > 1; + + if (isPartitioned) { + assert(cvtLayout.hasOutDim(kPartition) && + cvtLayout.getOutDimSize(kPartition) == + static_cast(smemBases.size()) && + "smemBases size must match partition dimension size"); + auto inDimNames = llvm::to_vector(cvtLayout.getInDimNames()); + partitionLayout = cvtLayout.sublayout(inDimNames, {kPartition}); + SmallVector outDims = + llvm::to_vector(cvtLayout.getOutDimNames()); + llvm::erase(outDims, kPartition); + cvtLayout = cvtLayout.sublayout(inDimNames, outDims); + basesVec = LLVM::buildBasePtrVector(loc, rewriter, smemBases); + } + // Map onto offsets (contiguous part) and addr (non-contiguous part) LinearLayout fullTile; // Contiguous tile @@ -170,11 +187,11 @@ class TransLocalLoadOpConversion // Add warp dimension so we can invert and compose with reps later fullTile *= LinearLayout::identity1D(1, kWarp, kAddr); - if (cvt.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) { + if (cvtLayout.getInDimSize(kReg) < fullTile.getInDimSize(kReg)) { return failure(); } - auto maybeQuot = divideLeft(cvt, tile); + auto maybeQuot = divideLeft(cvtLayout, tile); if (!maybeQuot.has_value()) { return failure(); } @@ -209,6 +226,22 @@ class TransLocalLoadOpConversion auto [nAdditive, permStrides] = actionAdditiveStrides(reps, addrLayout, maskSpanAffineOffset); reps = permStrides.apply(reps); + if (isPartitioned) { + partitionLayout = permStrides.apply(partitionLayout); + + // One ds_read_tr* instruction produces `fullTile.getInDimSize(kReg)` + // consecutive register values from a single LDS base pointer. We only + // select a partition once per instruction, so all of those register + // positions must map to the same partition. For a LinearLayout that holds + // iff the low log2(elemsPerInstr) register bases contribute 0 to + // kPartition. Bail out if not, so a generic lowering can take over. + const unsigned numInstrRegBits = + llvm::Log2_32(fullTile.getInDimSize(kReg)); + for (unsigned pos = 0; pos < numInstrRegBits; ++pos) { + if (partitionLayout.getBasis(kReg, pos, kPartition) != 0) + return failure(); + } + } // Perform computation in bytes, LLVM optimises this better assert(bitWidth >= 8); @@ -282,7 +315,7 @@ class TransLocalLoadOpConversion auto elemsPerInstr = fullTile.getInDimSize(kReg); auto elemsPerVec = ldsParams.instBitWidth / bitWidth; auto vecTy = vec_ty(llvmElemTy, elemsPerVec); - for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) { + for (int i = 0; i < cvtLayout.getInDimSize(kReg); i += nAdditive) { auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second; auto regIdxI8 = regIdx * (bitWidth / 8); Value offset = b.xor_(regBase, b.i32_val(regIdxI8)); @@ -297,14 +330,21 @@ class TransLocalLoadOpConversion // separately. regIdxAddI8 = applyPadding(regIdxAddI8, paddingShifts); Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8)); - auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset, + Value smemBaseVal = smemBases[0]; + if (isPartitioned) { + auto partOut = applyLinearLayout( + loc, rewriter, partitionLayout, + {{kReg, b.i32_val(i + i2)}, {kLane, laneId}, {kWarp, warpId}}); + smemBaseVal = b.extract_element(basesVec, partOut[0].second); + } + auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBaseVal, innerOffset, LLVM::GEPNoWrapFlags::inbounds); llvm::append_range(vals, lowerInst(rewriter, loc, vecAddr, i + i2, vecTy)); } } // apply all the inverse permutations in the reverse order - assert(vals.size() == cvt.getInDimSize(kReg)); + assert(vals.size() == cvtLayout.getInDimSize(kReg)); vals = permStrides.inverse().apply(vals); return success(); diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index e008f14f5b62..83ce27101825 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -1435,7 +1435,7 @@ def partitioned_tdm_copy_kernel(a_ptr, b_ptr, M, N, # BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, # NUM_PARTITIONS: ttgl.constexpr, NUM_GROUPS: ttgl.constexpr, # PARTITION_DIM: ttgl.constexpr): - """TDM load with PartitionedSharedLayout, then store via registers.""" + """TDM load with PartitionedSharedLayout; LDS read uses WMMA dot layout (transpose path).""" num_warps: ttgl.constexpr = ttgl.num_warps() if PARTITION_DIM == 0: @@ -1450,6 +1450,8 @@ def partitioned_tdm_copy_kernel(a_ptr, b_ptr, M, N, # smem_layout: ttgl.constexpr = PartitionedSharedLayout(NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM, inner_layout) block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [num_warps, 1], [1, 0]) + WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, 32]) + OPERAND_LAYOUT: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8) pid_m = ttgl.program_id(axis=0) pid_n = ttgl.program_id(axis=1) @@ -1463,7 +1465,8 @@ def partitioned_tdm_copy_kernel(a_ptr, b_ptr, M, N, # ttgl.amd.gfx1250.tdm.async_load(a_desc, [idx_m, idx_n], a_buffer) ttgl.amd.gfx1250.tdm.async_wait(0) - a = a_buffer.load(layout=block_layout) + a_dot = a_buffer.load(layout=OPERAND_LAYOUT) + a = ttgl.convert_layout(a_dot, block_layout) offs_bm = idx_m + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, block_layout)) offs_bn = idx_n + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, block_layout)) @@ -1472,36 +1475,37 @@ def partitioned_tdm_copy_kernel(a_ptr, b_ptr, M, N, # ttgl.store(b_ptr + offs_b, a, mask=b_mask) +_PARTITIONED_TDM_PARAMS = [ + # # --- partitionDim = 0 (rows) --- + # 2 partitions x 1 group = 2 pieces along dim0 -> 4 warps covers it + (64, 32, 2, 1, 0), + # 2 partitions x 2 groups = 4 pieces along dim0 -> 4 warps covers it + (64, 32, 2, 2, 0), + # 2 partitions x 4 groups = 8 pieces along dim0 -> 4 warps < 8 -> 2 instructions + (64, 32, 2, 4, 0), + # 2 partitions x 8 groups = 16 pieces along dim0 -> 4 warps < 16 -> 4 instructions + (128, 64, 2, 2, 0), + # --- partitionDim = 1 (cols) --- + # 2 partitions x 1 group = 2 pieces along dim1 + (32, 64, 2, 1, 1), + # 2 partitions x 2 groups = 4 pieces along dim1 -> 4 warps covers it + (64, 64, 2, 2, 1), + # 2 partitions x 4 groups = 8 pieces along dim1 -> 4 warps < 8 -> 2 instructions + (64, 128, 2, 4, 1), + # 2 partitions x 8 groups = 16 pieces along dim1 -> 4 warps < 16 -> 4 instructions + (64, 256, 2, 8, 1), +] + + @pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250") -@pytest.mark.parametrize( - "BLOCK_M,BLOCK_N,NUM_PARTITIONS,NUM_GROUPS,PARTITION_DIM", - [ - # # --- partitionDim = 0 (rows) --- - # 2 partitions x 1 group = 2 pieces along dim0 -> 4 warps covers it - (64, 32, 2, 1, 0), - # 2 partitions x 2 groups = 4 pieces along dim0 -> 4 warps covers it - (64, 32, 2, 2, 0), - # 2 partitions x 4 groups = 8 pieces along dim0 -> 4 warps < 8 -> 2 instructions - (64, 32, 2, 4, 0), - # 2 partitions x 8 groups = 16 pieces along dim0 -> 4 warps < 16 -> 4 instructions - (128, 64, 2, 2, 0), - # --- partitionDim = 1 (cols) --- - # 2 partitions x 1 group = 2 pieces along dim1 - (32, 64, 2, 1, 1), - # 2 partitions x 2 groups = 4 pieces along dim1 -> 4 warps covers it - (64, 64, 2, 2, 1), - # 2 partitions x 4 groups = 8 pieces along dim1 -> 4 warps < 8 -> 2 instructions - (64, 128, 2, 4, 1), - # 2 partitions x 8 groups = 16 pieces along dim1 -> 4 warps < 16 -> 4 instructions - (64, 256, 2, 8, 1), - ], -) -@pytest.mark.parametrize("num_warps", [4]) -@pytest.mark.parametrize("M,N", [(256, 256)]) -def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM, num_warps, M, N): - """Test TDM async_load with PartitionedSharedLayout (global -> LDS).""" +@pytest.mark.parametrize("BLOCK_M,BLOCK_N,NUM_PARTITIONS,NUM_GROUPS,PARTITION_DIM", _PARTITIONED_TDM_PARAMS) +def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM): + """Test correctness of TDM async_load + transpose LDS load with PartitionedSharedEncodingAttr.""" + M, N = 256, 256 + num_warps = 4 + torch.manual_seed(42) - a = torch.randint(0x0, 0xFFFF, (M, N), dtype=torch.uint16) + a = torch.randn((M, N), dtype=torch.float16) b = torch.zeros_like(a) a_device = a.cuda() @@ -1519,6 +1523,37 @@ def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROU f"partitionDim={PARTITION_DIM}, numPartitions={NUM_PARTITIONS}, numGroups={NUM_GROUPS}") +@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250") +@pytest.mark.parametrize("BLOCK_M,BLOCK_N,NUM_PARTITIONS,NUM_GROUPS,PARTITION_DIM", _PARTITIONED_TDM_PARAMS) +def test_compile_partitioned_tdm_transpose_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM): + """Check that LDS load from PartitionedSharedEncodingAttr lowers to ds_load_tr16 when appropriate.""" + num_warps = 4 + signature = { + "a_ptr": "*fp16", + "b_ptr": "*fp16", + "M": "i32", + "N": "i32", + "BLOCK_M": "constexpr", + "BLOCK_N": "constexpr", + "NUM_PARTITIONS": "constexpr", + "NUM_GROUPS": "constexpr", + "PARTITION_DIM": "constexpr", + } + constexprs = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "NUM_PARTITIONS": NUM_PARTITIONS, + "NUM_GROUPS": NUM_GROUPS, + "PARTITION_DIM": PARTITION_DIM, + } + k = triton.compile( + src=gluon._runtime.GluonASTSource(partitioned_tdm_copy_kernel, signature, constexprs), + target=GPUTarget("hip", 'gfx1250', 32), + options={"num_warps": num_warps}, + ) + assert re.search(r"ds_load_tr16", k.asm["amdgcn"]), "expected transpose LDS loads (ds_load_tr16) in amdgcn" + + @gluon.jit def tensor_device_tdm_multi_cta_load_and_store_kernel(a_ptr, b_ptr, M, N, # BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, From 7e89e338cd6c1144d93d8b853b53f6c088a5ebfa Mon Sep 17 00:00:00 2001 From: Ognjen Plavsic Date: Mon, 20 Apr 2026 11:02:59 +0000 Subject: [PATCH 2/2] Address review comments --- third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp | 3 +-- third_party/amd/python/test/test_gluon_gfx1250.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp index 05c4fba3a4d6..a2bb38f38892 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -63,8 +63,7 @@ class TransLocalLoadOpConversion cvtDstLL = maybeSublayout.value(); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); - SmallVector smemBases(smemObj.getBases().begin(), - smemObj.getBases().end()); + SmallVector smemBases = llvm::to_vector(smemObj.getBases()); auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy); auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy); auto paddingShifts = getPaddedSharedShifts(srcTy.getEncoding(), diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index 83ce27101825..bc38874d46c1 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -1451,7 +1451,7 @@ def partitioned_tdm_copy_kernel(a_ptr, b_ptr, M, N, # block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [num_warps, 1], [1, 0]) WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, 32]) - OPERAND_LAYOUT: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8) + DOT_RHS_LAYOUT: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8) pid_m = ttgl.program_id(axis=0) pid_n = ttgl.program_id(axis=1) @@ -1465,7 +1465,7 @@ def partitioned_tdm_copy_kernel(a_ptr, b_ptr, M, N, # ttgl.amd.gfx1250.tdm.async_load(a_desc, [idx_m, idx_n], a_buffer) ttgl.amd.gfx1250.tdm.async_wait(0) - a_dot = a_buffer.load(layout=OPERAND_LAYOUT) + a_dot = a_buffer.load(layout=DOT_RHS_LAYOUT) a = ttgl.convert_layout(a_dot, block_layout) offs_bm = idx_m + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, block_layout)) @@ -1523,7 +1523,6 @@ def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROU f"partitionDim={PARTITION_DIM}, numPartitions={NUM_PARTITIONS}, numGroups={NUM_GROUPS}") -@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250") @pytest.mark.parametrize("BLOCK_M,BLOCK_N,NUM_PARTITIONS,NUM_GROUPS,PARTITION_DIM", _PARTITIONED_TDM_PARAMS) def test_compile_partitioned_tdm_transpose_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM): """Check that LDS load from PartitionedSharedEncodingAttr lowers to ds_load_tr16 when appropriate."""