diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index 9eb6756af0e0..e383c185fe1b 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -169,13 +169,16 @@ class FunctionBuilder { int thread, Value pred, MemType memType, Operation *insertPoint, Value recipientCTAs); // trackBarrierWriteForBuffer: mark a specific buffer as tracked by a - // barrier in the write-tracking table. + // barrier in the write-tracking table. When diagonalEffectRecipientCTAs is + // false, every signaled barrier row publishes the full effectRecipientCTAs + // mask. When it is true, barrier row i publishes only bit i of that mask. void createTrackBarrierWriteForBufferCall(ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, Value barrierRecipientCTAs, - Value effectRecipientCTAs); + Value effectRecipientCTAs, + bool diagonalEffectRecipientCTAs); // clearBarrierWriteTracking: clear all write tracking associated with the // given barrier row. void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar, diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md index 4fd74bfce0fb..f7d074675860 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md @@ -9,10 +9,10 @@ Auxiliary state is kept in distributed tensors and global scratch memory, with t ### Thread model - Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions). -- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads. -- Total logical threads: 48. Bitmasks are sized to the next power of two: 64. +- Peer classes: +16 TMA threads, +16 Tensor Core (TC) threads, and +16 CLC threads to model lack of ordering with base threads. +- Total logical threads: 64. -Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience. +Indexing uses a logical thread id in [0, 64), with column vectors sized to 64 for layout convenience. ## Auxiliary data structures @@ -21,7 +21,7 @@ All types are generated on-demand (per partition) based on: - B: number of tracked buffers (power-of-two padded) - K: number of mbarriers (power-of-two padded) - T_bits: 64 (bitmask width) -- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers) +- T_commits: 16 (base threads; commit counters do not apply to TC/TMA/CLC helpers) “tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts. @@ -53,7 +53,7 @@ ConSan separates “tracking” from “visibility transfer”: - experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer. - experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier. - At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes. -- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent. +- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC, CLC) to keep the classes consistent. ### Barrier phase/count tracking diff --git a/include/triton/Dialect/TritonInstrument/IR/Utility.h b/include/triton/Dialect/TritonInstrument/IR/Utility.h index b8fb63e574a4..09230273c659 100644 --- a/include/triton/Dialect/TritonInstrument/IR/Utility.h +++ b/include/triton/Dialect/TritonInstrument/IR/Utility.h @@ -6,6 +6,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonInstrument/IR/Dialect.h" +#include "llvm/Support/MathExtras.h" #include @@ -22,8 +23,11 @@ constexpr int numMemTypes = getMaxEnumValForMemType() + 1; constexpr int NUM_THREADS = 16; constexpr int TMA_THREAD_OFFSET = NUM_THREADS; constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS; -constexpr int TOTAL_NUM_THREADS = TC_THREAD_OFFSET + NUM_THREADS; -constexpr int THREADS_BITMASK_SIZE = llvm::NextPowerOf2(TOTAL_NUM_THREADS); +constexpr int CLC_THREAD_OFFSET = TC_THREAD_OFFSET + NUM_THREADS; +constexpr int TOTAL_NUM_THREADS = CLC_THREAD_OFFSET + NUM_THREADS; +static_assert(TOTAL_NUM_THREADS <= 64, + "ConSan thread bitsets are stored in i64 masks"); +const int THREADS_BITMASK_SIZE = llvm::PowerOf2Ceil(TOTAL_NUM_THREADS); namespace CommitKind { enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds }; diff --git a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h index 89f55733aaa4..e47e4bb752f5 100644 --- a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h +++ b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h @@ -12,8 +12,19 @@ namespace mlir::triton::instrument { struct MemEffectsOpInfo { - // Frontier: snapshot thread-visible frontier into barrier tracking. - // EffectWrites: track only buffers written by op effects. + // Controls which memory effects become visible to a CTA after it waits on + // this barrier. + // + // Frontier snapshots the issuing thread's current visibility frontier into + // the barrier. A later wait publishes whatever shared/tensor memory writes + // and reads were visible to that logical thread before the arrive/commit. Use + // this for ordering operations whose semantics are a release of prior work. + // + // EffectWrites does not snapshot the whole thread frontier. Instead, it + // attaches only the explicit write effects of this op to the barrier. A later + // wait publishes those op-local writes and nothing else. Use this for PTX ops + // that perform the write and also signal the barrier via + // `mbarrier::complete_tx`. enum class BarrierTrackingMode { Frontier, EffectWrites, @@ -34,6 +45,18 @@ struct MemEffectsOpInfo { int count; BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier; int txCount = 0; + // For EffectWrites, effectRecipientCTAs identifies the CTA rows where the + // op wrote its explicit result. By default, for + // diagonalEffectRecipientCTAs=false, waiting on a barrier publishes the CTA + // rows in effectRecipientCTAs, which is the full mask. This is the + // behaviour of TMA multicast. If diagonalEffectRecipientCTAs is true, + // waiting on a barrier publishes only the CTA rows in effectRecipientCTAs, + // which is the diagonal mask. e.g. effectRecipientCTAs = 0b1101 If + // DiagonalEffectRecipientCTAs is false, waiting on the barrier publishes + // the following CTA rows: CTA0 0b1101 CTA1 0b1101 CTA2 0b1101 CTA3 0b1101 + // If diagonalEffectRecipientCTAs is true, waiting on the barrier publishes + // the following CTA rows: CTA0 0b1000 CTA1 0b0100 CTA2 0b0000 CTA3 0b0001 + bool diagonalEffectRecipientCTAs = false; }; enum class TrackingKind { None, @@ -83,6 +106,8 @@ class ConSanTargetHooks { virtual bool isTMAOp(Operation *op) const = 0; + virtual bool isCLCOp(Operation *op) const { return false; } + virtual std::optional getBarrierInitInfo(Operation *op) const = 0; diff --git a/lib/Analysis/BufferRegion.cpp b/lib/Analysis/BufferRegion.cpp index 4f0c43e43fa9..76ad89f7e5ce 100644 --- a/lib/Analysis/BufferRegion.cpp +++ b/lib/Analysis/BufferRegion.cpp @@ -323,7 +323,7 @@ void BufferRegionAnalysis::calculateUsedBufferRegions(Operation *op) { bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) { if (isa(op)) { + ttng::TMAOpInterface, ttng::CLCLoadResultOp>(op)) { return true; } if (isa(op)) { diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index 9ac031d1fd1c..db3ace2f1df6 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -1674,13 +1674,13 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, void FunctionBuilder::createTrackBarrierWriteForBufferCall( ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, Value barrierRecipientCTAs, - Value effectRecipientCTAs) { + Value effectRecipientCTAs, bool diagonalEffectRecipientCTAs) { if (auxData.barriers.empty() || auxData.buffers[(int)memType].empty() || auxData.writeTracking[(int)memType].empty()) { return; } assert(!auxData.barrierWriteRecipients.empty() && - "barrier write recipients must exist when tracking TMA writes"); + "barrier write recipients must exist when tracking EffectWrites"); if (!pred) pred = arith::ConstantIntOp::create(b, 1, 1); Value barriersVal = auxData.barriers.at(insertPoint).value; @@ -1717,10 +1717,10 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( b, "track_barrier_write_for_buffer", args, /*assertInfo=*/std::nullopt, {barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, - (uint64_t)memType}, - [barriersType, buffersType, writeTrackingType, - barrierWriteRecipientsType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + (uint64_t)memType, (uint64_t)diagonalEffectRecipientCTAs}, + [barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, + diagonalEffectRecipientCTAs](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value mbarLengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1747,6 +1747,30 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( createCmpIntTensorScalar(fb, barriers, barrierDescriptor); Value effectRecipientCTAsTensor = triton::SplatOp::create( fb, barrierWriteRecipientsType, effectRecipientCTAs); + if (diagonalEffectRecipientCTAs) { + // Expand the effect CTA mask diagonally: barrier row i publishes only + // bit i. This models per-CTA results, while the default replicated + // mask models TMA multicast where a barrier publishes all result + // rows. + auto encoding = cast( + barrierWriteRecipientsType.getEncoding()); + auto rowSliceEncoding = + tti::getSingleDimSliceEncoding(encoding, /*dim=*/0); + int numCTAs = barrierWriteRecipientsType.getShape()[0]; + auto rowType = RankedTensorType::get({numCTAs}, fb.getI32Type(), + rowSliceEncoding); + Value rowIdx = triton::MakeRangeOp::create(fb, rowType, + /*start=*/0, + /*end=*/numCTAs); + auto indexType = + cast(barrierWriteRecipientsType.cloneWith( + std::nullopt, fb.getI32Type())); + rowIdx = convertAndBroadcast(fb, rowIdx, {0}, indexType); + Value one = tti::createConstIntTensor(fb, fb.getLoc(), 1, indexType); + Value rowBit = arith::ShLIOp::create(fb, one, rowIdx); + effectRecipientCTAsTensor = + arith::AndIOp::create(fb, effectRecipientCTAsTensor, rowBit); + } Value updatedBarrierWriteRecipients = arith::OrIOp::create( fb, barrierWriteRecipients, effectRecipientCTAsTensor); updatedBarrierWriteRecipients = arith::SelectOp::create( @@ -2365,11 +2389,10 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value zeroTensor = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); - constexpr uint64_t fullMask = - tti::THREADS_BITMASK_SIZE == 64 - ? std::numeric_limits::max() - : (std::numeric_limits::max() >> - (64 - tti::THREADS_BITMASK_SIZE)); + uint64_t fullMask = tti::THREADS_BITMASK_SIZE == 64 + ? std::numeric_limits::max() + : (std::numeric_limits::max() >> + (64 - tti::THREADS_BITMASK_SIZE)); Value fullMaskVal = arith::ConstantIntOp::create(fb, fullMask, 64); Value destMaskElem = adjustIntegerWidth(fb, destMaskVal, elemType); Value fullMaskElem = adjustIntegerWidth(fb, fullMaskVal, elemType); diff --git a/lib/Dialect/TritonInstrument/IR/Utility.cpp b/lib/Dialect/TritonInstrument/IR/Utility.cpp index 29c689765ec6..63ab53813189 100644 --- a/lib/Dialect/TritonInstrument/IR/Utility.cpp +++ b/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -591,8 +591,7 @@ void AuxDataMap::populateAndPassToWarpSpecialize( if (numCTAs > 1) { ClusterBarrierOp::create(b, b.getLoc()); } else { - BarrierOp::create(b, b.getLoc(), - AddrSpace::GlobalRead | AddrSpace::GlobalWrite); + BarrierOp::create(b, b.getLoc(), AddrSpace::Local); } lock.insert(entryRegion, {lockVal, lockVal.getType()}); passToWarpSpecialize(entryPoint, lock.at(entryRegion), lock, captureCounter); diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index d1498580274a..c4221a0a0467 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -23,7 +23,7 @@ // buffers | tensor | | Base pointers of all (sub)buffers // barriers | tensor | | Pointers to all individual mbarriers // barrierStates | scratch | | Packed barrier phase (bit 0), arrival counts (bits[1..20] init, [21..40] current), and signed tx-count (bits[41..61]); zero means invalid/uninitialized -// barrierWriteRecipients | scratch | | CTA bitsets of write-tracking rows reached by outstanding TMA effects on each barrier +// barrierWriteRecipients | scratch | | CTA bitsets of EffectWrites rows published by each barrier // waiting | scratch | | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1) // writeVisibility | scratch | | Per-buffer thread-visibility bitmask (bit i => thread i visible) // readVisibility | scratch | | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks) @@ -120,12 +120,16 @@ int getCurrentThread(Operation *op, const ConSanTargetHooks *hooks) { thread += TC_THREAD_OFFSET; return thread; } + if (hooks->isCLCOp(op)) { + thread += CLC_THREAD_OFFSET; + return thread; + } return thread; } int getBaseThread(int thread) { return thread % NUM_THREADS; } -// Peer threads are the equivalent threads in the TMA, TC and normal +// Peer threads are the equivalent threads in the TMA, TC, CLC and normal // thread classes. // If a thread is a base thread, return the mask with the peers, otherwise // return the mask with the thread itself. @@ -134,6 +138,7 @@ uint64_t getThreadPeersMask(int thread) { if (thread < NUM_THREADS) { mask |= 1ULL << (thread + TMA_THREAD_OFFSET); mask |= 1ULL << (thread + TC_THREAD_OFFSET); + mask |= 1ULL << (thread + CLC_THREAD_OFFSET); } return mask; } @@ -159,6 +164,12 @@ Value currentCTAMask(ImplicitLocOpBuilder &b) { ctaId); } +Value allCTAsMask(ImplicitLocOpBuilder &b) { + int numCTAs = ttg::lookupNumCTAs(b); + assert(numCTAs <= 16 && "ConSan CTA bitsets assume at most 16 CTAs"); + return arith::ConstantIntOp::create(b, (1u << numCTAs) - 1, 32); +} + uint16_t getBlockBroadcastMask(Value alloc) { auto allocTy = cast(alloc.getType()); auto kBlock = StringAttr::get(alloc.getContext(), "block"); @@ -259,6 +270,8 @@ Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { return getMulticastRecipientCTAs(b, tmaLoad.getResult()); return currentCTAMask(b); } + if (isa(op)) + return allCTAsMask(b); if (isTensorCoreOp(op)) return getRecipientCTAsForBroadcastMasks( b, ttng::getCTABroadcastMasks(ttng::getModuleTwoCTAs(op), {})); @@ -278,6 +291,8 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { tmaLoad.getBarrier()); return getLeaderCTA(b, tmaLoad.getBarrier()); } + if (isa(op)) + return allCTAsMask(b); if (isTensorCoreOp(op)) return getRecipientCTAsForBroadcastMasks( @@ -560,7 +575,8 @@ class ConcurrencySanitizerImpl { memType = MemType::SHARED_MEM; funcBuilder.createTrackBarrierWriteForBufferCall( b, barrier, effect.buf, effect.length, combinedPred, memType, op, - recipientCTAs, effectRecipientCTAs); + recipientCTAs, effectRecipientCTAs, + barrierInfo.diagonalEffectRecipientCTAs); } } if (barrierInfo.count > 0 || barrierInfo.txCount != 0) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp index afabbf693d2d..59f2d6cc3443 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp @@ -108,11 +108,11 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op, if (auto commit = dyn_cast(op)) { return ttng::getModuleTwoCTAs(op) && aliasesTracked(commit.getBarrier()); } - if (auto tma = dyn_cast(op)) { + if (auto tma = dyn_cast(op)) { return tma.getMulticast() && aliasesTracked(tma.getBarrier()); } - if (auto tma = dyn_cast(op)) { - return tma.getMulticast() && aliasesTracked(tma.getBarrier()); + if (auto clc = dyn_cast(op)) { + return aliasesTracked(clc.getMbarrier()); } return false; } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index 395d4e90f4bd..f3bc65039879 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -42,6 +42,10 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { return isa(op); } + bool isCLCOp(Operation *op) const override { + return isa(op); + } + std::optional getBarrierInitInfo(Operation *op) const override { if (auto initOp = dyn_cast(op)) { @@ -101,6 +105,11 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { } if (auto storeOp = dyn_cast(op)) mask = getBlockBroadcastMask(storeOp.getSrc().getType()); + if (isa(op) && ttg::lookupNumCTAs(op) > 1) { + Value ctaId = tti::ExperimentalClusterCTAIdOp::create(b, b.getLoc()); + return arith::CmpIOp::create(b, arith::CmpIPredicate::eq, ctaId, + arith::ConstantIntOp::create(b, 0, 32)); + } // In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair // issues the op. @@ -226,6 +235,24 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, loadOp.getResult()); } + if (auto tryCancelOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->barriers.push_back( + {tryCancelOp.getMbarrier(), nullptr, /*count=*/0, + MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, + /*txCount=*/ + -static_cast(tti::getMemDescLength(tryCancelOp.getResult())), + /*diagonalEffectRecipientCTAs=*/true}); + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, + tryCancelOp.getResult()); + } + if (auto loadResultOp = dyn_cast(op)) { + info.emplace(); + info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; + info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read, + loadResultOp.getSrc()); + } if (auto storeOp = dyn_cast(op)) { info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount; diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 48a609200f2a..bf3064ad8f2c 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -8,7 +8,7 @@ from triton.experimental.gluon.language.nvidia import blackwell from triton.experimental.gluon.language.nvidia import hopper from triton.experimental.gluon.language.nvidia import ampere -from triton.experimental.gluon.language.nvidia.blackwell import allocate_tensor_memory, mbarrier, tma +from triton.experimental.gluon.language.nvidia.blackwell import allocate_tensor_memory, clc, mbarrier, tma from triton._internal_testing import is_cuda, run_in_process @@ -250,6 +250,44 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell") +@pytest.mark.parametrize("FAILURE", [True, False]) +def test_clc_result_visibility(FAILURE, device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_clc_result_visibility, (FAILURE, device, False, monkeypatch, num_ctas)) + if FAILURE: + assert_expected_cuda_failure(result.exc) + assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(out, FAILURE: ttgl.constexpr): + cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1) + layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) + clc_result = ttgl.allocate_shared_memory(ttgl.int64, [2], layout) + clc_bar = mbarrier.allocate_mbarrier() + mbarrier.init(clc_bar, count=1) + + clc.try_cancel(clc_result, clc_bar) + mbarrier.expect(clc_bar, 16) + mbarrier.wait(clc_bar, 0, pred=(not FAILURE)) + response = clc.load_result(clc_result) + mbarrier.wait(clc_bar, 0, pred=FAILURE) + mbarrier.invalidate(clc_bar) + + ttgl.store(out + ttgl.program_id(0), response.is_canceled()) + + output = torch.empty((1, ), device=device, dtype=torch.bool) + kernel[(1, )](output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") def test_async_tma_multicast_kernel_reuse(device, run_wrapper, monkeypatch, num_ctas): if num_ctas == 1: diff --git a/test/TritonGPU/amd/amd-consan.mlir b/test/TritonGPU/amd/amd-consan.mlir index d475a14b6ea7..fdb7c4ac3ca7 100644 --- a/test/TritonGPU/amd/amd-consan.mlir +++ b/test/TritonGPU/amd/amd-consan.mlir @@ -242,7 +242,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @async_wait() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 4295032833 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -699,7 +699,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @amdg_async_wait() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 4295032833 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -830,10 +830,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -874,10 +874,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -908,10 +908,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index e5db7d7ba9fe..5ef6070da05b 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -252,6 +252,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // ----- +#shared_clc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrier = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { + // CHECK-LABEL: @clc_try_cancel_diagonal_effect_recipients + tt.func public @clc_try_cancel_diagonal_effect_recipients() { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %result = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2xi64, #shared_clc, #smem, mutable> + %bar = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<1xi64, #barrier, #smem, mutable> + ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #barrier, #smem, mutable> + // CHECK: tti.experimental_cluster_cta_id + // CHECK: arith.cmpi eq + // CHECK: %[[CLC_THREAD:.*]] = arith.constant 48 : i32 + // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[CLC_THREAD]] + // CHECK: %[[CLC_MASK:.*]] = arith.constant 281474976710656 : i64 + // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[CLC_MASK]] + // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer{{.*}}_I1 + ttng.clc_try_cancel %result, %bar : !ttg.memdesc<2xi64, #shared_clc, #smem, mutable>, !ttg.memdesc<1xi64, #barrier, #smem, mutable> + ttng.barrier_expect %bar, 16, %true : !ttg.memdesc<1xi64, #barrier, #smem, mutable> + ttng.wait_barrier %bar, %c0_i32, %true : !ttg.memdesc<1xi64, #barrier, #smem, mutable> + // CHECK: tt.call @__triton_consan_verify_write_visibility + // CHECK: ttng.clc_load_result + %clc_res = ttng.clc_load_result %result : !ttg.memdesc<2xi64, #shared_clc, #smem, mutable> -> i128 + tt.return + } +} + +// ----- + #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[0, 0]]}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #smem = #ttg.shared_memory @@ -1000,7 +1030,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @async_commit_group() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 4295032833 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -1251,10 +1281,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -1296,10 +1326,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { @@ -1332,10 +1362,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 8590065666 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { diff --git a/test/TritonNvidiaGPU/membar-cluster.mlir b/test/TritonNvidiaGPU/membar-cluster.mlir index becb904b91ee..05925e949748 100644 --- a/test/TritonNvidiaGPU/membar-cluster.mlir +++ b/test/TritonNvidiaGPU/membar-cluster.mlir @@ -473,6 +473,35 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- +#sharedCLC = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> +#barrierCLC = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CLC multicasts completion through the cluster, so it needs init sync even + // if the barrier allocation shape looks per-CTA. + // CHECK-LABEL: @cluster_clc_with_per_cta_barrier + // CHECK: ttng.init_barrier + // CHECK-NEXT: ttng.fence_mbarrier_init_release_cluster + // CHECK-NEXT: ttng.cluster_barrier {relaxed = true} + // CHECK-NEXT: ttng.clc_try_cancel + // CHECK: tt.return + tt.func @cluster_clc_with_per_cta_barrier() { + %true = arith.constant true + %result = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #sharedCLC, #smem, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + ttng.init_barrier %barrier, 1 : !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + ttng.clc_try_cancel %result, %barrier : + !ttg.memdesc<2xi64, #sharedCLC, #smem, mutable>, + !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + ttng.barrier_expect %barrier, 16, %true : + !ttg.memdesc<2xi64, #barrierCLC, #smem, mutable> + tt.return + } +} + +// ----- + #nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> #barrierEnc = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}>