diff --git a/Makefile b/Makefile index d8880d3ffd80..e91493c1e719 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,7 @@ test-distributed: all test-gluon: all $(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon $(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py + $(PYTEST) --tb=short -n $(NUM_PROCS) -vs python/tutorials/gluon .PHONY: test-regression test-regression: all diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h index 3cc613a791c8..7697be274646 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -16,10 +16,6 @@ inline bool isFp4Padded(Attribute encoding) { return mmaEnc && mmaEnc.getFp4Padded(); } -SmallVector translateTMAIndices(OpBuilder &builder, Location loc, - Attribute encoding, - SmallVector indices); - gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout, ArrayRef shape); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp index fd650cacc50c..6d368b76a6d3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -243,18 +243,14 @@ void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, CoarseSchedule &schedule) { - return createTMAAsyncCopy( - forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier, - waitOp, schedule, - [&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view, - Value pred) { - auto indices = ttng::translateTMAIndices( - builder, loadOp.getLoc(), - loadOp.getDesc().getType().getBlockType().getEncoding(), - loadOp.getIndices()); - ttng::AsyncTMACopyGlobalToLocalOp::create( - builder, loadOp.getLoc(), tmaPtr, indices, barrier, view, pred); - }); + return createTMAAsyncCopy(forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, + extractIdx, barrier, waitOp, schedule, + [&](OpBuilderForStage &builder, Value desc, + Value barrier, Value view, Value pred) { + ttng::AsyncTMACopyGlobalToLocalOp::create( + builder, loadOp.getLoc(), desc, + loadOp.getIndices(), barrier, view, pred); + }); } void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp, @@ -263,10 +259,10 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp, CoarseSchedule &schedule) { return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc, insertIdx, extractIdx, barrier, waitOp, schedule, - [&](OpBuilderForStage &builder, Value tmaPtr, + [&](OpBuilderForStage &builder, Value desc, Value barrier, Value view, Value pred) { ttng::AsyncTMAGatherOp::create( - builder, gatherOp.getLoc(), tmaPtr, + builder, gatherOp.getLoc(), desc, gatherOp.getXOffsets(), gatherOp.getYOffset(), barrier, view, pred); }); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index dbc9130430bc..2b753f3c6b1a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -60,17 +60,9 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store, ttng::FenceAsyncSharedOp::create(builder, loc, false); auto desc = store.desc; if (auto storeOp = dyn_cast(store.op)) { - auto indices = ttng::translateTMAIndices( - builder, storeOp.getLoc(), - storeOp.getDesc().getType().getBlockType().getEncoding(), - storeOp.getIndices()); ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc, storeOp.getIndices(), alloc); } else if (auto reduceOp = dyn_cast(store.op)) { - auto indices = ttng::translateTMAIndices( - builder, reduceOp.getLoc(), - reduceOp.getDesc().getType().getBlockType().getEncoding(), - reduceOp.getIndices()); ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc, reduceOp.getIndices(), alloc); } else { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 28c83bbb162c..9c7769c7144d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -68,13 +68,11 @@ class TMALoadLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorLoadOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, Value pred) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( - rewriter, op.getLoc(), tmaPtr, indices, barrierAlloc, alloc, pred); + rewriter, op.getLoc(), desc, op.getIndices(), barrierAlloc, alloc, + pred); }; lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); return success(); @@ -86,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorGatherOp op, PatternRewriter &rewriter) const override { - auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, Value pred) { triton::nvidia_gpu::AsyncTMAGatherOp::create( - rewriter, op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(), + rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(), barrierAlloc, alloc, pred); }; lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); @@ -122,12 +120,9 @@ struct TMAStoreLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorStoreOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); + auto createStore = [&](Value desc, Value alloc) { triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create( - rewriter, op.getLoc(), tmaPtr, indices, alloc); + rewriter, op.getLoc(), desc, op.getIndices(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); return success(); @@ -139,12 +134,9 @@ struct TMAReduceLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorReduceOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); + auto createStore = [&](Value desc, Value alloc) { triton::nvidia_gpu::AsyncTMAReduceOp::create( - rewriter, op.getLoc(), op.getKind(), tmaPtr, indices, alloc); + rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); return success(); @@ -156,9 +148,9 @@ struct TMAScatterLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorScatterOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), - tmaPtr, op.getXOffsets(), + auto createStore = [&](Value desc, Value alloc) { + triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc, + op.getXOffsets(), op.getYOffset(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp index 46d68c0b5559..034295db518f 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -7,16 +7,6 @@ namespace ttg = mlir::triton::gpu; namespace mlir::triton::nvidia_gpu { -SmallVector translateTMAIndices(OpBuilder &builder, Location loc, - Attribute encoding, - SmallVector indices) { - if (isFp4Padded(encoding)) { - auto two = arith::ConstantIntOp::create(builder, loc, 2, 32); - indices.back() = arith::MulIOp::create(builder, loc, indices.back(), two); - } - return indices; -} - ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout, ArrayRef shape) { auto rank = shape.size(); diff --git a/python/tutorials/gluon/11-tcgen05-mma-scaled.py b/python/tutorials/gluon/11-tcgen05-mma-scaled.py index f35a9cb9dc12..bb981a0c336a 100644 --- a/python/tutorials/gluon/11-tcgen05-mma-scaled.py +++ b/python/tutorials/gluon/11-tcgen05-mma-scaled.py @@ -173,18 +173,6 @@ def simple_mma_scaled_kernel(a_desc, b_desc, c_desc, a_scale_ptr, a_scale_stride off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - # When issuing a TMA transaction to TMA tensor descriptors with fp4 padded operands, we need to multiply - # the offset along the contiguous dimension by 2 to account for the padding. This applies to async TMA - # loads, stores, gather, and scatter. Failing to do this can result in illegal instruction errors. If you - # catch the illegal instruction error inside `cuda-gdb`, it may point to the TMA instruction or the - # `mbarrier.wait` on the instruction completion barrier. When breaking on the illegal instruction error, - # you can use `x/i $pc` to print the instruction at the faulting address, and for example use `x/-50i $pc` - # to print the previous 50 instructions. - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 - # Load the A and B tiles. mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) @@ -495,10 +483,6 @@ def mma_scaled_contig_kernel(a_desc, b_desc, c_desc, a_scale_ptr, b_scale_ptr, V for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) @@ -741,13 +725,9 @@ def mma_scaled_packed_block_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 # Index the K subtile along REP_K for each scale. - off_k_a_scale = k // BLOCK_K * A_REP_K - off_k_b_scale = k // BLOCK_K * B_REP_K + off_k_a_scale = (k // BLOCK_K) * A_REP_K + off_k_b_scale = (k // BLOCK_K) * B_REP_K mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + @@ -1029,12 +1009,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 - off_k_a_scale = k // BLOCK_K * A_REP_K - off_k_b_scale = k // BLOCK_K * B_REP_K + off_k_a_scale = (k // BLOCK_K) * A_REP_K + off_k_b_scale = (k // BLOCK_K) * B_REP_K mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + @@ -1213,10 +1189,6 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale off_n_b_scale = pid_n * REP_N off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 off_k_a_scale = (k // BLOCK_K) * A_REP_K off_k_b_scale = (k // BLOCK_K) * B_REP_K diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp index 5692f0310576..068263523d36 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp @@ -229,7 +229,6 @@ Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, builder.setInsertionPoint(tmaLoad); auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), buffer, bufferIdx, true); - // FIXME: translateTMAIndices copy = builder.createWithAsyncTaskIds( loc, tmaLoad.getDesc(), tmaLoad.getIndices(), prodBarrier, pipelineBuffer, pred); diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp index 0544394ad8fc..3a7ecbc6a275 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp @@ -286,27 +286,8 @@ getSubViews(ArefValue arefVal, Value stage, Location loc, OpBuilder &rewriter, void createTMALoad(triton::nvws::DescriptorLoadOp op, PatternRewriter &rewriter, Value barrierAlloc, Value pred) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); - for (auto [newIdx, oldIdx] : llvm::zip(indices, op.getIndices())) { - // translateTMAIndices may create ops, we need to annotated them - if (newIdx != oldIdx) { - auto partitionIds = getPartitionWsTagIds(op); - auto stageCluster = getStageCluster(op); - assignStageCluster(newIdx.getDefiningOp(), partitionIds, stageCluster, - rewriter); - for (auto val : newIdx.getDefiningOp()->getOperands()) { - if (auto op = val.getDefiningOp()) { - if (!hasPartition(op)) { - assignStageCluster(op, partitionIds, stageCluster, rewriter); - } - } - } - } - } auto newLoadOp = triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( - rewriter, op.getLoc(), op.getDesc(), indices, barrierAlloc, + rewriter, op.getLoc(), op.getDesc(), op.getIndices(), barrierAlloc, op.getResult(), pred); assignStageCluster(newLoadOp, getPartitionWsTagIds(op), getStageCluster(op), rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a41290a38b19..48c75659399e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1412,10 +1412,16 @@ struct AsyncTMACopyGlobalToLocalOpConversion auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); int operandIdx = 3; + auto encoding = op.getDesc().getType().getBlockType().getEncoding(); + bool fp4Padded = nvidia_gpu::isFp4Padded(encoding); for (int i = 0; i < rank; i++) { Value coord = adaptor.getCoord()[rank - i - 1]; + if (fp4Padded && i == 0) { + coord = b.mul(coord, b.i32_val(2)); + } if (i < offsets.size()) coord = b.add(coord, offsets[offsets.size() - i - 1].second); + operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); tmaInst += "$" + std::to_string(operandIdx++); if (i != rank - 1) @@ -1496,8 +1502,12 @@ LogicalResult convertTMAStoreLikeOp(Operation *op, auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); + bool fp4Padded = nvidia_gpu::isFp4Padded(srcTy.getEncoding()); for (int i = 0; i < rank; i++) { Value coord = coords[rank - i - 1]; + if (fp4Padded && i == 0) { + coord = b.mul(coord, b.i32_val(2)); + } if (i < offsets.size()) coord = b.add(coord, offsets[offsets.size() - i - 1].second); operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); @@ -1623,8 +1633,11 @@ static LogicalResult iterateGatherScatterIndices( return op->emitError("memdesc shape must match alloc shape"); // `NVMMASharedEncodingAttr` means the core matrix tiles are placed next to // each other in shared memory, which lines up with how `gather4` loads data. - if (!isa(smemType.getEncoding())) + auto enc = dyn_cast(smemType.getEncoding()); + if (!enc) return op->emitError("requires dst encoding NVMMASharedEncodingAttr"); + if (enc.getFp4Padded()) + yOffsetValue = b.mul(yOffsetValue, b.i32_val(2)); Type llvmElemTy = typeConverter.convertType(smemType.getElementType()); Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, smemObjValue,