From 2e663f63f882af219c7165f43147ac24ff63bee3 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 22 Dec 2025 20:32:35 +0000 Subject: [PATCH] [Backend] Move TMA index translation from mid-end, to lowerings Currently the behavior of fp4_padded is different between `triton::Descriptor` ops and `AsyncTMA` ops. The former is indexed as if the data is int8, while the latter is indexed by individual fp4 elements, which is what the TMA hardware expects. This now gets leaked into gluon, which isn't ideal. So, this PR moves the translation into the lowerings. Along the way, this probably fixes quite a few bugs as there were several places the translation was missing. --- Makefile | 1 + .../TritonNvidiaGPU/Transforms/TMAUtilities.h | 4 --- .../Transforms/Pipeliner/LowerLoops.cpp | 24 ++++++------- .../Pipeliner/TMAStoresPipeline.cpp | 8 ----- .../Transforms/TMALowering.cpp | 32 +++++++---------- .../Transforms/TMAUtilities.cpp | 10 ------ .../tutorials/gluon/11-tcgen05-mma-scaled.py | 36 +++---------------- .../WarpSpecialization/WSLowerMem.cpp | 1 - .../lib/Dialect/NVWS/Transforms/LowerAref.cpp | 21 +---------- .../LoadStoreOpToLLVM.cpp | 15 +++++++- 10 files changed, 42 insertions(+), 110 deletions(-) diff --git a/Makefile b/Makefile index d8880d3ffd80..e91493c1e719 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,7 @@ test-distributed: all test-gluon: all $(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon $(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py + $(PYTEST) --tb=short -n $(NUM_PROCS) -vs python/tutorials/gluon .PHONY: test-regression test-regression: all diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h index 3cc613a791c8..7697be274646 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -16,10 +16,6 @@ inline bool isFp4Padded(Attribute encoding) { return mmaEnc && mmaEnc.getFp4Padded(); } -SmallVector translateTMAIndices(OpBuilder &builder, Location loc, - Attribute encoding, - SmallVector indices); - gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout, ArrayRef shape); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp index fd650cacc50c..6d368b76a6d3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -243,18 +243,14 @@ void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, CoarseSchedule &schedule) { - return createTMAAsyncCopy( - forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier, - waitOp, schedule, - [&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view, - Value pred) { - auto indices = ttng::translateTMAIndices( - builder, loadOp.getLoc(), - loadOp.getDesc().getType().getBlockType().getEncoding(), - loadOp.getIndices()); - ttng::AsyncTMACopyGlobalToLocalOp::create( - builder, loadOp.getLoc(), tmaPtr, indices, barrier, view, pred); - }); + return createTMAAsyncCopy(forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, + extractIdx, barrier, waitOp, schedule, + [&](OpBuilderForStage &builder, Value desc, + Value barrier, Value view, Value pred) { + ttng::AsyncTMACopyGlobalToLocalOp::create( + builder, loadOp.getLoc(), desc, + loadOp.getIndices(), barrier, view, pred); + }); } void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp, @@ -263,10 +259,10 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp, CoarseSchedule &schedule) { return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc, insertIdx, extractIdx, barrier, waitOp, schedule, - [&](OpBuilderForStage &builder, Value tmaPtr, + [&](OpBuilderForStage &builder, Value desc, Value barrier, Value view, Value pred) { ttng::AsyncTMAGatherOp::create( - builder, gatherOp.getLoc(), tmaPtr, + builder, gatherOp.getLoc(), desc, gatherOp.getXOffsets(), gatherOp.getYOffset(), barrier, view, pred); }); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index dbc9130430bc..2b753f3c6b1a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -60,17 +60,9 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store, ttng::FenceAsyncSharedOp::create(builder, loc, false); auto desc = store.desc; if (auto storeOp = dyn_cast(store.op)) { - auto indices = ttng::translateTMAIndices( - builder, storeOp.getLoc(), - storeOp.getDesc().getType().getBlockType().getEncoding(), - storeOp.getIndices()); ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc, storeOp.getIndices(), alloc); } else if (auto reduceOp = dyn_cast(store.op)) { - auto indices = ttng::translateTMAIndices( - builder, reduceOp.getLoc(), - reduceOp.getDesc().getType().getBlockType().getEncoding(), - reduceOp.getIndices()); ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc, reduceOp.getIndices(), alloc); } else { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 28c83bbb162c..9c7769c7144d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -68,13 +68,11 @@ class TMALoadLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorLoadOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, Value pred) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( - rewriter, op.getLoc(), tmaPtr, indices, barrierAlloc, alloc, pred); + rewriter, op.getLoc(), desc, op.getIndices(), barrierAlloc, alloc, + pred); }; lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); return success(); @@ -86,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorGatherOp op, PatternRewriter &rewriter) const override { - auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, Value pred) { triton::nvidia_gpu::AsyncTMAGatherOp::create( - rewriter, op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(), + rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(), barrierAlloc, alloc, pred); }; lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); @@ -122,12 +120,9 @@ struct TMAStoreLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorStoreOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); + auto createStore = [&](Value desc, Value alloc) { triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create( - rewriter, op.getLoc(), tmaPtr, indices, alloc); + rewriter, op.getLoc(), desc, op.getIndices(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); return success(); @@ -139,12 +134,9 @@ struct TMAReduceLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorReduceOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); + auto createStore = [&](Value desc, Value alloc) { triton::nvidia_gpu::AsyncTMAReduceOp::create( - rewriter, op.getLoc(), op.getKind(), tmaPtr, indices, alloc); + rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); return success(); @@ -156,9 +148,9 @@ struct TMAScatterLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorScatterOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), - tmaPtr, op.getXOffsets(), + auto createStore = [&](Value desc, Value alloc) { + triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc, + op.getXOffsets(), op.getYOffset(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp index 46d68c0b5559..034295db518f 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -7,16 +7,6 @@ namespace ttg = mlir::triton::gpu; namespace mlir::triton::nvidia_gpu { -SmallVector translateTMAIndices(OpBuilder &builder, Location loc, - Attribute encoding, - SmallVector indices) { - if (isFp4Padded(encoding)) { - auto two = arith::ConstantIntOp::create(builder, loc, 2, 32); - indices.back() = arith::MulIOp::create(builder, loc, indices.back(), two); - } - return indices; -} - ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout, ArrayRef shape) { auto rank = shape.size(); diff --git a/python/tutorials/gluon/11-tcgen05-mma-scaled.py b/python/tutorials/gluon/11-tcgen05-mma-scaled.py index f35a9cb9dc12..bb981a0c336a 100644 --- a/python/tutorials/gluon/11-tcgen05-mma-scaled.py +++ b/python/tutorials/gluon/11-tcgen05-mma-scaled.py @@ -173,18 +173,6 @@ def simple_mma_scaled_kernel(a_desc, b_desc, c_desc, a_scale_ptr, a_scale_stride off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - # When issuing a TMA transaction to TMA tensor descriptors with fp4 padded operands, we need to multiply - # the offset along the contiguous dimension by 2 to account for the padding. This applies to async TMA - # loads, stores, gather, and scatter. Failing to do this can result in illegal instruction errors. If you - # catch the illegal instruction error inside `cuda-gdb`, it may point to the TMA instruction or the - # `mbarrier.wait` on the instruction completion barrier. When breaking on the illegal instruction error, - # you can use `x/i $pc` to print the instruction at the faulting address, and for example use `x/-50i $pc` - # to print the previous 50 instructions. - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 - # Load the A and B tiles. mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) @@ -495,10 +483,6 @@ def mma_scaled_contig_kernel(a_desc, b_desc, c_desc, a_scale_ptr, b_scale_ptr, V for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) @@ -741,13 +725,9 @@ def mma_scaled_packed_block_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 # Index the K subtile along REP_K for each scale. - off_k_a_scale = k // BLOCK_K * A_REP_K - off_k_b_scale = k // BLOCK_K * B_REP_K + off_k_a_scale = (k // BLOCK_K) * A_REP_K + off_k_b_scale = (k // BLOCK_K) * B_REP_K mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + @@ -1029,12 +1009,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 - off_k_a_scale = k // BLOCK_K * A_REP_K - off_k_b_scale = k // BLOCK_K * B_REP_K + off_k_a_scale = (k // BLOCK_K) * A_REP_K + off_k_b_scale = (k // BLOCK_K) * B_REP_K mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + @@ -1213,10 +1189,6 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale off_n_b_scale = pid_n * REP_N off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 off_k_a_scale = (k // BLOCK_K) * A_REP_K off_k_b_scale = (k // BLOCK_K) * B_REP_K diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp index 5692f0310576..068263523d36 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp @@ -229,7 +229,6 @@ Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, builder.setInsertionPoint(tmaLoad); auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), buffer, bufferIdx, true); - // FIXME: translateTMAIndices copy = builder.createWithAsyncTaskIds( loc, tmaLoad.getDesc(), tmaLoad.getIndices(), prodBarrier, pipelineBuffer, pred); diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp index 0544394ad8fc..3a7ecbc6a275 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp @@ -286,27 +286,8 @@ getSubViews(ArefValue arefVal, Value stage, Location loc, OpBuilder &rewriter, void createTMALoad(triton::nvws::DescriptorLoadOp op, PatternRewriter &rewriter, Value barrierAlloc, Value pred) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); - for (auto [newIdx, oldIdx] : llvm::zip(indices, op.getIndices())) { - // translateTMAIndices may create ops, we need to annotated them - if (newIdx != oldIdx) { - auto partitionIds = getPartitionWsTagIds(op); - auto stageCluster = getStageCluster(op); - assignStageCluster(newIdx.getDefiningOp(), partitionIds, stageCluster, - rewriter); - for (auto val : newIdx.getDefiningOp()->getOperands()) { - if (auto op = val.getDefiningOp()) { - if (!hasPartition(op)) { - assignStageCluster(op, partitionIds, stageCluster, rewriter); - } - } - } - } - } auto newLoadOp = triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( - rewriter, op.getLoc(), op.getDesc(), indices, barrierAlloc, + rewriter, op.getLoc(), op.getDesc(), op.getIndices(), barrierAlloc, op.getResult(), pred); assignStageCluster(newLoadOp, getPartitionWsTagIds(op), getStageCluster(op), rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a41290a38b19..48c75659399e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1412,10 +1412,16 @@ struct AsyncTMACopyGlobalToLocalOpConversion auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); int operandIdx = 3; + auto encoding = op.getDesc().getType().getBlockType().getEncoding(); + bool fp4Padded = nvidia_gpu::isFp4Padded(encoding); for (int i = 0; i < rank; i++) { Value coord = adaptor.getCoord()[rank - i - 1]; + if (fp4Padded && i == 0) { + coord = b.mul(coord, b.i32_val(2)); + } if (i < offsets.size()) coord = b.add(coord, offsets[offsets.size() - i - 1].second); + operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); tmaInst += "$" + std::to_string(operandIdx++); if (i != rank - 1) @@ -1496,8 +1502,12 @@ LogicalResult convertTMAStoreLikeOp(Operation *op, auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); + bool fp4Padded = nvidia_gpu::isFp4Padded(srcTy.getEncoding()); for (int i = 0; i < rank; i++) { Value coord = coords[rank - i - 1]; + if (fp4Padded && i == 0) { + coord = b.mul(coord, b.i32_val(2)); + } if (i < offsets.size()) coord = b.add(coord, offsets[offsets.size() - i - 1].second); operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); @@ -1623,8 +1633,11 @@ static LogicalResult iterateGatherScatterIndices( return op->emitError("memdesc shape must match alloc shape"); // `NVMMASharedEncodingAttr` means the core matrix tiles are placed next to // each other in shared memory, which lines up with how `gather4` loads data. - if (!isa(smemType.getEncoding())) + auto enc = dyn_cast(smemType.getEncoding()); + if (!enc) return op->emitError("requires dst encoding NVMMASharedEncodingAttr"); + if (enc.getFp4Padded()) + yOffsetValue = b.mul(yOffsetValue, b.i32_val(2)); Type llvmElemTy = typeConverter.convertType(smemType.getElementType()); Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, smemObjValue,