Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,12 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [
}

def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait", [MemWaitOpTrait]> {
let summary = "wait until all the inputs are read.";
let arguments = (ins I32Attr:$pendings);
let summary = "wait for pending TMA store operations.";
let arguments = (ins I32Attr:$pendings, UnitAttr:$read_only);
let description = [{
Wait until all the read operations are done from the associated store operations.
This is needed before the shared memory can be written to.
Wait for the associated store operations to complete. When `read_only` is
set, only wait until their reads from shared memory have completed. This is
needed before shared memory can be written to again.
}];

let assemblyFormat = "attr-dict";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store,

// Put wait before the local_store make the store truly async. We know
// that we are the only user of the CopyLocalToGlobal.
ttng::TMAStoreWaitOp::create(builder, loc, 0);
ttng::TMAStoreWaitOp::create(builder, loc, 0, /*read_only=*/false);
ttg::LocalStoreOp::create(builder, loc, store.src, alloc);
ttng::FenceAsyncSharedOp::create(builder, loc, false);
auto desc = store.desc;
Expand Down Expand Up @@ -112,7 +112,8 @@ bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) {
// Deallocate shared memory buffers.
OpBuilder builder(forOp);
builder.setInsertionPointAfter(forOp);
ttng::TMAStoreWaitOp::create(builder, forOp->getLoc(), 0);
ttng::TMAStoreWaitOp::create(builder, forOp->getLoc(), 0,
/*read_only=*/false);
for (auto it : storeToAlloc) {
ttg::LocalDeallocOp::create(builder, forOp->getLoc(), it.second);
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ static void lowerTMAStore(Operation *op, mlir::TypedValue<RankedTensorType> src,
Value alloc = gpu::LocalAllocOp::create(rewriter, loc, memDescType, src);
triton::nvidia_gpu::FenceAsyncSharedOp::create(rewriter, loc, false);
createStore(desc, alloc);
triton::nvidia_gpu::TMAStoreWaitOp::create(rewriter, loc, 0);
triton::nvidia_gpu::TMAStoreWaitOp::create(rewriter, loc, 0,
/*read_only=*/false);
rewriter.eraseOp(op);
}

Expand Down
4 changes: 2 additions & 2 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,8 +981,8 @@ void init_gluon_ir(py::module &&m) {
self.create<ttng::AsyncTMAReduceOp>(kind, descPtr, coord, src);
})
.def("create_async_tma_store_wait",
[](GluonOpBuilder &self, int pendings) {
self.create<ttng::TMAStoreWaitOp>(pendings);
[](GluonOpBuilder &self, int pendings, bool readOnly) {
self.create<ttng::TMAStoreWaitOp>(pendings, readOnly);
})
.def(
"create_async_tma_gather",
Expand Down
2 changes: 2 additions & 0 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4161,6 +4161,7 @@ def nv_tma_descriptor_store_kernel(input_ptr):
smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout)
tma.async_copy_shared_to_global(input_desc, [0, 0], smem)
tma.store_wait(0)
tma.store_wait(0, read_only=True)

ptr = MockTensor(ttgl.float32)
module = run_parser(nv_tma_descriptor_store_kernel, *make_args(ptr), target)
Expand All @@ -4180,6 +4181,7 @@ def nv_tma_descriptor_store_kernel(input_ptr):
%c0_i32_1 = arith.constant 0 : i32
ttng.async_tma_copy_local_to_global %0[%c0_i32, %c0_i32_1] %1 : !tt.tensordesc<128x128xf32, #shared>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable>
ttng.async_tma_store_wait {pendings = 0 : i32}
ttng.async_tma_store_wait {pendings = 0 : i32, read_only}
tt.return
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,10 @@ def async_atomic_xor(tensor_desc, coord, src, _semantic=None):


@builtin
def store_wait(pendings, _semantic=None):
def store_wait(pendings, read_only=False, _semantic=None):
pendings = _unwrap_if_constexpr(pendings)
_semantic.builder.create_async_tma_store_wait(pendings)
read_only = _unwrap_if_constexpr(read_only)
_semantic.builder.create_async_tma_store_wait(pendings, read_only)


@builtin
Expand Down
13 changes: 12 additions & 1 deletion test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} {

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: async_tma_store_wait
// CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
// CHECK: nvvm.cp.async.bulk.wait_group 0{{$}}
tt.func @async_tma_store_wait() {
ttng.async_tma_store_wait {pendings = 0 : i32}
tt.return
Expand All @@ -367,6 +367,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

// -----

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: async_tma_store_wait_read_only
// CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
tt.func @async_tma_store_wait_read_only() {
ttng.async_tma_store_wait {pendings = 0 : i32, read_only}
tt.return
}
}

// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
Expand Down
2 changes: 1 addition & 1 deletion test/TritonNvidiaGPU/tma_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<1x128xbf16, #nvmma_128>, %arg1: tenso
// CHECK-NEXT: [[SRC:%.*]] = ttg.local_alloc %arg3
// CHECK-NEXT: ttng.fence_async_shared {bCluster = false}
// CHECK-NEXT: ttng.async_tma_scatter %arg0[%arg1, %arg2] [[SRC]]
// CHECK-NEXT: ttng.async_tma_store_wait
// CHECK-NEXT: ttng.async_tma_store_wait {pendings = 0 : i32}
tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<1x128xbf16, #nvmma_128>, tensor<32xi32, #offsets>, i32, tensor<32x128xbf16, #blocked1>
tt.return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1832,10 +1832,8 @@ struct TMAStoreWaitOpConversion
LogicalResult
matchAndRewrite(triton::nvidia_gpu::TMAStoreWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ctx = op.getContext();
auto isRead = UnitAttr::get(ctx);
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkWaitGroupOp>(
op, op.getPendingsAttr(), isRead);
op, op.getPendingsAttr(), op.getReadOnlyAttr());
return success();
}
};
Expand Down
Loading