From 56f67606ce7da70c6e86546ffdb263ec6f41808f Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 13 Apr 2026 09:19:46 +0200 Subject: [PATCH 1/2] [BACKEND] Model async TMA variants in ConSan We follow TMA load / store closely. --- .../IR/TritonNvidiaGPUOpInterfaces.td | 5 +++ .../Transforms/ConcurrencySanitizer.cpp | 18 +++++----- .../Transforms/ConSanNVIDIA.cpp | 33 ++++++++++--------- test/TritonGPU/consan.mlir | 28 +++++++++++++++- 4 files changed, 57 insertions(+), 27 deletions(-) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td index 4a1764ca8ab5..294db9f07033 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td @@ -42,6 +42,11 @@ def TMALoadLikeOpInterface : OpInterface<"TMALoadLikeOpInterface", [TMAOpInterfa /*retType=*/"::mlir::Value", /*methodName=*/"getPred", /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Return true if this load uses multicast", + /*retType=*/"bool", + /*methodName=*/"getMulticast", + /*args=*/(ins)>, ]; } diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index 67ee6265e3b3..d1498580274a 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -254,9 +254,9 @@ SmallVector getTensorCoreBarrierBroadcastMasks(Operation *op) { Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op); Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { - if (auto copyOp = dyn_cast(op)) { - if (copyOp.getMulticast()) - return getMulticastRecipientCTAs(b, copyOp.getResult()); + if (auto tmaLoad = dyn_cast(op)) { + if (tmaLoad.getMulticast()) + return getMulticastRecipientCTAs(b, tmaLoad.getResult()); return currentCTAMask(b); } if (isTensorCoreOp(op)) @@ -272,14 +272,12 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { return getLeaderCTA(b, arriveOp.getAlloc()); if (auto arriveOp = dyn_cast(op)) return getLeaderCTA(b, arriveOp.getBarrier()); - if (auto copyOp = dyn_cast(op)) { - if (copyOp.getMulticast()) - return getMulticastBarrierRecipientCTAs(b, copyOp.getResult(), - copyOp.getBarrier()); - return getLeaderCTA(b, copyOp.getBarrier()); - } - if (auto tmaLoad = dyn_cast(op)) + if (auto tmaLoad = dyn_cast(op)) { + if (tmaLoad.getMulticast()) + return getMulticastBarrierRecipientCTAs(b, tmaLoad.getResult(), + tmaLoad.getBarrier()); return getLeaderCTA(b, tmaLoad.getBarrier()); + } if (isTensorCoreOp(op)) return getRecipientCTAsForBroadcastMasks( diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index 0036f4c33425..395d4e90f4bd 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -28,6 +28,12 @@ Value getLeaderCTAPredicate(ImplicitLocOpBuilder &b, uint32_t broadcastMask) { arith::ConstantIntOp::create(b, 0, 32)); } +uint32_t getBlockBroadcastMask(Type type) { + auto memDescTy = cast(type); + auto kBlock = StringAttr::get(type.getContext(), "block"); + return toLinearLayout(memDescTy).getFreeVariableMasks().lookup(kBlock); +} + } // namespace class NVIDIAConSanHooks : public tti::ConSanTargetHooks { @@ -89,13 +95,12 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { mask = getBarrierMask(waitOp.getAlloc()); if (auto invalOp = dyn_cast(op)) mask = getBarrierMask(invalOp.getAlloc()); - if (auto copyOp = dyn_cast(op)) { - if (copyOp.getMulticast()) { - auto dstTy = cast(copyOp.getResult().getType()); - auto kBlock = StringAttr::get(op->getContext(), "block"); - mask = toLinearLayout(dstTy).getFreeVariableMasks().lookup(kBlock); - } + if (auto loadOp = dyn_cast(op)) { + if (loadOp.getMulticast()) + mask = getBlockBroadcastMask(loadOp.getResult().getType()); } + if (auto storeOp = dyn_cast(op)) + mask = getBlockBroadcastMask(storeOp.getSrc().getType()); // In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair // issues the op. @@ -205,16 +210,12 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; info->pred = loadOp.getPred(); int txCount = tti::getMemDescLength(loadOp.getResult()); - if (auto copyOp = dyn_cast(op); - copyOp && copyOp.getMulticast()) { - auto resultTy = cast(loadOp.getResult().getType()); - auto barrierTy = cast(loadOp.getBarrier().getType()); - auto kBlock = StringAttr::get(op->getContext(), "block"); - uint16_t resultMask = - toLinearLayout(resultTy).getFreeVariableMasks().lookup(kBlock); - uint16_t barrierMask = - toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock); - uint16_t collapsedMask = resultMask & barrierMask; + if (loadOp.getMulticast()) { + uint32_t resultMask = + getBlockBroadcastMask(loadOp.getResult().getType()); + uint32_t barrierMask = + getBlockBroadcastMask(loadOp.getBarrier().getType()); + uint32_t collapsedMask = resultMask & barrierMask; for (; collapsedMask; collapsedMask &= collapsedMask - 1) txCount *= 2; } diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 85e61e99bca5..1dfc9e3fdcba 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -479,9 +479,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.warp_group_dot %shmem, %shmem, %acc : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #mma> // CHECK: ttng.warp_group_dot + // CHECK: tt.call @__triton_consan_verify_write_visibility + // CHECK: tt.call @__triton_consan_check_outstanding_commits + // CHECK: tt.call @__triton_consan_stage_access_for_commit + // CHECK: tt.call @__triton_consan_commit_accesses + ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#smem = #ttg.shared_memory +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @async_tma_reduce + tt.func public @async_tma_reduce(%arg0: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr, #blocked>, %acc: tensor<128x128xf16, #mma>) { + %c0_i32 = arith.constant 0 : i32 + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttg.async_copy_global_to_local %ptr, %shmem : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #smem, mutable> + // CHECK: tt.call @__triton_consan_verify_write_visibility // CHECK: tt.call @__triton_consan_check_outstanding_commits - ttng.async_tma_scatter %arg0[%x_offsets, %c0_i32] %0 : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32>, i32, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + // CHECK: tt.call @__triton_consan_stage_access_for_commit + // CHECK: tt.call @__triton_consan_commit_accesses + ttng.async_tma_reduce add, %arg0[%c0_i32, %c0_i32] %0 : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } From bb795b9911abb425f012258a9460eca446eda3bd Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 14 Apr 2026 21:50:23 +0200 Subject: [PATCH 2/2] add lit test that checks multicast gather completion --- test/TritonGPU/consan.mlir | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 1dfc9e3fdcba..e5db7d7ba9fe 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -293,6 +293,40 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // ----- +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[0, 0]]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#smem = #ttg.shared_memory +#offset_parent = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 0]]}> +#offsets = #ttg.slice<{dim = 0, parent = #offset_parent}> +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, "ttng.two-ctas" = true, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @multicast_gather_two_cta_tx_count + tt.func public @multicast_gather_two_cta_tx_count(%desc: !tt.tensordesc<1x32xf32, #shared>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %x_offsets = arith.constant dense<0> : tensor<32xi32, #offsets> + %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + %result = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> + // CHECK: scf.for + scf.for %i = %c0 to %c2 step %c1 { + // CHECK: arith.constant 8192 : i64 + // CHECK: tt.call @__triton_consan_verify_barrier_arrive + // CHECK: ttng.barrier_expect + ttng.barrier_expect %bar, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> + // CHECK: arith.constant -8192 : i64 + // CHECK: tt.call @__triton_consan_verify_barrier_arrive + // CHECK: ttng.async_tma_gather + ttng.async_tma_gather %desc[%x_offsets, %c0_i32] %result, %bar, %true {multicast} : !tt.tensordesc<1x32xf32, #shared>, tensor<32xi32, #offsets>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, i1 + } + tt.return + } +} + +// ----- + #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory