diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 006361d9fb47..3c223842646b 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -584,11 +584,12 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [ } def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait", [MemWaitOpTrait]> { - let summary = "wait until all the inputs are read."; - let arguments = (ins I32Attr:$pendings); + let summary = "wait for pending TMA store operations."; + let arguments = (ins I32Attr:$pendings, UnitAttr:$read_only); let description = [{ - Wait until all the read operations are done from the associated store operations. - This is needed before the shared memory can be written to. + Wait for the associated store operations to complete. When `read_only` is + set, only wait until their reads from shared memory have completed. This is + needed before shared memory can be written to again. }]; let assemblyFormat = "attr-dict"; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 7808f1a168fe..ae1f60bebccd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -54,7 +54,7 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store, // Put wait before the local_store make the store truly async. We know // that we are the only user of the CopyLocalToGlobal. - ttng::TMAStoreWaitOp::create(builder, loc, 0); + ttng::TMAStoreWaitOp::create(builder, loc, 0, /*read_only=*/false); ttg::LocalStoreOp::create(builder, loc, store.src, alloc); ttng::FenceAsyncSharedOp::create(builder, loc, false); auto desc = store.desc; @@ -112,7 +112,8 @@ bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) { // Deallocate shared memory buffers. OpBuilder builder(forOp); builder.setInsertionPointAfter(forOp); - ttng::TMAStoreWaitOp::create(builder, forOp->getLoc(), 0); + ttng::TMAStoreWaitOp::create(builder, forOp->getLoc(), 0, + /*read_only=*/false); for (auto it : storeToAlloc) { ttg::LocalDeallocOp::create(builder, forOp->getLoc(), it.second); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 2c004ee379c1..dc9680e63035 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -113,7 +113,8 @@ static void lowerTMAStore(Operation *op, mlir::TypedValue src, Value alloc = gpu::LocalAllocOp::create(rewriter, loc, memDescType, src); triton::nvidia_gpu::FenceAsyncSharedOp::create(rewriter, loc, false); createStore(desc, alloc); - triton::nvidia_gpu::TMAStoreWaitOp::create(rewriter, loc, 0); + triton::nvidia_gpu::TMAStoreWaitOp::create(rewriter, loc, 0, + /*read_only=*/false); rewriter.eraseOp(op); } diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 4d98119243a2..64772b756970 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -981,8 +981,8 @@ void init_gluon_ir(py::module &&m) { self.create(kind, descPtr, coord, src); }) .def("create_async_tma_store_wait", - [](GluonOpBuilder &self, int pendings) { - self.create(pendings); + [](GluonOpBuilder &self, int pendings, bool readOnly) { + self.create(pendings, readOnly); }) .def( "create_async_tma_gather", diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 1144802fabd6..7c8a5df6f4ed 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -4161,6 +4161,7 @@ def nv_tma_descriptor_store_kernel(input_ptr): smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) tma.async_copy_shared_to_global(input_desc, [0, 0], smem) tma.store_wait(0) + tma.store_wait(0, read_only=True) ptr = MockTensor(ttgl.float32) module = run_parser(nv_tma_descriptor_store_kernel, *make_args(ptr), target) @@ -4180,6 +4181,7 @@ def nv_tma_descriptor_store_kernel(input_ptr): %c0_i32_1 = arith.constant 0 : i32 ttng.async_tma_copy_local_to_global %0[%c0_i32, %c0_i32_1] %1 : !tt.tensordesc<128x128xf32, #shared>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable> ttng.async_tma_store_wait {pendings = 0 : i32} + ttng.async_tma_store_wait {pendings = 0 : i32, read_only} tt.return } } diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py index fad68fc357ba..888630198046 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -384,9 +384,10 @@ def async_atomic_xor(tensor_desc, coord, src, _semantic=None): @builtin -def store_wait(pendings, _semantic=None): +def store_wait(pendings, read_only=False, _semantic=None): pendings = _unwrap_if_constexpr(pendings) - _semantic.builder.create_async_tma_store_wait(pendings) + read_only = _unwrap_if_constexpr(read_only) + _semantic.builder.create_async_tma_store_wait(pendings, read_only) @builtin diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index ccf2064eba81..20ddc8920357 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -358,7 +358,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: async_tma_store_wait - // CHECK: nvvm.cp.async.bulk.wait_group 0 {read} + // CHECK: nvvm.cp.async.bulk.wait_group 0{{$}} tt.func @async_tma_store_wait() { ttng.async_tma_store_wait {pendings = 0 : i32} tt.return @@ -367,6 +367,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: async_tma_store_wait_read_only + // CHECK: nvvm.cp.async.bulk.wait_group 0 {read} + tt.func @async_tma_store_wait_read_only() { + ttng.async_tma_store_wait {pendings = 0 : i32, read_only} + tt.return + } +} + +// ----- + #shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index 2ceb3fb14d31..5a2478b265f1 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -96,7 +96,7 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tenso // CHECK-NEXT: [[SRC:%.*]] = ttg.local_alloc %arg3 // CHECK-NEXT: ttng.fence_async_shared {bCluster = false} // CHECK-NEXT: ttng.async_tma_scatter %arg0[%arg1, %arg2] [[SRC]] - // CHECK-NEXT: ttng.async_tma_store_wait + // CHECK-NEXT: ttng.async_tma_store_wait {pendings = 0 : i32} tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #offsets>, i32, tensor<32x128xbf16, #blocked1> tt.return } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index c7c0f583b7f5..3123f222caeb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1832,10 +1832,8 @@ struct TMAStoreWaitOpConversion LogicalResult matchAndRewrite(triton::nvidia_gpu::TMAStoreWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto ctx = op.getContext(); - auto isRead = UnitAttr::get(ctx); rewriter.replaceOpWithNewOp( - op, op.getPendingsAttr(), isRead); + op, op.getPendingsAttr(), op.getReadOnlyAttr()); return success(); } };