From d087507234a5667526c8ec85d74eac661a18e01f Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 16 Apr 2026 12:38:21 +0200 Subject: [PATCH 1/7] [Consan] Support CLC Smelly bits: We execute CLC in the TMA partition to avoid having to create a new partition for CLC. I think we should create a different partition for CLC but I wanted to have @pawelszczerbuk's approval before doing it. We model CLC as we model TMA writes, via a Barrier::EffectWrites. The idea of this mode is that we link all the writes on the op to the barrier. We also annotate in the table `barrierWriteRecipients` which CTAs will become visible once we wait on the associated barrier. We note something interesting and document it. `BarrierTrackingMode::Frontier` should be used when we have a commit/arrive/expect op that affects anything in flight before it. Instead, we use `BarrierTrackingMode::EffectWrites` when the PTX op accepts a barrier so the barrier just signals the completion of the op's particular write. The other point we add is a flag `bool diagonalEffectRecipientCTAs`. This differentiates the behaviour between TMA, where after waiting on the barrier you see all the writes from all the CTAs in the multicas group, vs. the diagonal version, as in CLC, where waiting on CTAi just makes the thread see the CTAi memory. --- .../TritonInstrument/IR/FunctionBuilder.h | 7 +++- .../Transforms/ConSanTargetHooks.h | 27 ++++++++++++- lib/Analysis/BufferRegion.cpp | 2 +- .../TritonInstrument/IR/FunctionBuilder.cpp | 36 ++++++++++++++--- .../Transforms/ConcurrencySanitizer.cpp | 15 ++++++- .../Transforms/ConSanNVIDIA.cpp | 27 ++++++++++++- python/test/gluon/test_consan.py | 40 ++++++++++++++++++- test/TritonGPU/consan.mlir | 26 ++++++++++++ 8 files changed, 165 insertions(+), 15 deletions(-) diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index 9eb6756af0e0..e383c185fe1b 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -169,13 +169,16 @@ class FunctionBuilder { int thread, Value pred, MemType memType, Operation *insertPoint, Value recipientCTAs); // trackBarrierWriteForBuffer: mark a specific buffer as tracked by a - // barrier in the write-tracking table. + // barrier in the write-tracking table. When diagonalEffectRecipientCTAs is + // false, every signaled barrier row publishes the full effectRecipientCTAs + // mask. When it is true, barrier row i publishes only bit i of that mask. void createTrackBarrierWriteForBufferCall(ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, Value barrierRecipientCTAs, - Value effectRecipientCTAs); + Value effectRecipientCTAs, + bool diagonalEffectRecipientCTAs); // clearBarrierWriteTracking: clear all write tracking associated with the // given barrier row. void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar, diff --git a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h index 89f55733aaa4..0d97338e4c7d 100644 --- a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h +++ b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h @@ -12,8 +12,19 @@ namespace mlir::triton::instrument { struct MemEffectsOpInfo { - // Frontier: snapshot thread-visible frontier into barrier tracking. - // EffectWrites: track only buffers written by op effects. + // Controls which memory effects become visible to a CTA after it waits on + // this barrier. + // + // Frontier snapshots the issuing thread's current visibility frontier into + // the barrier. A later wait publishes whatever shared/tensor memory writes + // and reads were visible to that logical thread before the arrive/commit. Use + // this for ordering operations whose semantics are a release of prior work. + // + // EffectWrites does not snapshot the whole thread frontier. Instead, it + // attaches only the explicit write effects of this op to the barrier. A later + // wait publishes those op-local writes and nothing else. Use this for PTX ops + // that perform the write and also signal the barrier via + // `mbarrier::complete_tx`. enum class BarrierTrackingMode { Frontier, EffectWrites, @@ -34,6 +45,18 @@ struct MemEffectsOpInfo { int count; BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier; int txCount = 0; + // For EffectWrites, effectRecipientCTAs identifies the CTA rows where the + // op wrote its explicit result. By default, for + // diagonalEffectRecipientCTAs=false, waiting on a barrier publishes the CTA + // rows in effectRecipientCTAs, which is the full mask. This is the + // behaviour of TMA multicast. If diagonalEffectRecipientCTAs is true, + // waiting on a barrier publishes only the CTA rows in effectRecipientCTAs, + // which is the diagonal mask. e.g. effectRecipientCTAs = 0b1101 If + // DiagonalEffectRecipientCTAs is false, waiting on the barrier publishes + // the following CTA rows: CTA0 0b1101 CTA1 0b1101 CTA2 0b1101 CTA3 0b1101 + // If diagonalEffectRecipientCTAs is true, waiting on the barrier publishes + // the following CTA rows: CTA0 0b1000 CTA1 0b0100 CTA2 0b0000 CTA3 0b0001 + bool diagonalEffectRecipientCTAs = false; }; enum class TrackingKind { None, diff --git a/lib/Analysis/BufferRegion.cpp b/lib/Analysis/BufferRegion.cpp index 4f0c43e43fa9..76ad89f7e5ce 100644 --- a/lib/Analysis/BufferRegion.cpp +++ b/lib/Analysis/BufferRegion.cpp @@ -323,7 +323,7 @@ void BufferRegionAnalysis::calculateUsedBufferRegions(Operation *op) { bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) { if (isa(op)) { + ttng::TMAOpInterface, ttng::CLCLoadResultOp>(op)) { return true; } if (isa(op)) { diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index 9ac031d1fd1c..1d617a5667d9 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -1674,13 +1674,13 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, void FunctionBuilder::createTrackBarrierWriteForBufferCall( ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, Value barrierRecipientCTAs, - Value effectRecipientCTAs) { + Value effectRecipientCTAs, bool diagonalEffectRecipientCTAs) { if (auxData.barriers.empty() || auxData.buffers[(int)memType].empty() || auxData.writeTracking[(int)memType].empty()) { return; } assert(!auxData.barrierWriteRecipients.empty() && - "barrier write recipients must exist when tracking TMA writes"); + "barrier write recipients must exist when tracking EffectWrites"); if (!pred) pred = arith::ConstantIntOp::create(b, 1, 1); Value barriersVal = auxData.barriers.at(insertPoint).value; @@ -1717,10 +1717,10 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( b, "track_barrier_write_for_buffer", args, /*assertInfo=*/std::nullopt, {barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, - (uint64_t)memType}, - [barriersType, buffersType, writeTrackingType, - barrierWriteRecipientsType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + (uint64_t)memType, (uint64_t)diagonalEffectRecipientCTAs}, + [barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, + diagonalEffectRecipientCTAs](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value mbarLengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1747,6 +1747,30 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( createCmpIntTensorScalar(fb, barriers, barrierDescriptor); Value effectRecipientCTAsTensor = triton::SplatOp::create( fb, barrierWriteRecipientsType, effectRecipientCTAs); + if (diagonalEffectRecipientCTAs) { + // Expand the effect CTA mask diagonally: barrier row i publishes only + // bit i. This models per-CTA results, while the default replicated + // mask models TMA multicast where a barrier publishes all result + // rows. + auto encoding = cast( + barrierWriteRecipientsType.getEncoding()); + auto rowSliceEncoding = + tti::getSingleDimSliceEncoding(encoding, /*dim=*/0); + int numCTAs = barrierWriteRecipientsType.getShape()[0]; + auto rowType = RankedTensorType::get({numCTAs}, fb.getI32Type(), + rowSliceEncoding); + Value rowIdx = triton::MakeRangeOp::create(fb, rowType, + /*start=*/0, + /*end=*/numCTAs); + auto indexType = + cast(barrierWriteRecipientsType.cloneWith( + std::nullopt, fb.getI32Type())); + rowIdx = convertAndBroadcast(fb, rowIdx, {0}, indexType); + Value one = tti::createConstIntTensor(fb, fb.getLoc(), 1, indexType); + Value rowBit = arith::ShLIOp::create(fb, one, rowIdx); + effectRecipientCTAsTensor = + arith::AndIOp::create(fb, effectRecipientCTAsTensor, rowBit); + } Value updatedBarrierWriteRecipients = arith::OrIOp::create( fb, barrierWriteRecipients, effectRecipientCTAsTensor); updatedBarrierWriteRecipients = arith::SelectOp::create( diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index d1498580274a..12053ce0cab1 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -23,7 +23,7 @@ // buffers | tensor | | Base pointers of all (sub)buffers // barriers | tensor | | Pointers to all individual mbarriers // barrierStates | scratch | | Packed barrier phase (bit 0), arrival counts (bits[1..20] init, [21..40] current), and signed tx-count (bits[41..61]); zero means invalid/uninitialized -// barrierWriteRecipients | scratch | | CTA bitsets of write-tracking rows reached by outstanding TMA effects on each barrier +// barrierWriteRecipients | scratch | | CTA bitsets of EffectWrites rows published by each barrier // waiting | scratch | | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1) // writeVisibility | scratch | | Per-buffer thread-visibility bitmask (bit i => thread i visible) // readVisibility | scratch | | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks) @@ -159,6 +159,12 @@ Value currentCTAMask(ImplicitLocOpBuilder &b) { ctaId); } +Value allCTAsMask(ImplicitLocOpBuilder &b) { + int numCTAs = ttg::lookupNumCTAs(b); + assert(numCTAs <= 16 && "ConSan CTA bitsets assume at most 16 CTAs"); + return arith::ConstantIntOp::create(b, (1u << numCTAs) - 1, 32); +} + uint16_t getBlockBroadcastMask(Value alloc) { auto allocTy = cast(alloc.getType()); auto kBlock = StringAttr::get(alloc.getContext(), "block"); @@ -259,6 +265,8 @@ Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { return getMulticastRecipientCTAs(b, tmaLoad.getResult()); return currentCTAMask(b); } + if (isa(op)) + return allCTAsMask(b); if (isTensorCoreOp(op)) return getRecipientCTAsForBroadcastMasks( b, ttng::getCTABroadcastMasks(ttng::getModuleTwoCTAs(op), {})); @@ -278,6 +286,8 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { tmaLoad.getBarrier()); return getLeaderCTA(b, tmaLoad.getBarrier()); } + if (isa(op)) + return allCTAsMask(b); if (isTensorCoreOp(op)) return getRecipientCTAsForBroadcastMasks( @@ -560,7 +570,8 @@ class ConcurrencySanitizerImpl { memType = MemType::SHARED_MEM; funcBuilder.createTrackBarrierWriteForBufferCall( b, barrier, effect.buf, effect.length, combinedPred, memType, op, - recipientCTAs, effectRecipientCTAs); + recipientCTAs, effectRecipientCTAs, + barrierInfo.diagonalEffectRecipientCTAs); } } if (barrierInfo.count > 0 || barrierInfo.txCount != 0) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index 395d4e90f4bd..5b8ebf8b004e 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -39,7 +39,9 @@ uint32_t getBlockBroadcastMask(Type type) { class NVIDIAConSanHooks : public tti::ConSanTargetHooks { public: bool isTMAOp(Operation *op) const override { - return isa(op); + // CLC writes its result asynchronously through the same helper-thread + // visibility class that ConSan uses for TMA effects. + return isa(op); } std::optional @@ -101,6 +103,11 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { } if (auto storeOp = dyn_cast(op)) mask = getBlockBroadcastMask(storeOp.getSrc().getType()); + if (isa(op) && ttg::lookupNumCTAs(op) > 1) { + Value ctaId = tti::ExperimentalClusterCTAIdOp::create(b, b.getLoc()); + return arith::CmpIOp::create(b, arith::CmpIPredicate::eq, ctaId, + arith::ConstantIntOp::create(b, 0, 32)); + } // In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair // issues the op. @@ -226,6 +233,24 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, loadOp.getResult()); } + if (auto tryCancelOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->barriers.push_back( + {tryCancelOp.getMbarrier(), nullptr, /*count=*/0, + MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, + /*txCount=*/ + -static_cast(tti::getMemDescLength(tryCancelOp.getResult())), + /*diagonalEffectRecipientCTAs=*/true}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + tryCancelOp.getResult()); + } + if (auto loadResultOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + loadResultOp.getSrc()); + } if (auto storeOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 48a609200f2a..bf3064ad8f2c 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -8,7 +8,7 @@ from triton.experimental.gluon.language.nvidia import blackwell from triton.experimental.gluon.language.nvidia import hopper from triton.experimental.gluon.language.nvidia import ampere -from triton.experimental.gluon.language.nvidia.blackwell import allocate_tensor_memory, mbarrier, tma +from triton.experimental.gluon.language.nvidia.blackwell import allocate_tensor_memory, clc, mbarrier, tma from triton._internal_testing import is_cuda, run_in_process @@ -250,6 +250,44 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell") +@pytest.mark.parametrize("FAILURE", [True, False]) +def test_clc_result_visibility(FAILURE, device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_clc_result_visibility, (FAILURE, device, False, monkeypatch, num_ctas)) + if FAILURE: + assert_expected_cuda_failure(result.exc) + assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(out, FAILURE: ttgl.constexpr): + cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1) + layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) + clc_result = ttgl.allocate_shared_memory(ttgl.int64, [2], layout) + clc_bar = mbarrier.allocate_mbarrier() + mbarrier.init(clc_bar, count=1) + + clc.try_cancel(clc_result, clc_bar) + mbarrier.expect(clc_bar, 16) + mbarrier.wait(clc_bar, 0, pred=(not FAILURE)) + response = clc.load_result(clc_result) + mbarrier.wait(clc_bar, 0, pred=FAILURE) + mbarrier.invalidate(clc_bar) + + ttgl.store(out + ttgl.program_id(0), response.is_canceled()) + + output = torch.empty((1, ), device=device, dtype=torch.bool) + kernel[(1, )](output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") def test_async_tma_multicast_kernel_reuse(device, run_wrapper, monkeypatch, num_ctas): if num_ctas == 1: diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index e5db7d7ba9fe..007b7425a3db 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -252,6 +252,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // ----- +#shared_clc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrier = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @clc_try_cancel_diagonal_effect_recipients + tt.func public @clc_try_cancel_diagonal_effect_recipients() { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %result = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2xi64, #shared_clc, #smem, mutable> + %bar = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<1xi64, #barrier, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #barrier, #smem, mutable> + // CHECK: tti.experimental_cluster_cta_id + // CHECK: arith.cmpi eq + // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer{{.*}}_I1 + ttng.clc_try_cancel %result, %bar : !ttg.memdesc<2xi64, #shared_clc, #smem, mutable>, !ttg.memdesc<1xi64, #barrier, #smem, mutable> + ttng.barrier_expect %bar, 16, %true : !ttg.memdesc<1xi64, #barrier, #smem, mutable> + ttng.wait_barrier %bar, %c0_i32, %true : !ttg.memdesc<1xi64, #barrier, #smem, mutable> + // CHECK: tt.call @__triton_consan_verify_write_visibility + // CHECK: ttng.clc_load_result + %clc_res = ttng.clc_load_result %result : !ttg.memdesc<2xi64, #shared_clc, #smem, mutable> -> i128 + tt.return + } +} + +// ----- + #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[0, 0]]}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #smem = #ttg.shared_memory From eadf967777d961dc0c3bd044bc7cf19aef5b7c7f Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 16 Apr 2026 15:10:24 +0200 Subject: [PATCH 2/7] CLC partition --- .../TritonInstrument/IR/TritonInstrument.md | 10 +++++----- .../Dialect/TritonInstrument/IR/Utility.h | 8 ++++++-- .../Transforms/ConSanTargetHooks.h | 2 ++ .../TritonInstrument/IR/FunctionBuilder.cpp | 9 ++++----- .../Transforms/ConcurrencySanitizer.cpp | 7 ++++++- .../Transforms/ConSanNVIDIA.cpp | 8 +++++--- test/TritonGPU/consan.mlir | 18 +++++++++++------- 7 files changed, 39 insertions(+), 23 deletions(-) diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md index 4fd74bfce0fb..0d8c3254880a 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md @@ -9,10 +9,10 @@ Auxiliary state is kept in distributed tensors and global scratch memory, with t ### Thread model - Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions). -- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads. -- Total logical threads: 48. Bitmasks are sized to the next power of two: 64. +- Peer classes: +16 TMA threads, +16 Tensor Core (TC) threads, and +16 CLC threads to model lack of ordering with base threads. +- Total logical threads: 64. Bitmasks are sized to the next power of two: 64. -Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience. +Indexing uses a logical thread id in [0, 64), with column vectors sized to 64 for layout convenience. ## Auxiliary data structures @@ -21,7 +21,7 @@ All types are generated on-demand (per partition) based on: - B: number of tracked buffers (power-of-two padded) - K: number of mbarriers (power-of-two padded) - T_bits: 64 (bitmask width) -- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers) +- T_commits: 16 (base threads; commit counters do not apply to TC/TMA/CLC helpers) “tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts. @@ -53,7 +53,7 @@ ConSan separates “tracking” from “visibility transfer”: - experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer. - experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier. - At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes. -- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent. +- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC, CLC) to keep the classes consistent. ### Barrier phase/count tracking diff --git a/include/triton/Dialect/TritonInstrument/IR/Utility.h b/include/triton/Dialect/TritonInstrument/IR/Utility.h index b8fb63e574a4..09230273c659 100644 --- a/include/triton/Dialect/TritonInstrument/IR/Utility.h +++ b/include/triton/Dialect/TritonInstrument/IR/Utility.h @@ -6,6 +6,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "llvm/Support/MathExtras.h" #include @@ -22,8 +23,11 @@ constexpr int numMemTypes = getMaxEnumValForMemType() + 1; constexpr int NUM_THREADS = 16; constexpr int TMA_THREAD_OFFSET = NUM_THREADS; constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS; -constexpr int TOTAL_NUM_THREADS = TC_THREAD_OFFSET + NUM_THREADS; -constexpr int THREADS_BITMASK_SIZE = llvm::NextPowerOf2(TOTAL_NUM_THREADS); +constexpr int CLC_THREAD_OFFSET = TC_THREAD_OFFSET + NUM_THREADS; +constexpr int TOTAL_NUM_THREADS = CLC_THREAD_OFFSET + NUM_THREADS; +static_assert(TOTAL_NUM_THREADS <= 64, + "ConSan thread bitsets are stored in i64 masks"); +const int THREADS_BITMASK_SIZE = llvm::PowerOf2Ceil(TOTAL_NUM_THREADS); namespace CommitKind { enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds }; diff --git a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h index 0d97338e4c7d..e47e4bb752f5 100644 --- a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h +++ b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h @@ -106,6 +106,8 @@ class ConSanTargetHooks { virtual bool isTMAOp(Operation *op) const = 0; + virtual bool isCLCOp(Operation *op) const { return false; } + virtual std::optional getBarrierInitInfo(Operation *op) const = 0; diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index 1d617a5667d9..db3ace2f1df6 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -2389,11 +2389,10 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value zeroTensor = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); - constexpr uint64_t fullMask = - tti::THREADS_BITMASK_SIZE == 64 - ? std::numeric_limits::max() - : (std::numeric_limits::max() >> - (64 - tti::THREADS_BITMASK_SIZE)); + uint64_t fullMask = tti::THREADS_BITMASK_SIZE == 64 + ? std::numeric_limits::max() + : (std::numeric_limits::max() >> + (64 - tti::THREADS_BITMASK_SIZE)); Value fullMaskVal = arith::ConstantIntOp::create(fb, fullMask, 64); Value destMaskElem = adjustIntegerWidth(fb, destMaskVal, elemType); Value fullMaskElem = adjustIntegerWidth(fb, fullMaskVal, elemType); diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index 12053ce0cab1..c4221a0a0467 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -120,12 +120,16 @@ int getCurrentThread(Operation *op, const ConSanTargetHooks *hooks) { thread += TC_THREAD_OFFSET; return thread; } + if (hooks->isCLCOp(op)) { + thread += CLC_THREAD_OFFSET; + return thread; + } return thread; } int getBaseThread(int thread) { return thread % NUM_THREADS; } -// Peer threads are the equivalent threads in the TMA, TC and normal +// Peer threads are the equivalent threads in the TMA, TC, CLC and normal // thread classes. // If a thread is a base thread, return the mask with the peers, otherwise // return the mask with the thread itself. @@ -134,6 +138,7 @@ uint64_t getThreadPeersMask(int thread) { if (thread < NUM_THREADS) { mask |= 1ULL << (thread + TMA_THREAD_OFFSET); mask |= 1ULL << (thread + TC_THREAD_OFFSET); + mask |= 1ULL << (thread + CLC_THREAD_OFFSET); } return mask; } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index 5b8ebf8b004e..f3bc65039879 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -39,9 +39,11 @@ uint32_t getBlockBroadcastMask(Type type) { class NVIDIAConSanHooks : public tti::ConSanTargetHooks { public: bool isTMAOp(Operation *op) const override { - // CLC writes its result asynchronously through the same helper-thread - // visibility class that ConSan uses for TMA effects. - return isa(op); + return isa(op); + } + + bool isCLCOp(Operation *op) const override { + return isa(op); } std::optional diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 007b7425a3db..5ef6070da05b 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -265,6 +265,10 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #barrier, #smem, mutable> // CHECK: tti.experimental_cluster_cta_id // CHECK: arith.cmpi eq + // CHECK: %[[CLC_THREAD:.*]] = arith.constant 48 : i32 + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[CLC_THREAD]] + // CHECK: %[[CLC_MASK:.*]] = arith.constant 281474976710656 : i64 + // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[CLC_MASK]] // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer{{.*}}_I1 ttng.clc_try_cancel %result, %bar : !ttg.memdesc<2xi64, #shared_clc, #smem, mutable>, !ttg.memdesc<1xi64, #barrier, #smem, mutable> ttng.barrier_expect %bar, 16, %true : !ttg.memdesc<1xi64, #barrier, #smem, mutable> @@ -1026,7 +1030,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @async_commit_group() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 4295032833 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -1277,10 +1281,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -1322,10 +1326,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -1358,10 +1362,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { From 04dfc7b94e54b794c2708d03f6d9afe9e59a6c0a Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 16 Apr 2026 23:16:34 +0200 Subject: [PATCH 3/7] address review --- include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md index 0d8c3254880a..f7d074675860 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md @@ -10,7 +10,7 @@ Auxiliary state is kept in distributed tensors and global scratch memory, with t - Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions). - Peer classes: +16 TMA threads, +16 Tensor Core (TC) threads, and +16 CLC threads to model lack of ordering with base threads. -- Total logical threads: 64. Bitmasks are sized to the next power of two: 64. +- Total logical threads: 64. Indexing uses a logical thread id in [0, 64), with column vectors sized to 64 for layout convenience. From f1158f61e683ae1072b041c5d34a6c805064b2de Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 16 Apr 2026 23:28:41 +0200 Subject: [PATCH 4/7] lit tests --- test/TritonGPU/amd/amd-consan.mlir | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/TritonGPU/amd/amd-consan.mlir b/test/TritonGPU/amd/amd-consan.mlir index d475a14b6ea7..fdb7c4ac3ca7 100644 --- a/test/TritonGPU/amd/amd-consan.mlir +++ b/test/TritonGPU/amd/amd-consan.mlir @@ -242,7 +242,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @async_wait() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 4295032833 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -699,7 +699,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @amdg_async_wait() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 4295032833 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -830,10 +830,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -874,10 +874,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -908,10 +908,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { From 4aa57e01736c378562741dee38f2d6fafd21b15d Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 17 Apr 2026 17:36:06 +0200 Subject: [PATCH 5/7] kill unnecessary flags --- lib/Dialect/TritonInstrument/IR/Utility.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/Dialect/TritonInstrument/IR/Utility.cpp b/lib/Dialect/TritonInstrument/IR/Utility.cpp index 29c689765ec6..63ab53813189 100644 --- a/lib/Dialect/TritonInstrument/IR/Utility.cpp +++ b/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -591,8 +591,7 @@ void AuxDataMap::populateAndPassToWarpSpecialize( if (numCTAs > 1) { ClusterBarrierOp::create(b, b.getLoc()); } else { - BarrierOp::create(b, b.getLoc(), - AddrSpace::GlobalRead | AddrSpace::GlobalWrite); + BarrierOp::create(b, b.getLoc(), AddrSpace::Local); } lock.insert(entryRegion, {lockVal, lockVal.getType()}); passToWarpSpecialize(entryPoint, lock.at(entryRegion), lock, captureCounter); From e8b24107ed17245e6561ee922e33f5f608c94555 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 17 Apr 2026 17:50:41 +0200 Subject: [PATCH 6/7] add missing barrier fence if using CLC --- .../Transforms/ClusterBarrierInsertion.cpp | 3 ++ test/TritonNvidiaGPU/membar-cluster.mlir | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp index afabbf693d2d..ea43ce133f6f 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp @@ -114,6 +114,9 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op, if (auto tma = dyn_cast(op)) { return tma.getMulticast() && aliasesTracked(tma.getBarrier()); } + if (auto clc = dyn_cast(op)) { + return aliasesTracked(clc.getMbarrier()); + } return false; } diff --git a/test/TritonNvidiaGPU/membar-cluster.mlir b/test/TritonNvidiaGPU/membar-cluster.mlir index becb904b91ee..05925e949748 100644 --- a/test/TritonNvidiaGPU/membar-cluster.mlir +++ b/test/TritonNvidiaGPU/membar-cluster.mlir @@ -473,6 +473,35 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- +#sharedCLC = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrierCLC = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CLC multicasts completion through the cluster, so it needs init sync even + // if the barrier allocation shape looks per-CTA. + // CHECK-LABEL: @cluster_clc_with_per_cta_barrier + // CHECK: ttng.init_barrier + // CHECK-NEXT: ttng.fence_mbarrier_init_release_cluster + // CHECK-NEXT: ttng.cluster_barrier {relaxed = true} + // CHECK-NEXT: ttng.clc_try_cancel + // CHECK: tt.return + tt.func @cluster_clc_with_per_cta_barrier() { + %true = arith.constant true + %result = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #sharedCLC, #smem, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + ttng.init_barrier %barrier, 1 : !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + ttng.clc_try_cancel %result, %barrier : + !ttg.memdesc<2xi64, #sharedCLC, #smem, mutable>, + !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + ttng.barrier_expect %barrier, 16, %true : + !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + tt.return + } +} + +// ----- + #nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> #barrierEnc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> From 61ebb887a64aa994a13d85469c4e95a0dd92ef6a Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 17 Apr 2026 17:53:14 +0200 Subject: [PATCH 7/7] minor --- .../TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp index ea43ce133f6f..59f2d6cc3443 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp @@ -108,10 +108,7 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op, if (auto commit = dyn_cast(op)) { return ttng::getModuleTwoCTAs(op) && aliasesTracked(commit.getBarrier()); } - if (auto tma = dyn_cast(op)) { - return tma.getMulticast() && aliasesTracked(tma.getBarrier()); - } - if (auto tma = dyn_cast(op)) { + if (auto tma = dyn_cast(op)) { return tma.getMulticast() && aliasesTracked(tma.getBarrier()); } if (auto clc = dyn_cast(op)) {