From 098b4f020ba93c613cf361d658f977316e9170d1 Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 7 Apr 2026 16:26:15 +0200 Subject: [PATCH 1/6] [BACKEND] Tighten TMA multicta layouts As per the PTX docs, TMAs have a very specific behaviour when executed in a 2CTA kernel: >.cta_group::1 : The mbarrier signal is also multicasted to the same offset as mbar in the shared memory of the destination CTA. .cta_group::2 : The mbarrier signal is multicasted either to all the odd numbered CTAs or the even numbered CTAs within the corresponding CTA-Pair. For each destination CTA specified in the ctaMask, the mbarrier signal is sent either to the destination CTA or its peer-CTA based on CTAs %cluster_ctarank parity of shared memory where the mbarrier object mbar resides. As such, we require these CTA layouts in TMA barriers. --- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 81 +++++++++++++------- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 4 +- test/TritonNvidiaGPU/invalid.mlir | 42 +++++++++- test/TritonNvidiaGPU/membar-cluster.mlir | 38 ++++----- 4 files changed, 113 insertions(+), 52 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 5bd29e234fa8..7a2c1f491c91 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -290,6 +290,57 @@ LogicalResult ClusterBarrierOp::verify() { } // -- TMA operation verifiers -- +static std::string formatCGALayout(CGAEncodingAttr cgaLayout) { + std::string str; + llvm::raw_string_ostream os(str); + auto kBlock = StringAttr::get(cgaLayout.getContext(), "block"); + os << "["; + llvm::interleaveComma(cgaLayout.getLinearLayout().getBases().lookup(kBlock), + os, [&](const auto &basis) { + os << "["; + llvm::interleaveComma(basis, os); + os << "]"; + }); + os << "]"; + return os.str(); +} + +static LogicalResult verifyBarrierCGALayout(Operation *op, Value barrier, + CGAEncodingAttr expectedCGALayout, + StringRef barrierName) { + auto barrierTy = cast(barrier.getType()); + auto actualCGALayout = getCGALayout(barrierTy.getEncoding()); + if (actualCGALayout != expectedCGALayout) + return op->emitOpError() << barrierName << " cga_layout must be " + << formatCGALayout(expectedCGALayout) << ", got " + << formatCGALayout(actualCGALayout); + return success(); +} + +static LogicalResult verifyCompletionBarrierLayout(Operation *op, + Value barrier) { + auto expectedCGALayout = + CGAEncodingAttr::get1DLayout(op->getContext(), gpu::lookupNumCTAs(op)); + return verifyBarrierCGALayout(op, barrier, expectedCGALayout, + "completion barrier"); +} + +static LogicalResult verifyTMABarrierLayout(Operation *op, Value barrier) { + auto ctx = op->getContext(); + int numCTAs = gpu::lookupNumCTAs(op); + CGAEncodingAttr expectedCGALayout; + if (getModuleTwoCTAs(op)) { + auto kBlock = StringAttr::get(ctx, "block"); + auto dim = standardOutDimNames(ctx, /*rank=*/1)[0]; + auto layout = LinearLayout::zeros1D(2, kBlock, dim) * + LinearLayout::identity1D(numCTAs / 2, kBlock, dim); + expectedCGALayout = CGAEncodingAttr::get(ctx, std::move(layout)); + } else { + expectedCGALayout = CGAEncodingAttr::get1DLayout(ctx, numCTAs); + } + return verifyBarrierCGALayout(op, barrier, expectedCGALayout, "TMA barrier"); +} + static LogicalResult verifyTMAEncoding(Operation *op, TensorDescInterface desc, Attribute enc) { auto nvmma = dyn_cast(enc); @@ -318,6 +369,8 @@ static LogicalResult verifyAsyncTMALoadOp(Operation *op, MemDescType resultType) { if (failed(verifyBarrierType(op, barrier.getType()))) return failure(); + if (failed(verifyTMABarrierLayout(op, barrier))) + return failure(); if (!resultType.getMutableMemory()) return op->emitOpError("cannot store into immutable memory"); if (failed(verifyTMAEncoding(op, desc, resultType.getEncoding()))) @@ -599,34 +652,6 @@ static LogicalResult verifyMMADType(Operation *op, Type a, Type b, Type d) { return success(); } -static std::string formatCGALayout(CGAEncodingAttr cgaLayout) { - std::string str; - llvm::raw_string_ostream os(str); - auto kBlock = StringAttr::get(cgaLayout.getContext(), "block"); - os << "["; - llvm::interleaveComma(cgaLayout.getLinearLayout().getBases().lookup(kBlock), - os, [&](const auto &basis) { - os << "["; - llvm::interleaveComma(basis, os); - os << "]"; - }); - os << "]"; - return os.str(); -} - -static LogicalResult verifyCompletionBarrierLayout(Operation *op, - Value barrier) { - auto barrierTy = cast(barrier.getType()); - auto expectedCGALayout = - CGAEncodingAttr::get1DLayout(op->getContext(), gpu::lookupNumCTAs(op)); - auto actualCGALayout = getCGALayout(barrierTy.getEncoding()); - if (actualCGALayout != expectedCGALayout) - return op->emitOpError("completion barrier cga_layout must be ") - << formatCGALayout(expectedCGALayout) << ", got " - << formatCGALayout(actualCGALayout); - return success(); -} - LogicalResult TCGen5MMAOp::verify() { if (!getIsAsync() && !getBarriers().empty()) { return emitOpError("The op is synchronous but a barrier is present."); diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index b745b8df87ed..6a5170f3afb7 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -167,14 +167,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared0_cluster = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #shared1_cga = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CGALayout = [[0, 0]]}> #smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true} { // CHECK-LABEL: tma_copy_barrier_mask_nonzero // Barrier pointer is modified when barrier mask != 0 // CHECK: llvm.ptrtoint // CHECK: llvm.and // CHECK: llvm.inttoptr // TMA uses shared::cluster when barrier mask is non-zero - // CHECK: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes + // CHECK: cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes // CHECK-NOT: cp.async.bulk.tensor.2d.shared::cta.global.mbarrier tt.func @tma_copy_barrier_mask_nonzero(%tma: !tt.tensordesc<128x128xf32, #shared1_cga>, %alloc: !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0_cluster, #smem>, %pred: i1) { ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.tensordesc<128x128xf32, #shared1_cga>, !ttg.memdesc<1xi64, #shared0_cluster, #smem> -> !ttg.memdesc<128x128xf32, #shared1_cga, #smem, mutable> diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index 94d96fb9b22b..d20d8beedc2b 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -168,6 +168,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ } // ----- +#nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> +#barrier = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_tma_copy_global_to_local_requires_1d_barrier_layout( + %arg0: !tt.tensordesc<64x128xf16, #nvmma>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrier, #smem, mutable> + // expected-error @below {{TMA barrier cga_layout must be [[1]], got [[0]]}} + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true : !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrier, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + tt.return + } +} + +// ----- + +#nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0], [0, 1]]}> +#barrier = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1], [2]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @async_tma_copy_global_to_local_requires_two_cta_barrier_layout( + %arg0: !tt.tensordesc<64x128xf16, #nvmma>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4xi64, #barrier, #smem, mutable> + // expected-error @below {{TMA barrier cga_layout must be [[0], [1]], got [[1], [2]]}} + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true : !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<4xi64, #barrier, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> @@ -602,16 +638,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #nvmma_no_broadcast = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> -#shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#shared_bar = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @async_tma_copy_multicast_requires_broadcast(%arg0: !tt.tensordesc<64x128xf16, #nvmma_no_broadcast>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma_no_broadcast, #smem, mutable> - %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #shared_bar, #smem, mutable> // expected-error @below {{multicast requires the shared layout to broadcast across CTAs}} - ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true {multicast} : !tt.tensordesc<64x128xf16, #nvmma_no_broadcast>, !ttg.memdesc<1xi64, #shared_bar, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_no_broadcast, #smem, mutable> + ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %1, %true {multicast} : !tt.tensordesc<64x128xf16, #nvmma_no_broadcast>, !ttg.memdesc<2xi64, #shared_bar, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma_no_broadcast, #smem, mutable> tt.return } } diff --git a/test/TritonNvidiaGPU/membar-cluster.mlir b/test/TritonNvidiaGPU/membar-cluster.mlir index c49082d9481f..28259ea2a430 100644 --- a/test/TritonNvidiaGPU/membar-cluster.mlir +++ b/test/TritonNvidiaGPU/membar-cluster.mlir @@ -369,7 +369,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blockedTmaSrc = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #blockedTmaDst = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #nvmmaTma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> -#barrierEncTma = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrierEncTma = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { @@ -386,17 +386,17 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.func @convert_layout_trivial_then_tma_multicast_cluster_barrier(%input: tensor<64x128xf16, #blockedTmaSrc>, %desc: !tt.tensordesc<64x128xf16, #nvmmaTma>) -> tensor<64x128xf16, #blockedTmaDst> { %c0 = arith.constant 0 : i32 %true = arith.constant true - %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> - ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #barrierEncTma, #smem, mutable> + ttng.init_barrier %barrier, 1 : !ttg.memdesc<2xi64, #barrierEncTma, #smem, mutable> %cvt = ttg.convert_layout %input : tensor<64x128xf16, #blockedTmaSrc> -> tensor<64x128xf16, #blockedTmaDst> %dst = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %dst, %barrier, %true {multicast} : - !tt.tensordesc<64x128xf16, #nvmmaTma>, !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmmaTma>, !ttg.memdesc<2xi64, #barrierEncTma, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %dst : - !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable>, + !ttg.memdesc<2xi64, #barrierEncTma, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> ttg.local_dealloc %dst : !ttg.memdesc<64x128xf16, #nvmmaTma, #smem, mutable> - ttg.local_dealloc %barrier : !ttg.memdesc<1xi64, #barrierEncTma, #smem, mutable> + ttg.local_dealloc %barrier : !ttg.memdesc<2xi64, #barrierEncTma, #smem, mutable> tt.return %cvt : tensor<64x128xf16, #blockedTmaDst> } } @@ -478,7 +478,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = true, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // Multicast TMA still needs init sync even if the barrier allocation shape // looks per-CTA. // CHECK-LABEL: @cluster_tma_multicast_with_per_cta_barrier @@ -537,7 +537,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> -#barrierEnc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrierEnc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #smem = #ttg.shared_memory @@ -560,12 +560,12 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %true = arith.constant true %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf16, #blocked> %buf = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> - %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> - ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> + ttng.init_barrier %barrier, 1 : !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf, %barrier, %true {multicast} : - !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %buf : - !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, + !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttg.local_store %cst, %buf : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> %ld = ttg.local_load %buf : !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> -> tensor<64x128xf16, #blocked> @@ -686,7 +686,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.tw // ----- #nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> -#barrierEnc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrierEnc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #smem = #ttg.shared_memory @@ -707,14 +707,14 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0 = arith.constant 0 : i32 %true = arith.constant true - %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> - ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> + ttng.init_barrier %barrier, 1 : !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> // a lifetime start %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %a, %barrier, %true {multicast} : - !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %a : - !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, + !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> %t = ttg.local_load %a : !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> -> tensor<64x128xf16, #blocked> ttg.local_dealloc %a : !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> @@ -723,9 +723,9 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // b lifetime start %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %b, %barrier, %true {multicast} : - !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> + !tt.tensordesc<64x128xf16, #nvmma>, !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable> -> !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> ttng.wait_barrier %barrier, %c0 deps %b : - !ttg.memdesc<1xi64, #barrierEnc, #smem, mutable>, + !ttg.memdesc<2xi64, #barrierEnc, #smem, mutable>, !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> %t2 = ttg.local_load %b : !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> -> tensor<64x128xf16, #blocked> ttg.local_dealloc %b : !ttg.memdesc<64x128xf16, #nvmma, #smem, mutable> From 8adaa55fb90f92292fa6291e9846f5e4218ca0b8 Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 7 Apr 2026 16:31:07 +0200 Subject: [PATCH 2/6] add small comment to tutorial --- python/tutorials/gluon/14-multicta.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tutorials/gluon/14-multicta.py b/python/tutorials/gluon/14-multicta.py index 1f7936a750eb..51777d9baeda 100644 --- a/python/tutorials/gluon/14-multicta.py +++ b/python/tutorials/gluon/14-multicta.py @@ -477,8 +477,9 @@ def broadcast(b): # every CTA in the multicast group atomically, so the wait side does not need a # different API. # -# The only new ingredient is the layout. The TMA destination must use a -# broadcast `cga_layout`, so that both CTAs view the same shared-memory tile. +# The TMA destination must use a broadcast `cga_layout`, so that both CTAs +# receive the same shared-memory tile. The barrier stays a regular 1D TMA +# barrier unless the kernel is in 2CTA mode. # # The example below keeps things intentionally simple: it multicasts one tile # into shared memory and then materializes that same tile back to global memory. @@ -489,6 +490,7 @@ def tma_multicast_copy_kernel(in_desc, out_desc): gl.static_assert(gl.num_ctas() == 2) smem = gl.allocate_shared_memory(in_desc.dtype, in_desc.block_shape, in_desc.layout) + # This kernel is not in 2CTA mode, so the TMA barrier is per-CTA. bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) From b364481259d479923584c70809807fbd12088a99 Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 7 Apr 2026 17:48:29 +0200 Subject: [PATCH 3/6] don't verify until we have the module two-ctas attr --- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 7 ++++++- test/TritonNvidiaGPU/invalid.mlir | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 7a2c1f491c91..c8777a801d88 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -326,10 +326,15 @@ static LogicalResult verifyCompletionBarrierLayout(Operation *op, } static LogicalResult verifyTMABarrierLayout(Operation *op, Value barrier) { + auto twoCTAsAttr = + op->getParentOfType()->getAttrOfType(AttrTwoCTAsName); + if (!twoCTAsAttr) + return success(); + auto ctx = op->getContext(); int numCTAs = gpu::lookupNumCTAs(op); CGAEncodingAttr expectedCGALayout; - if (getModuleTwoCTAs(op)) { + if (twoCTAsAttr.getValue()) { auto kBlock = StringAttr::get(ctx, "block"); auto dim = standardOutDimNames(ctx, /*rank=*/1)[0]; auto layout = LinearLayout::zeros1D(2, kBlock, dim) * diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index d20d8beedc2b..84bd0d8191e3 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -171,7 +171,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> #barrier = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttng.two-ctas" = false, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @async_tma_copy_global_to_local_requires_1d_barrier_layout( %arg0: !tt.tensordesc<64x128xf16, #nvmma>) { %true = arith.constant true From 69ec58d4168b7d70586b39bac65a9c455759a2f9 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 8 Apr 2026 16:13:00 +0200 Subject: [PATCH 4/6] fix --- python/test/gluon/test_consan.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 0996548a110c..33f60b9ae07e 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -741,21 +741,24 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, TWO_CTAS)), ) - bar = mbarrier.allocate_mbarrier(batch=2) + mma_bar = mbarrier.allocate_mbarrier() acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) - mbarrier.init(bar.index(0), count=1) - mbarrier.init(bar.index(1), count=1) + mbarrier.init(mma_bar, count=1) + if MEM_ACCESS_KIND == "tma_cp": + tma_bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTAS) + mbarrier.init(tma_bar, count=1) blackwell.tcgen05_mma(smemA, smemB, acc) - blackwell.tcgen05_commit(bar.index(0)) + blackwell.tcgen05_commit(mma_bar) if not FAILURE: - mbarrier.wait(bar.index(0), 0) + mbarrier.wait(mma_bar, 0) if MEM_ACCESS_KIND == "tma_cp": - mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar.index(1), smemA) - mbarrier.wait(bar.index(1), 0) + mbarrier.expect(tma_bar, input_desc.nbytes_per_cta) + tma.async_copy_global_to_shared(input_desc, [0, 0], tma_bar, smemA) + mbarrier.wait(tma_bar, 0) + mbarrier.invalidate(tma_bar) elif MEM_ACCESS_KIND == "local_store": smemA.store(ttgl.full([block_m, XBLOCK], 42, ttgl.float16, smem_a_blocked_layout)) elif MEM_ACCESS_KIND == "tmem_load": @@ -769,8 +772,7 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt elif MEM_ACCESS_KIND == "tmem_store": acc.store(ttgl.full([block_m, block_n], 42, ttgl.float32, acc_blocked_layout)) - mbarrier.invalidate(bar.index(0)) - mbarrier.invalidate(bar.index(1)) + mbarrier.invalidate(mma_bar) block_m = mma_block_m(num_ctas) block_n = mma_block_n(num_ctas) From 135915cbc6b59e1f436c12110ae35e340a89012a Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 9 Apr 2026 19:40:40 +0200 Subject: [PATCH 5/6] Update ConSan multicast TMA lit descriptor syntax Co-authored-by: Codex --- test/TritonGPU/consan.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 7d908a9e2e56..9ef0928b574c 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -264,7 +264,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: arith.shrui // CHECK-LABEL: @outstanding_commits_multicast_tma_recipients tt.func public @outstanding_commits_multicast_tma_recipients( - %desc: !tt.tensordesc>, + %desc: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<32x32x!tt.ptr, #blocked>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 @@ -286,7 +286,7 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: %[[RECIPIENTS:.*]] = arith.shli %[[PATTERN]], // CHECK: tt.call @__triton_consan_check_outstanding_commits{{.*}}({{.*}}, %[[RECIPIENTS]]) // CHECK: ttng.async_tma_copy_global_to_local - ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %shmem, %bar, %true {multicast} : !tt.tensordesc>, !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0_i32, %c0_i32] %shmem, %bar, %true {multicast} : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } From c9cf9d98444a9f7ca8858ec9db3ec76273712d82 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 10 Apr 2026 08:12:13 +0200 Subject: [PATCH 6/6] kill unnecessary tests --- python/test/gluon/test_consan.py | 76 +++++--------------------------- 1 file changed, 12 insertions(+), 64 deletions(-) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 33f60b9ae07e..48a609200f2a 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -158,13 +158,10 @@ def test_consan_uses_profile_scratch(device, fresh_knobs, num_ctas): @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) -@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True]) -def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas): - if TWO_CTA_BARRIER and num_ctas == 1: - pytest.skip("Need at least 2 CTAs for a two-CTA barrier") +def test_async_tma_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): if run_wrapper: - result = run_in_process(test_async_tma_kernel, (FAILURE, TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas)) - if FAILURE or TWO_CTA_BARRIER: + result = run_in_process(test_async_tma_kernel, (FAILURE, device, False, monkeypatch, num_ctas)) + if FAILURE: assert_expected_cuda_failure(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output else: @@ -177,13 +174,13 @@ def test_async_tma_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeyp knobs.refresh_knobs() @gluon.jit - def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.constexpr): + def kernel(input_desc, out, FAILURE: ttgl.constexpr): block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) - bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER) + bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta) tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) @@ -203,19 +200,17 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=default_cga_layout(num_ctas, 2)) input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [block_m, XBLOCK.value], shared_layout) - kernel[(1, )](input_desc, output, FAILURE=FAILURE, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas) + kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) -@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True]) -def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas): +def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): if num_ctas == 1: pytest.skip("Need at least 2 CTAs for multicast in this test") if run_wrapper: - result = run_in_process(test_async_tma_multicast_kernel, - (FAILURE, TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas)) - if FAILURE or TWO_CTA_BARRIER: + result = run_in_process(test_async_tma_multicast_kernel, (FAILURE, device, False, monkeypatch, num_ctas)) + if FAILURE: assert_expected_cuda_failure(result.exc) assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output else: @@ -228,12 +223,12 @@ def test_async_tma_multicast_kernel(FAILURE, TWO_CTA_BARRIER, device, run_wrappe knobs.refresh_knobs() @gluon.jit - def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.constexpr): + def kernel(input_desc, out, FAILURE: ttgl.constexpr): cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 2) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout) - bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER) + bar = mbarrier.allocate_mbarrier() mbarrier.init(bar, count=1) mbarrier.expect(bar, input_desc.nbytes_per_cta) tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True) @@ -252,54 +247,7 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr, TWO_CTA_BARRIER: ttgl.const shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=multicast_cga_layout(num_ctas, 2)) input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK.value, XBLOCK.value], shared_layout) - kernel[(1, )](input_desc, output, FAILURE=FAILURE, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas) - - -@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") -@pytest.mark.parametrize("TWO_CTA_BARRIER", [False, True]) -def test_async_tma_multicast_kernel_two_cta_barrier(TWO_CTA_BARRIER, device, run_wrapper, monkeypatch, num_ctas): - if num_ctas != 2: - pytest.skip("This test covers a single 2-CTA multicast group") - if run_wrapper: - result = run_in_process(test_async_tma_multicast_kernel_two_cta_barrier, - (TWO_CTA_BARRIER, device, False, monkeypatch, num_ctas)) - if TWO_CTA_BARRIER: - assert_expected_cuda_failure(result.exc) - assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output - else: - assert result.exc is None - assert result.driver_stderr_output == "" - return - - monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") - monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") - knobs.refresh_knobs() - - @gluon.jit - def kernel(input_desc, out, TWO_CTA_BARRIER: ttgl.constexpr): - cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 2) - blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], - warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) - smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout) - bar = mbarrier.allocate_mbarrier(two_ctas=TWO_CTA_BARRIER) - mbarrier.init(bar, count=1) - mbarrier.expect(bar, input_desc.nbytes_per_cta) - tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem, multicast=True) - mbarrier.wait(bar, 0, deps=[smem]) - val = smem.load(blocked_layout) - mbarrier.invalidate(bar) - - out_m = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, blocked_layout))[:, None] - out_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :] - out_ptr = out + out_m * XBLOCK + out_n - ttgl.store(out_ptr, val) - - input = torch.randn((XBLOCK.value, XBLOCK.value), device=device, dtype=torch.float16) - output = torch.empty((XBLOCK.value, XBLOCK.value), device=device, dtype=torch.float16) - shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, - cga_layout=multicast_cga_layout(num_ctas, 2)) - input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [XBLOCK.value, XBLOCK.value], shared_layout) - kernel[(1, )](input_desc, output, TWO_CTA_BARRIER=TWO_CTA_BARRIER, num_warps=4, num_ctas=num_ctas) + kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")