diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index 1a2565150b43..a57eb0896cd1 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -142,51 +142,46 @@ class FunctionBuilder { void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, uint64_t threadMask, Value pred, MemType memType, - Operation *insertPoint, - Value recipientCTAs); + Operation *insertPoint, Value effectCTAs); // setReadVisibility: add the threads set in threadMask to the buffer's read // visibility bitmask. void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, uint64_t threadMask, Value pred, MemType memType, - Operation *insertPoint, Value recipientCTAs); + Operation *insertPoint, Value effectCTAs); // clearWriteTracking: clear all the information about threads writing to a // buffer. void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs); + Value effectCTAs); // clearReadVisibility: clear the read visibility for a buffer. void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs); + Value effectCTAs); // clearReadTracking: clear the read tracking for a buffer. void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, Value pred, MemType memType, - Operation *insertPoint, Value recipientCTAs); + Operation *insertPoint, Value effectCTAs); // trackVisibleWrites: snapshot buffers currently visible to the thread into // the tracking table for a barrier. void createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar, int thread, Value pred, MemType memType, - Operation *insertPoint, - Value recipientCTAs); + Operation *insertPoint, Value barrierCTAs); // trackVisibleReads: snapshot buffers currently visible to the thread into // the read tracking table for a barrier. void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, int thread, Value pred, MemType memType, - Operation *insertPoint, Value recipientCTAs); + Operation *insertPoint, Value barrierCTAs); // trackBarrierWriteForBuffer: mark a specific buffer as tracked by a - // barrier in the write-tracking table. When diagonalEffectRecipientCTAs is - // false, every signaled barrier row publishes the full effectRecipientCTAs - // mask. When it is true, barrier row i publishes only bit i of that mask. + // barrier in the write-tracking table. void createTrackBarrierWriteForBufferCall(ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, - Value barrierRecipientCTAs, - Value effectRecipientCTAs, - bool diagonalEffectRecipientCTAs); + Value barrierCTAs, + Value effectCTAs); // clearBarrierWriteTracking: clear all write tracking associated with the // given barrier row. void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar, @@ -213,14 +208,14 @@ class FunctionBuilder { uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs, bool allowNoWrite); + Value effectCTAs, bool allowNoWrite); // verifyReadVisibility: ensure all reads from the buffer are visible to the // thread. void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs); + Value effectCTAs); // copyWriteVisibility: replicate the write visibility bit of sourceThread to // every destination thread in destMask. void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, @@ -231,6 +226,11 @@ class FunctionBuilder { void createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, uint64_t destMask, Value pred, MemType memType, Operation *insertPoint); + // publishClusterVisibility: after a non-relaxed cluster barrier, make + // synchronous facts visible to every CTA in the cluster. + void createPublishClusterVisibilityCall(ImplicitLocOpBuilder &b, Value pred, + MemType memType, + Operation *insertPoint); // stageAccessForCommit: mark the buffer as staged (value -1) in the // outstanding commit table for this thread. void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf, @@ -273,7 +273,7 @@ class FunctionBuilder { void createCheckOutstandingCommitsCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef pendingAccessType, Value pred, MemType memType, - CommitKind::Kind commitKind, Operation *insertPoint, Value recipientCTAs, + CommitKind::Kind commitKind, Operation *insertPoint, Value effectCTAs, bool excludeSelf = false); private: diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md index 8e67561df8f2..4c06e4c6d41c 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md @@ -17,7 +17,7 @@ implementation: ConSan currently supports one public entry point in the module. It uses BufferRegion analysis to collect shared-memory buffers, tensor-memory buffers, and barrier allocations, then creates auxiliary state in distributed tensors and -shared-cluster global scratch memory. Most state has a leading CTA dimension so +shared-cluster global scratch memory. Most scratch state is CTA-qualified so cluster and multicast effects can be modeled explicitly. ## Visibility Model @@ -58,6 +58,24 @@ At a `ttg.warp_specialize`, the pass copies the default thread's read and write visibility to the destination partition peer masks so partition-local execution starts with the visibility frontier that existed before specialization. +## CTA Model + +The single-CTA case is the degenerate form of the multiCTA model: every +CTA-qualified axis has one row, so the usual per-buffer and per-barrier rules +apply unchanged. + +For multiCTA kernels, each CTA is modeled as its own set of logical threads. +Buffer and barrier descriptors stay CTA-agnostic, while shadow state records the +CTA whose buffer row, barrier row, logical thread, or visibility mask a fact +belongs to. This keeps the single-CTA visibility rules intact and adds only the +question of which CTA rows an operation reads, writes, or synchronizes. + +A multicast-layout barrier has one live barrier row per multicast group, owned +by the group's lead CTA. Every CTA in the group may `arrive` or `expect` on that +row, but only the lead CTA initializes, waits on, and invalidates it. This is +the same model as several independent logical threads arriving on one barrier +while only one logical thread waits on it; non-leader barrier rows are not live. + ## Runtime State ConSan keeps enough runtime state to answer two questions at each instrumented @@ -74,7 +92,7 @@ At a high level, the pass maintains: - Optional alias metadata when tracked buffer regions overlap. - A shared-cluster lock that serializes instrumentation updates. -Most runtime state is CTA-indexed so cluster, multicast, and cross-CTA effects +Most runtime state is CTA-qualified so cluster, multicast, and cross-CTA effects can be represented directly. Scratch state is zero-initialized once before the instrumented body runs, and the initialization is followed by a CTA or cluster barrier before any instrumented operation can use it. @@ -94,7 +112,8 @@ selected buffer. The runtime checks account for aliasing and CTA recipients. A check against one buffer is expanded through the alias metadata when BufferRegion analysis found overlapping tracked regions, and multi-CTA operations only inspect the CTA rows -that the operation can affect. +that the operation can affect. Aliasing remains intra-CTA: overlapping +descriptors may alias within one CTA row, but not across different CTA rows. After a barrier-tracked read, ConSan records that the current peer thread mask can see that read. After a barrier-tracked write, ConSan records the current @@ -105,28 +124,28 @@ All normal instrumentation emitted around one IR operation is wrapped in the ConSan lock. Barrier waits are split into a locked pre-wait section and a locked post-wait section. -## CTA Recipients and Multicast +## CTA Issuers, Effects, and Recipients -Most multi-CTA instrumentation computes a CTA recipient bitset. That bitset is -converted to a mask over the CTA dimension so only relevant CTA rows are checked -or updated. +Single-CTA operations implicitly use the current CTA for all three roles. In a +multiCTA kernel those roles can differ: -The target hooks compute recipients from the operation: +- The issuer predicate selects which CTA actually executes the instrumented op. +- The memory-effect CTA bitset selects the buffer rows that the op reads or + writes. +- The barrier-recipient CTA bitset selects the live barrier rows that arrivals, + expectations, and completion signals update. -- Non-multicast operations usually target the current CTA. -- Multicast TMA loads update all result-recipient CTAs, while barrier arrivals - route to the leader barrier CTA. -- NVIDIA two-CTA Tensor Core operations are predicated to the issuing CTA pair - leader. -- TMA load effects that write one set of CTAs and signal a different leader - barrier remember the effect-recipient CTA rows so a later wait can transfer - write visibility to both the waiting CTA and the written CTA rows. -- CLC try-cancel effects use all CTA rows for the barrier and memory effect, - with diagonal recipient handling for the written result rows. +For example, a multicast TMA load is issued by the multicast-group leader, +writes every result-recipient CTA row, and signals the leader barrier row. A +two-CTA Tensor Core operation is issued by the even CTA in the pair, but its +memory effects cover both CTA rows in that pair. CLC try-cancel is issued once +for the cluster and touches all CTA rows. ## Barrier Synchronization -ConSan separates barrier tracking from visibility transfer. +ConSan separates barrier tracking from visibility transfer. Ordinary mbarrier +operations follow the live-row rule from the CTA model above: all participating +CTAs address the lead barrier row, and only the lead CTA performs the wait. For frontier-tracked barriers, an arrive or commit snapshots the current thread's visible writes and reads into the barrier's tracking state. A later @@ -154,6 +173,10 @@ Write transfers also consult the recorded effect-recipient CTA rows, which lets TMA-style and CLC cross-CTA writes become visible in the CTA rows reached by the memory effect. Read transfers update the current CTA row. +A non-relaxed cluster barrier is different from an mbarrier wait: it publishes +synchronous work from base threads to all CTA rows directly (i.e., just the generic +proxy). + ## Barrier Lifecycle and Deadlock Checks The barrier state table models initialized, invalidated, phase, arrival-count, diff --git a/include/triton/Dialect/TritonInstrument/IR/Utility.h b/include/triton/Dialect/TritonInstrument/IR/Utility.h index 19dc922ddbfc..af1bfcd2dbd0 100644 --- a/include/triton/Dialect/TritonInstrument/IR/Utility.h +++ b/include/triton/Dialect/TritonInstrument/IR/Utility.h @@ -21,14 +21,7 @@ class FunctionBuilder; constexpr int numMemTypes = getMaxEnumValForMemType() + 1; -constexpr int NUM_THREADS = 16; -constexpr int TMA_THREAD_OFFSET = NUM_THREADS; -constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS; -constexpr int CLC_THREAD_OFFSET = TC_THREAD_OFFSET + NUM_THREADS; -constexpr int TOTAL_NUM_THREADS = CLC_THREAD_OFFSET + NUM_THREADS; -static_assert(TOTAL_NUM_THREADS <= 64, - "ConSan thread bitsets are stored in i64 masks"); -const int THREADS_BITMASK_SIZE = llvm::PowerOf2Ceil(TOTAL_NUM_THREADS); +constexpr int MAX_NUM_BASE_THREADS = 16; namespace CommitKind { enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds }; @@ -42,9 +35,8 @@ enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds }; // writeVisibility + readVisibility per active memory type. constexpr int kCapturesPerMemType = 2; -// barrierStates + waiting + activeMasks + barrierWriteRecipients (only when -// barriers exist). -constexpr int kBarrierBaseCaptures = 4; +// barrierStates + waiting + activeMasks (only when barriers exist). +constexpr int kBarrierBaseCaptures = 3; // writeTracking + readTracking per active memory type (only when barriers // exist and the memory type has buffers). @@ -122,6 +114,20 @@ struct ValueType { // that pointer. For tensor descriptors and constants, ValueType::value is the // tensor itself and ValueType::type is its type. struct AuxDataMap { + struct ThreadLayout { + int numBaseThreads = 1; + int numBaseThreadSlots = 1; + int tmaThreadOffset = -1; + int tcThreadOffset = -1; + int clcThreadOffset = -1; + int totalNumThreads = 1; + int numThreadSlots = 1; + + bool hasTMAThreads() const { return tmaThreadOffset >= 0; } + bool hasTCThreads() const { return tcThreadOffset >= 0; } + bool hasCLCThreads() const { return clcThreadOffset >= 0; } + }; + struct RegionToValueMap { DenseMap values; ValueType at(Region *region) { @@ -142,52 +148,49 @@ struct AuxDataMap { // Shape notation: // C = CTAs in the cluster. + // Cbar, Cbuf, Cthr, Cmask = CTA dimensions qualifying barriers, buffers, + // threads, and thread masks respectively. Each has extent C. // B = tracked buffers for one memory type, power-of-two padded. // K = tracked mbarriers, power-of-two padded. - // T = logical ConSan thread bit slots, padded to 64. - // P = base-thread commit columns, currently 16. + // T = logical ConSan thread bit slots used by this module, power-of-two + // padded for the distributed layout. + // P = base-thread commit columns used by this module, power-of-two padded. // // Storage notation: // tensor = distributed tensor value. // scratch = pointer to shared-cluster global scratch memory. - // tensor, + // tensor, // Per-memory-type packed buffer descriptors. Each i64 stores the 32-bit base // offset and 32-bit length of one shared-memory or tensor-memory region. RegionToValueMap buffers[numMemTypes]; - // tensor, + // tensor, // Packed descriptors for tracked mbarrier allocations. Barriers are shared // memory descriptors. RegionToValueMap barriers; - // scratch, + // scratch, // Packed barrier lifecycle state. Zero means invalid/uninitialized. Bit 0 is // phase, bits [1..20] are the initial arrival count, bits [21..40] are the // current arrival count, and bits [41..61] hold a signed tx-count. RegionToValueMap barrierStates; - // scratch, - // Per-barrier CTA bitsets of write-recipient rows reached by outstanding - // EffectWrites operations such as TMA and CLC. Used when a later wait - // transfers tracked writes. - RegionToValueMap barrierWriteRecipients; - - // scratch, + // scratch, // Per-memory-type write frontier. Bit i means logical ConSan thread i can see // the latest write to the buffer row. RegionToValueMap writeVisibility[numMemTypes]; - // scratch, + // scratch, // Per-memory-type buffer/barrier map for writes that a barrier tracks. RegionToValueMap writeTracking[numMemTypes]; - // scratch, + // scratch, // Per-memory-type read frontier. For each buffer and logical thread lane, the // i64 value is a bitmask of reads visible to that lane's thread. RegionToValueMap readVisibility[numMemTypes]; - // scratch, + // scratch, // Per-memory-type buffer/barrier map for read visibility masks that a barrier // tracks. RegionToValueMap readTracking[numMemTypes]; @@ -196,9 +199,11 @@ struct AuxDataMap { // Per-commit-kind outstanding commit counters for shared-memory buffers. // Entries are 0 for none, -1 for staged but uncommitted, and positive for a // committed access with an outstanding-group distance. + // Just one C dimension as ampere async_copy, WGMMA and TMA store are + // intra-CTA. RegionToValueMap commits[CommitKind::NumCommitKinds]; - // tensor, + // tensor, // Optional per-memory-type alias matrix. Created only when BufferRegion // analysis finds cross-buffer aliasing; checks expand selected buffer rows // through this matrix. @@ -208,7 +213,7 @@ struct AuxDataMap { // Shared-cluster lock used to serialize ConSan instrumentation updates. RegionToValueMap lock; - // scratch, + // scratch, // Deadlock-detection bitfield. Each base thread uses two bits: waiting flag // and stored phase. RegionToValueMap waiting; @@ -223,6 +228,10 @@ struct AuxDataMap { // aliasMatrices to make visibility and commit checks conservative. std::array hasNonTrivialAliasing{}; + // Dense logical-thread numbering for this module. Base threads are always + // present; TMA/TC/CLC peer ranges are added only when the module uses them. + ThreadLayout threadLayout; + LogicalResult populateAndPassToWarpSpecialize(ModuleOp module, FunctionBuilder &funcBuilder, const ConSanTargetHooks *hooks); diff --git a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h index 61bdf8a93fde..49d4a7c5cf29 100644 --- a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h +++ b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h @@ -45,18 +45,6 @@ struct MemEffectsOpInfo { int count; BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier; int txCount = 0; - // For EffectWrites, effectRecipientCTAs identifies the CTA rows where the - // op wrote its explicit result. By default, for - // diagonalEffectRecipientCTAs=false, waiting on a barrier publishes the CTA - // rows in effectRecipientCTAs, which is the full mask. This is the - // behaviour of TMA multicast. If diagonalEffectRecipientCTAs is true, - // waiting on a barrier publishes only the CTA rows in effectRecipientCTAs, - // which is the diagonal mask. e.g. effectRecipientCTAs = 0b1101 If - // DiagonalEffectRecipientCTAs is false, waiting on the barrier publishes - // the following CTA rows: CTA0 0b1101 CTA1 0b1101 CTA2 0b1101 CTA3 0b1101 - // If diagonalEffectRecipientCTAs is true, waiting on the barrier publishes - // the following CTA rows: CTA0 0b1000 CTA1 0b0100 CTA2 0b0000 CTA3 0b0001 - bool diagonalEffectRecipientCTAs = false; }; enum class TrackingKind { None, diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index ee461bd1e628..08d3eae7a7cb 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -86,20 +86,14 @@ struct BufferDescriptorsOpConversion auto lengths = adaptor.getLengths(); assert(offsets.size() == lengths.size() && "Mismatched descriptor arrays"); - auto totalTensorType = cast(op.getResult().getType()); - // The totalEncoding is of shape [CTAs, Descriptors] - auto totalEncoding = - cast(totalTensorType.getEncoding()); - assert(totalTensorType.getRank() == 2 && - "descriptor tables must have shape [cta, descriptor]"); + auto tensorType = cast(op.getResult().getType()); + auto encoding = + cast(tensorType.getEncoding()); + assert(tensorType.getRank() == 1 && + "descriptor tables must have shape [descriptor]"); assert(static_cast(offsets.size()) == - totalTensorType.getShape().back() && + tensorType.getShape().back() && "Descriptor data must match the descriptor dimension"); - // Get a slice of shape [Descriptors] that will be broadcasted at the end - auto encoding = tti::getSingleDimSliceEncoding(totalEncoding, /*dim=*/1); - auto tensorType = - RankedTensorType::get({totalTensorType.getShape().back()}, - totalTensorType.getElementType(), encoding); SmallVector offsetVals; offsetVals.reserve(offsets.size()); @@ -146,9 +140,6 @@ struct BufferDescriptorsOpConversion Value bufDescriptors = arith::OrIOp::create(rewriter, loc, trimmedPointers.getType(), trimmedPointers, lengthTensor); - bufDescriptors = tti::expandOuterSlicedDim(rewriter, loc, bufDescriptors); - bufDescriptors = triton::BroadcastOp::create(rewriter, loc, totalTensorType, - bufDescriptors); rewriter.replaceOp(op, bufDescriptors); return success(); } diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index d1b44c3b92b1..161224b80277 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -62,15 +62,12 @@ constexpr unsigned bitsPerThread = 2; constexpr unsigned flagBit = 0; constexpr unsigned phaseBit = 1; -constexpr uint32_t makeInterleavedMask(unsigned bit) { +uint32_t makeInterleavedMask(unsigned bit, unsigned numBaseThreads) { uint32_t mask = 0; - for (unsigned i = 0; i < tti::NUM_THREADS; ++i) + for (unsigned i = 0; i < numBaseThreads; ++i) mask |= 1u << (bitsPerThread * i + bit); return mask; } - -constexpr uint32_t flagMask = makeInterleavedMask(flagBit); -constexpr uint32_t phaseMask = makeInterleavedMask(phaseBit); } // namespace WaitingBits // Information about the optional assert message and tensor type to check. @@ -79,9 +76,9 @@ struct AssertInfo { Type type; }; -static uint64_t expandActiveMask(uint64_t activeMask) { +static uint64_t expandActiveMask(uint64_t activeMask, unsigned numBaseThreads) { uint64_t expanded = 0; - for (unsigned i = 0; i < tti::NUM_THREADS; ++i) { + for (unsigned i = 0; i < numBaseThreads; ++i) { if (activeMask & (1ull << i)) expanded |= 1ull << (WaitingBits::bitsPerThread * i + WaitingBits::flagBit); @@ -269,13 +266,13 @@ Value convertAndBroadcast(ImplicitLocOpBuilder &b, Value tensor, Value expandAliases(ImplicitLocOpBuilder &b, Value bufferMask, Value aliasMatrix, RankedTensorType aliasMatrixType) { - assert(aliasMatrixType.getRank() == 3 && - "Alias matrix expected to be rank-3"); + assert(aliasMatrixType.getRank() == 2 && + "Alias matrix expected to be rank-2"); auto bufferMaskType = cast(bufferMask.getType()); Value bufMaskMatrix = - convertAndBroadcast(b, bufferMask, {0, 1}, aliasMatrixType); + convertAndBroadcast(b, bufferMask, {0}, aliasMatrixType); Value aliasingMask = arith::AndIOp::create(b, aliasMatrix, bufMaskMatrix); - Value aliasVector = reduce(b, aliasingMask, /*axis=*/1); + Value aliasVector = reduce(b, aliasingMask, /*axis=*/0); return createConvertLayout(b, aliasVector, bufferMaskType.getEncoding()); } @@ -290,8 +287,7 @@ Value adjustIntegerWidth(ImplicitLocOpBuilder &b, Value value, } Value createThreadColumnMask(ImplicitLocOpBuilder &b, Value threadMask, - RankedTensorType tensorType) { - int columnDim = tensorType.getRank() - 1; + RankedTensorType tensorType, int columnDim) { auto loc = b.getLoc(); auto encoding = cast(tensorType.getEncoding()); auto sliceEncoding = tti::getSingleDimSliceEncoding(encoding, columnDim); @@ -337,9 +333,18 @@ Value createDimMask(ImplicitLocOpBuilder &b, Value index, return convertAndBroadcast(b, mask1D, {dim}, maskType); } -Value createColumnMask(ImplicitLocOpBuilder &b, Value column, - RankedTensorType tensorType) { - return createDimMask(b, column, tensorType, tensorType.getRank() - 1); +Value createDimIndices(ImplicitLocOpBuilder &b, RankedTensorType tensorType, + int dim) { + assert(dim >= 0 && dim < tensorType.getRank() && "invalid tensor dimension"); + auto encoding = cast(tensorType.getEncoding()); + auto sliceEncoding = tti::getSingleDimSliceEncoding(encoding, dim); + auto indexType = RankedTensorType::get({tensorType.getShape()[dim]}, + b.getI32Type(), sliceEncoding); + Value range = triton::MakeRangeOp::create(b, indexType, /*start=*/0, + /*end=*/tensorType.getShape()[dim]); + auto fullIndexType = cast( + tensorType.cloneWith(std::nullopt, b.getI32Type())); + return convertAndBroadcast(b, range, {dim}, fullIndexType); } Value createCurrentCTAMask(ImplicitLocOpBuilder &b) { @@ -348,26 +353,25 @@ Value createCurrentCTAMask(ImplicitLocOpBuilder &b) { ctaId); } -Value createRecipientCTAMask(ImplicitLocOpBuilder &b, - RankedTensorType tensorType, Value recipientCTAs) { +Value createCTASetMask(ImplicitLocOpBuilder &b, RankedTensorType tensorType, + int dim, Value ctas) { int numCTAs = ttg::lookupNumCTAs(b); // Turn the scalar recipient bitset into a tensor mask over logical CTA rows: - // build a [0, numCTAs) row-index vector on dim 0, broadcast it to the state + // build a [0, numCTAs) row-index vector on `dim`, broadcast it to the state // tensor shape, and test one bit of `recipientCTAs` per row. auto loc = b.getLoc(); auto encoding = cast(tensorType.getEncoding()); - auto rowSliceEncoding = tti::getSingleDimSliceEncoding(encoding, /*dim=*/0); + auto rowSliceEncoding = tti::getSingleDimSliceEncoding(encoding, dim); auto rowType = RankedTensorType::get({numCTAs}, b.getI32Type(), rowSliceEncoding); Value rowIdx = triton::MakeRangeOp::create(b, rowType, /*start=*/0, /*end=*/numCTAs); auto indexType = cast( tensorType.cloneWith(std::nullopt, b.getI32Type())); - rowIdx = convertAndBroadcast(b, rowIdx, {0}, indexType); + rowIdx = convertAndBroadcast(b, rowIdx, {dim}, indexType); - Value recipientBitsTensor = - triton::SplatOp::create(b, indexType, recipientCTAs); + Value recipientBitsTensor = triton::SplatOp::create(b, indexType, ctas); Value shifted = arith::ShRUIOp::create(b, recipientBitsTensor, rowIdx); Value one = tti::createConstIntTensor(b, loc, 1, indexType); Value selectedBit = arith::AndIOp::create(b, shifted, one); @@ -375,24 +379,40 @@ Value createRecipientCTAMask(ImplicitLocOpBuilder &b, return arith::CmpIOp::create(b, arith::CmpIPredicate::ne, selectedBit, zero); } -Operation *createCTAScopedStoreScratchMemory(ImplicitLocOpBuilder &b, - Location loc, Value alloc, - Value tensor, - RankedTensorType tensorType, - Value recipientCTAs) { +Value createLeadCTAEffectMask(ImplicitLocOpBuilder &b, + RankedTensorType tensorType, Value effectCTAs) { + Value lhsMask = createCTASetMask(b, tensorType, /*dim=*/0, effectCTAs); + Value leadCTAMask = + createCTASetMask(b, tensorType, /*dim=*/2, createCurrentCTAMask(b)); + return arith::AndIOp::create(b, lhsMask, leadCTAMask); +} + +Operation *createMaskedStoreScratchMemory(ImplicitLocOpBuilder &b, Location loc, + Value alloc, Value tensor, + RankedTensorType tensorType, + Value mask) { int64_t numCTAs = ttg::lookupNumCTAs(b); if (numCTAs > 1) { // This should hopefully be folded with the previous load in the caller // function Value oldTensor = tti::createLoadScratchMemory(b, loc, alloc, tensorType); - Value ctaMask = createRecipientCTAMask(b, tensorType, recipientCTAs); // and this with the previous selectOp, if there is any - tensor = arith::SelectOp::create(b, loc, ctaMask, tensor, oldTensor); + tensor = arith::SelectOp::create(b, loc, mask, tensor, oldTensor); } return tti::createStoreScratchMemory(b, loc, alloc, tensor, tensorType, /*currentCTAOnly=*/false); } +Operation *createCTAScopedStoreScratchMemory(ImplicitLocOpBuilder &b, + Location loc, Value alloc, + Value tensor, + RankedTensorType tensorType, + Value recipientCTAs) { + return createMaskedStoreScratchMemory( + b, loc, alloc, tensor, tensorType, + createCTASetMask(b, tensorType, /*dim=*/0, recipientCTAs)); +} + } // namespace void FunctionBuilder::createFillGlobalTensorCall(ImplicitLocOpBuilder &b, @@ -455,7 +475,10 @@ void FunctionBuilder::createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 1}, waitingType); + convertAndBroadcast(fb, barriersEqBar, {1}, waitingType); + Value ctaMask = + createLeadCTAEffectMask(fb, waitingType, createCurrentCTAMask(fb)); + barriersEqBar = arith::AndIOp::create(fb, barriersEqBar, ctaMask); Value bitsPerThread = arith::ConstantIntOp::create(fb, WaitingBits::bitsPerThread, 32); @@ -503,8 +526,8 @@ void FunctionBuilder::createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, Value newWaiting = arith::SelectOp::create(fb, cond, pendingWaiting, waiting); - tti::createStoreScratchMemory(fb, fb.getLoc(), waitingPtr, newWaiting, - waitingType, /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), waitingPtr, newWaiting, + waitingType, ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -555,7 +578,10 @@ void FunctionBuilder::createClearWaitingCall(ImplicitLocOpBuilder &b, Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 1}, waitingType); + convertAndBroadcast(fb, barriersEqBar, {1}, waitingType); + Value ctaMask = + createLeadCTAEffectMask(fb, waitingType, createCurrentCTAMask(fb)); + barriersEqBar = arith::AndIOp::create(fb, barriersEqBar, ctaMask); Value bitsPerThread = arith::ConstantIntOp::create(fb, WaitingBits::bitsPerThread, 32); @@ -586,8 +612,8 @@ void FunctionBuilder::createClearWaitingCall(ImplicitLocOpBuilder &b, Value newWaiting = arith::SelectOp::create(fb, barriersEqBar, clearedWaiting, waiting); - tti::createStoreScratchMemory(fb, fb.getLoc(), waitingPtr, newWaiting, - waitingType, /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), waitingPtr, newWaiting, + waitingType, ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); }); @@ -598,7 +624,8 @@ void FunctionBuilder::createSetActiveMaskCall(ImplicitLocOpBuilder &b, Operation *insertPoint) { if (auxData.activeMasks.empty()) return; - int64_t expandedActiveMask = expandActiveMask(activeMask); + int64_t expandedActiveMask = + expandActiveMask(activeMask, auxData.threadLayout.numBaseThreads); Value expandedActiveMaskVal = arith::ConstantIntOp::create(b, expandedActiveMask, 32); Value activeMasksVal = auxData.activeMasks.at(insertPoint).value; @@ -626,7 +653,8 @@ void FunctionBuilder::createRetireActiveThreadCall(ImplicitLocOpBuilder &b, Operation *insertPoint) { if (auxData.activeMasks.empty()) return; - int64_t threadMask = expandActiveMask(1u << thread); + int64_t threadMask = + expandActiveMask(1u << thread, auxData.threadLayout.numBaseThreads); Value clearMaskVal = arith::ConstantIntOp::create(b, ~threadMask, 32); Value activeMasksVal = auxData.activeMasks.at(insertPoint).value; auto activeMasksType = @@ -664,6 +692,10 @@ void FunctionBuilder::createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, if (!pred) { pred = arith::ConstantIntOp::create(b, 1, 1); } + uint32_t flagMask = WaitingBits::makeInterleavedMask( + WaitingBits::flagBit, auxData.threadLayout.numBaseThreads); + uint32_t phaseMask = WaitingBits::makeInterleavedMask( + WaitingBits::phaseBit, auxData.threadLayout.numBaseThreads); Value waitingVal = auxData.waiting.at(insertPoint).value; auto waitingType = cast(auxData.waiting.at(insertPoint).type); @@ -691,8 +723,8 @@ void FunctionBuilder::createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, createCallToCachedFunction( b, "check_all_active_waiting", args, assertInfo, {waitingGlobalType, barrierStatesGlobalType, activeMasksGlobalType}, - [waitingGlobalType, barrierStatesGlobalType, - activeMasksGlobalType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [waitingGlobalType, barrierStatesGlobalType, activeMasksGlobalType, + flagMask, phaseMask](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value pred = entryBlock->getArgument(0); Value waitingPtr = entryBlock->getArgument(1); @@ -705,9 +737,9 @@ void FunctionBuilder::createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, fb, fb.getLoc(), barrierStatesPtr, barrierStatesGlobalType); Value flagMaskTensor = tti::createConstIntTensor( - fb, fb.getLoc(), WaitingBits::flagMask, waitingGlobalType); + fb, fb.getLoc(), flagMask, waitingGlobalType); Value phaseMaskTensor = tti::createConstIntTensor( - fb, fb.getLoc(), WaitingBits::phaseMask, waitingGlobalType); + fb, fb.getLoc(), phaseMask, waitingGlobalType); Value flags = arith::AndIOp::create(fb, waiting, flagMaskTensor); Value phases = arith::AndIOp::create(fb, waiting, phaseMaskTensor); @@ -728,9 +760,12 @@ void FunctionBuilder::createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, Value phaseIsOne = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, barrierPhase, oneState); + phaseIsOne = + convertAndBroadcast(fb, phaseIsOne, {0, 1}, waitingGlobalType); Value effectiveWaiting = arith::SelectOp::create( fb, phaseIsOne, waitingPhase1, waitingPhase0); Value waitingOr = reduce(fb, effectiveWaiting, 1); + waitingOr = reduce(fb, waitingOr, 0); auto waitingOrType = cast(waitingOr.getType()); Value activeMasks = tti::createLoadScratchMemory( fb, fb.getLoc(), activeMasksPtr, activeMasksGlobalType); @@ -779,9 +814,8 @@ void FunctionBuilder::createVerifyBarrierCanInitCall(ImplicitLocOpBuilder &b, Value lengthVal = arith::ConstantIntOp::create(b, length, 32); SmallVector args = {mbarOffset, lengthVal, pred, barriersVal, barrierStatesVal, recipientCTAs}; - AssertInfo assertInfo{ - "Barrier re-initialized without prior invalidation", - barrierStatesType.cloneWith(std::nullopt, b.getI1Type())}; + AssertInfo assertInfo{"Barrier re-initialized without prior invalidation", + b.getI1Type()}; createCallToCachedFunction( b, "verify_barrier_can_init", args, assertInfo, {barriersType, barrierStatesType}, @@ -797,7 +831,7 @@ void FunctionBuilder::createVerifyBarrierCanInitCall(ImplicitLocOpBuilder &b, barrierStatesType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); - mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + mask = convertAndBroadcast(fb, mask, {1}, barrierStatesType); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); Value canInit = @@ -805,11 +839,12 @@ void FunctionBuilder::createVerifyBarrierCanInitCall(ImplicitLocOpBuilder &b, auto condType = cast(canInit.getType()); Value vTrue = tti::createConstIntTensor(fb, fb.getLoc(), 1, condType); canInit = arith::SelectOp::create(fb, mask, canInit, vTrue); - Value ctaMask = createRecipientCTAMask(fb, condType, recipientCTAs); + Value ctaMask = + createCTASetMask(fb, condType, /*dim=*/0, recipientCTAs); canInit = arith::SelectOp::create(fb, ctaMask, canInit, vTrue); Value predTensor = triton::SplatOp::create(fb, condType, pred); canInit = arith::SelectOp::create(fb, predTensor, canInit, vTrue); - triton::ReturnOp::create(fb, canInit); + triton::ReturnOp::create(fb, reduceAll(fb, canInit)); }); } @@ -836,7 +871,7 @@ void FunctionBuilder::createVerifyBarrierInitializedCall( barriersVal, barrierStatesVal, recipientCTAs}; AssertInfo assertInfo{ "Barrier used before initialization or after invalidation", - barrierStatesType.cloneWith(std::nullopt, b.getI1Type())}; + b.getI1Type()}; createCallToCachedFunction( b, "verify_barrier_initialized", args, assertInfo, {barriersType, barrierStatesType}, @@ -852,7 +887,7 @@ void FunctionBuilder::createVerifyBarrierInitializedCall( barrierStatesType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); - mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + mask = convertAndBroadcast(fb, mask, {1}, barrierStatesType); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); Value initialized = @@ -860,12 +895,14 @@ void FunctionBuilder::createVerifyBarrierInitializedCall( auto condType = cast(initialized.getType()); Value vTrue = tti::createConstIntTensor(fb, fb.getLoc(), 1, condType); initialized = arith::SelectOp::create(fb, mask, initialized, vTrue); - Value ctaMask = createRecipientCTAMask(fb, condType, recipientCTAs); + Value ctaMask = + createCTASetMask(fb, condType, /*dim=*/0, recipientCTAs); initialized = arith::SelectOp::create(fb, ctaMask, initialized, vTrue); Value predTensor = triton::SplatOp::create(fb, condType, pred); Value predicatedInitialized = arith::SelectOp::create(fb, predTensor, initialized, vTrue); - triton::ReturnOp::create(fb, predicatedInitialized); + triton::ReturnOp::create( + fb, reduceAll(fb, predicatedInitialized)); }); } @@ -910,7 +947,7 @@ void FunctionBuilder::createInitBarrierStateCall(ImplicitLocOpBuilder &b, barrierStatesType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); - mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + mask = convertAndBroadcast(fb, mask, {1}, barrierStatesType); Value countWide = adjustIntegerWidth( fb, count, cast(barrierStatesType.getElementType())); @@ -987,7 +1024,7 @@ void FunctionBuilder::createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, waitingPtr, waitingType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); - mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + mask = convertAndBroadcast(fb, mask, {1}, barrierStatesType); Value zeroState = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); @@ -1000,8 +1037,10 @@ void FunctionBuilder::createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, triton::SplatOp::create(fb, stateCondType, pred); updatedStates = arith::SelectOp::create(fb, statePredTensor, updatedStates, states); - Value waitingMask = - createConvertLayout(fb, mask, waitingType.getEncoding()); + Value waitingMask = convertAndBroadcast(fb, mask, {0, 1}, waitingType); + Value waitingCTAMask = + createLeadCTAEffectMask(fb, waitingType, createCurrentCTAMask(fb)); + waitingMask = arith::AndIOp::create(fb, waitingMask, waitingCTAMask); Value updatedWaiting = arith::SelectOp::create(fb, waitingMask, zeroWaiting, waiting); auto waitingCondType = cast(waitingMask.getType()); @@ -1012,9 +1051,9 @@ void FunctionBuilder::createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, createCTAScopedStoreScratchMemory(fb, fb.getLoc(), statesPtr, updatedStates, barrierStatesType, createCurrentCTAMask(fb)); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), waitingPtr, - updatedWaiting, waitingType, - createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), waitingPtr, + updatedWaiting, waitingType, + waitingCTAMask); triton::ReturnOp::create(fb); }); } @@ -1051,7 +1090,7 @@ void FunctionBuilder::createVerifyBarrierArriveCall( AssertInfo assertInfo{ "Barrier arrive underflow: current count or tx-count would become " "invalid", - barrierStatesType.cloneWith(std::nullopt, b.getI1Type())}; + b.getI1Type()}; createCallToCachedFunction( b, "verify_barrier_arrive", args, assertInfo, {barriersType, barrierStatesType}, @@ -1070,7 +1109,7 @@ void FunctionBuilder::createVerifyBarrierArriveCall( barrierStatesType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); - mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + mask = convertAndBroadcast(fb, mask, {1}, barrierStatesType); Value zero32 = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); @@ -1129,14 +1168,16 @@ void FunctionBuilder::createVerifyBarrierArriveCall( Value vTrue = tti::createConstIntTensor( fb, fb.getLoc(), 1, cast(valid.getType())); auto condType = cast(valid.getType()); - Value ctaMask = createRecipientCTAMask(fb, condType, recipientCTAs); + Value ctaMask = + createCTASetMask(fb, condType, /*dim=*/0, recipientCTAs); valid = arith::SelectOp::create(fb, ctaMask, valid, vTrue); Value predTensor = triton::SplatOp::create( fb, cast(valid.getType()), pred); Value predicatedValid = arith::SelectOp::create(fb, predTensor, valid, vTrue); - triton::ReturnOp::create(fb, predicatedValid); + triton::ReturnOp::create(fb, + reduceAll(fb, predicatedValid)); }); } @@ -1190,7 +1231,7 @@ void FunctionBuilder::createUpdateBarrierStateCall( barrierStatesType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); - mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + mask = convertAndBroadcast(fb, mask, {1}, barrierStatesType); Value zero32 = tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); @@ -1277,7 +1318,7 @@ void FunctionBuilder::createUpdateBarrierStateCall( void FunctionBuilder::createSetWriteVisibilityCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, uint64_t threadMask, - Value pred, MemType memType, Operation *insertPoint, Value recipientCTAs) { + Value pred, MemType memType, Operation *insertPoint, Value effectCTAs) { if (auxData.buffers[(int)memType].empty() || auxData.writeVisibility[(int)memType].empty()) { @@ -1297,7 +1338,7 @@ void FunctionBuilder::createSetWriteVisibilityCall( Value lengthVal = arith::ConstantIntOp::create(b, length, 32); SmallVector args = {bufOffset, lengthVal, pred, threadMaskVal, buffersVal, writeVisibilityVal, - recipientCTAs}; + effectCTAs}; createCallToCachedFunction( b, "set_write_visibility", args, /*assertInfo=*/std::nullopt, @@ -1309,7 +1350,7 @@ void FunctionBuilder::createSetWriteVisibilityCall( Value threadMaskVal = entryBlock->getArgument(3); Value buffers = entryBlock->getArgument(4); Value writeVisibilityPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value effectCTAs = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1319,16 +1360,19 @@ void FunctionBuilder::createSetWriteVisibilityCall( Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, writeVisibilityType); + convertAndBroadcast(fb, buffersEqBuf, {1}, writeVisibilityType); + Value relationMask = + createLeadCTAEffectMask(fb, writeVisibilityType, effectCTAs); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, relationMask); auto elemType = cast(writeVisibilityType.getElementType()); Value threadMaskElem = adjustIntegerWidth(fb, threadMaskVal, elemType); Value threadMaskTensor = triton::SplatOp::create(fb, writeVisibilityType, threadMaskElem); Value newVisibility = arith::SelectOp::create( fb, buffersEqBuf, threadMaskTensor, writeVisibility); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, - newVisibility, writeVisibilityType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + newVisibility, writeVisibilityType, + relationMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1337,7 +1381,7 @@ void FunctionBuilder::createSetWriteVisibilityCall( void FunctionBuilder::createSetReadVisibilityCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, uint64_t threadMask, - Value pred, MemType memType, Operation *insertPoint, Value recipientCTAs) { + Value pred, MemType memType, Operation *insertPoint, Value effectCTAs) { if (auxData.buffers[(int)memType].empty() || auxData.readVisibility[(int)memType].empty()) { @@ -1357,7 +1401,7 @@ void FunctionBuilder::createSetReadVisibilityCall( Value lengthVal = arith::ConstantIntOp::create(b, length, 32); SmallVector args = {bufOffset, lengthVal, pred, threadMaskVal, buffersVal, readVisibilityVal, - recipientCTAs}; + effectCTAs}; createCallToCachedFunction( b, "set_read_visibility", args, /*assertInfo=*/std::nullopt, @@ -1369,7 +1413,7 @@ void FunctionBuilder::createSetReadVisibilityCall( Value threadMaskVal = entryBlock->getArgument(3); Value buffers = entryBlock->getArgument(4); Value readVisibilityPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value effectCTAs = entryBlock->getArgument(6); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1379,22 +1423,37 @@ void FunctionBuilder::createSetReadVisibilityCall( Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, readVisibilityType); + convertAndBroadcast(fb, buffersEqBuf, {1}, readVisibilityType); + Value relationMask = + createLeadCTAEffectMask(fb, readVisibilityType, effectCTAs); + Value threadCTAMask = + createCTASetMask(fb, readVisibilityType, /*dim=*/2, effectCTAs); + threadCTAMask = arith::AndIOp::create( + fb, threadCTAMask, + createCTASetMask(fb, readVisibilityType, /*dim=*/4, effectCTAs)); + Value sameCTA = arith::CmpIOp::create( + fb, arith::CmpIPredicate::eq, + createDimIndices(fb, readVisibilityType, /*dim=*/2), + createDimIndices(fb, readVisibilityType, /*dim=*/4)); + threadCTAMask = arith::AndIOp::create(fb, threadCTAMask, sameCTA); + relationMask = arith::AndIOp::create(fb, relationMask, threadCTAMask); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, relationMask); auto elemType = cast(readVisibilityType.getElementType()); Value threadMaskElem = adjustIntegerWidth(fb, threadMaskVal, elemType); Value threadBit = triton::SplatOp::create(fb, readVisibilityType, threadMaskElem); Value threadColumnMask = - createThreadColumnMask(fb, threadMaskVal, readVisibilityType); + createThreadColumnMask(fb, threadMaskVal, readVisibilityType, + /*columnDim=*/3); Value readVisibilityOrThreadBit = arith::OrIOp::create(fb, readVisibility, threadBit); Value bufAndThread = arith::AndIOp::create(fb, buffersEqBuf, threadColumnMask); Value newVisibility = arith::SelectOp::create( fb, bufAndThread, readVisibilityOrThreadBit, readVisibility); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, - newVisibility, readVisibilityType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + newVisibility, readVisibilityType, + relationMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1405,7 +1464,7 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value effectCTAs) { if (auxData.buffers[(int)memType].empty() || auxData.writeTracking[(int)memType].empty()) { return; @@ -1422,7 +1481,7 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); SmallVector args = {bufOffset, lengthVal, pred, - buffersVal, writeTrackingVal, recipientCTAs}; + buffersVal, writeTrackingVal, effectCTAs}; createCallToCachedFunction( b, "clear_write_tracking", args, /*assertInfo=*/std::nullopt, @@ -1433,7 +1492,7 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value pred = entryBlock->getArgument(2); Value buffers = entryBlock->getArgument(3); Value writeTrackingPtr = entryBlock->getArgument(4); - Value recipientCTAs = entryBlock->getArgument(5); + Value effectCTAs = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1443,14 +1502,16 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, writeTrackingType); + convertAndBroadcast(fb, buffersEqBuf, {1}, writeTrackingType); + Value ctaMask = + createCTASetMask(fb, writeTrackingType, /*dim=*/0, effectCTAs); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeTrackingType); Value newTracking = arith::SelectOp::create(fb, buffersEqBuf, zero, writeTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, - newTracking, writeTrackingType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, + newTracking, writeTrackingType, ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1461,7 +1522,7 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value effectCTAs) { if (auxData.buffers[(int)memType].empty() || auxData.readVisibility[(int)memType].empty()) { return; @@ -1478,7 +1539,7 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); SmallVector args = {bufOffset, lengthVal, pred, - buffersVal, readVisibilityVal, recipientCTAs}; + buffersVal, readVisibilityVal, effectCTAs}; createCallToCachedFunction( b, "clear_read_visibility", args, /*assertInfo=*/std::nullopt, @@ -1489,7 +1550,7 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value pred = entryBlock->getArgument(2); Value buffers = entryBlock->getArgument(3); Value readVisibilityPtr = entryBlock->getArgument(4); - Value recipientCTAs = entryBlock->getArgument(5); + Value effectCTAs = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1499,14 +1560,17 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, readVisibilityType); + convertAndBroadcast(fb, buffersEqBuf, {1}, readVisibilityType); + Value ctaMask = + createCTASetMask(fb, readVisibilityType, /*dim=*/0, effectCTAs); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); Value newVisibility = arith::SelectOp::create(fb, buffersEqBuf, zero, readVisibility); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, - newVisibility, readVisibilityType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + newVisibility, readVisibilityType, + ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1517,7 +1581,7 @@ void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value effectCTAs) { if (auxData.buffers[(int)memType].empty() || auxData.readTracking[(int)memType].empty()) { @@ -1535,7 +1599,7 @@ void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); SmallVector args = {bufOffset, lengthVal, pred, - buffersVal, readTrackingVal, recipientCTAs}; + buffersVal, readTrackingVal, effectCTAs}; createCallToCachedFunction( b, "clear_read_tracking", args, /*assertInfo=*/std::nullopt, @@ -1546,7 +1610,7 @@ void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value pred = entryBlock->getArgument(2); Value buffers = entryBlock->getArgument(3); Value readTrackingPtr = entryBlock->getArgument(4); - Value recipientCTAs = entryBlock->getArgument(5); + Value effectCTAs = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1556,14 +1620,16 @@ void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, readTrackingType); + convertAndBroadcast(fb, buffersEqBuf, {1}, readTrackingType); + Value ctaMask = + createCTASetMask(fb, readTrackingType, /*dim=*/0, effectCTAs); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, readTrackingType); Value newTracking = arith::SelectOp::create(fb, buffersEqBuf, zero, readTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, - newTracking, readTrackingType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, + newTracking, readTrackingType, ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1574,7 +1640,7 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar, int thread, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value barrierCTAs) { if (auxData.barriers.empty() || auxData.writeVisibility[(int)memType].empty() || auxData.writeTracking[(int)memType].empty()) { @@ -1597,9 +1663,9 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, uint32_t length = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = {mbarOffset, lengthVal, pred, - threadVal, barriersVal, writeVisibilityVal, - writeTrackingVal, recipientCTAs}; + SmallVector args = {mbarOffset, lengthVal, pred, + threadVal, barriersVal, writeVisibilityVal, + writeTrackingVal, barrierCTAs}; createCallToCachedFunction( b, "track_visible_writes", args, /*assertInfo=*/std::nullopt, @@ -1613,7 +1679,7 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value barriers = entryBlock->getArgument(4); Value writeVisibilityPtr = entryBlock->getArgument(5); Value writeTrackingPtr = entryBlock->getArgument(6); - Value recipientCTAs = entryBlock->getArgument(7); + Value barrierCTAs = entryBlock->getArgument(7); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1626,7 +1692,11 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, writeTrackingType); + convertAndBroadcast(fb, barriersEqBar, {3}, writeTrackingType); + Value barrierCTAMask = + createCTASetMask(fb, writeTrackingType, /*dim=*/2, barrierCTAs); + barriersEqBar = + arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); Value threadI64 = arith::ExtUIOp::create(fb, fb.getI64Type(), threadVal); Value one64 = arith::ConstantIntOp::create(fb, 1, 64); @@ -1637,6 +1707,10 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, arith::AndIOp::create(fb, writeVisibility, threadBit); visibleWrites = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, visibleWrites, threadBit); + Value sourceCTAMask = createCTASetMask( + fb, writeVisibilityType, /*dim=*/2, createCurrentCTAMask(fb)); + visibleWrites = arith::AndIOp::create(fb, visibleWrites, sourceCTAMask); + visibleWrites = reduceLastDim(fb, visibleWrites); visibleWrites = convertAndBroadcast(fb, visibleWrites, {0, 1}, writeTrackingType); Value barAndVisible = @@ -1645,9 +1719,9 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, tti::createConstIntTensor(fb, fb.getLoc(), 1, writeTrackingType); Value newTracking = arith::SelectOp::create( fb, barAndVisible, writeTrackingOne, writeTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, - newTracking, writeTrackingType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, + newTracking, writeTrackingType, + barrierCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1658,7 +1732,7 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, int thread, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value barrierCTAs) { if (auxData.barriers.empty() || auxData.readVisibility[(int)memType].empty() || @@ -1682,9 +1756,9 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, uint32_t length = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = {mbarOffset, lengthVal, pred, - threadVal, barriersVal, readVisibilityVal, - readTrackingVal, recipientCTAs}; + SmallVector args = {mbarOffset, lengthVal, pred, + threadVal, barriersVal, readVisibilityVal, + readTrackingVal, barrierCTAs}; createCallToCachedFunction( b, "track_visible_reads", args, /*assertInfo=*/std::nullopt, @@ -1698,7 +1772,7 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value barriers = entryBlock->getArgument(4); Value readVisibilityPtr = entryBlock->getArgument(5); Value readTrackingPtr = entryBlock->getArgument(6); - Value recipientCTAs = entryBlock->getArgument(7); + Value barrierCTAs = entryBlock->getArgument(7); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1711,23 +1785,32 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, readTrackingType); + convertAndBroadcast(fb, barriersEqBar, {3}, readTrackingType); + Value barrierCTAMask = + createCTASetMask(fb, readTrackingType, /*dim=*/2, barrierCTAs); + barriersEqBar = + arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); Value threadColumnMask = - createColumnMask(fb, threadVal, readVisibilityType); + createDimMask(fb, threadVal, readVisibilityType, /*dim=*/3); Value readVisibilityZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); Value visibleReads = arith::SelectOp::create( fb, threadColumnMask, readVisibility, readVisibilityZero); - visibleReads = reduceLastDim(fb, visibleReads); + Value sourceCTAMask = createCTASetMask( + fb, readVisibilityType, /*dim=*/2, createCurrentCTAMask(fb)); + visibleReads = arith::SelectOp::create(fb, sourceCTAMask, visibleReads, + readVisibilityZero); + visibleReads = reduce(fb, visibleReads, /*axis=*/3); + visibleReads = reduce(fb, visibleReads, /*axis=*/2); visibleReads = - convertAndBroadcast(fb, visibleReads, {0, 1}, readTrackingType); + convertAndBroadcast(fb, visibleReads, {0, 1, 4}, readTrackingType); Value readTrackingOrVisible = arith::OrIOp::create(fb, readTracking, visibleReads); Value newTracking = arith::SelectOp::create( fb, barriersEqBar, readTrackingOrVisible, readTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, - newTracking, readTrackingType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, + newTracking, readTrackingType, + barrierCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1736,14 +1819,12 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, void FunctionBuilder::createTrackBarrierWriteForBufferCall( ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred, - MemType memType, Operation *insertPoint, Value barrierRecipientCTAs, - Value effectRecipientCTAs, bool diagonalEffectRecipientCTAs) { + MemType memType, Operation *insertPoint, Value barrierCTAs, + Value effectCTAs) { if (auxData.barriers.empty() || auxData.buffers[(int)memType].empty() || auxData.writeTracking[(int)memType].empty()) { return; } - assert(!auxData.barrierWriteRecipients.empty() && - "barrier write recipients must exist when tracking EffectWrites"); if (!pred) pred = arith::ConstantIntOp::create(b, 1, 1); Value barriersVal = auxData.barriers.at(insertPoint).value; @@ -1756,34 +1837,20 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( auxData.writeTracking[(int)memType].at(insertPoint).value; auto writeTrackingType = cast( auxData.writeTracking[(int)memType].at(insertPoint).type); - Value barrierWriteRecipientsVal = - auxData.barrierWriteRecipients.at(insertPoint).value; - auto barrierWriteRecipientsType = cast( - auxData.barrierWriteRecipients.at(insertPoint).type); uint32_t mbarLength = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value mbarLengthVal = arith::ConstantIntOp::create(b, mbarLength, 32); Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); Value bufLengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = {mbarOffset, - mbarLengthVal, - pred, - bufOffset, - bufLengthVal, - barriersVal, - buffersVal, - writeTrackingVal, - barrierWriteRecipientsVal, - barrierRecipientCTAs, - effectRecipientCTAs}; + SmallVector args = {mbarOffset, mbarLengthVal, pred, + bufOffset, bufLengthVal, barriersVal, + buffersVal, writeTrackingVal, barrierCTAs, + effectCTAs}; createCallToCachedFunction( b, "track_barrier_write_for_buffer", args, /*assertInfo=*/std::nullopt, - {barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, - (uint64_t)memType, (uint64_t)diagonalEffectRecipientCTAs}, - [writeTrackingType, barrierWriteRecipientsType, - diagonalEffectRecipientCTAs](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + {barriersType, buffersType, writeTrackingType, (uint64_t)memType}, + [writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value mbarLengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1792,74 +1859,41 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( Value barriers = entryBlock->getArgument(5); Value buffers = entryBlock->getArgument(6); Value writeTrackingPtr = entryBlock->getArgument(7); - Value barrierWriteRecipientsPtr = entryBlock->getArgument(8); - Value barrierRecipientCTAs = entryBlock->getArgument(9); - Value effectRecipientCTAs = entryBlock->getArgument(10); + Value barrierCTAs = entryBlock->getArgument(8); + Value effectCTAs = entryBlock->getArgument(9); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value writeTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); - Value barrierWriteRecipients = tti::createLoadScratchMemory( - fb, fb.getLoc(), barrierWriteRecipientsPtr, - barrierWriteRecipientsType); Value barrierDescriptor = createBufferDescriptor(fb, mbarOffset, mbarLengthVal); Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, barrierDescriptor); - Value effectRecipientCTAsTensor = triton::SplatOp::create( - fb, barrierWriteRecipientsType, effectRecipientCTAs); - if (diagonalEffectRecipientCTAs) { - // Expand the effect CTA mask diagonally: barrier row i publishes only - // bit i. This models per-CTA results, while the default replicated - // mask models TMA multicast where a barrier publishes all result - // rows. - auto encoding = cast( - barrierWriteRecipientsType.getEncoding()); - auto rowSliceEncoding = - tti::getSingleDimSliceEncoding(encoding, /*dim=*/0); - int numCTAs = barrierWriteRecipientsType.getShape()[0]; - auto rowType = RankedTensorType::get({numCTAs}, fb.getI32Type(), - rowSliceEncoding); - Value rowIdx = triton::MakeRangeOp::create(fb, rowType, - /*start=*/0, - /*end=*/numCTAs); - auto indexType = - cast(barrierWriteRecipientsType.cloneWith( - std::nullopt, fb.getI32Type())); - rowIdx = convertAndBroadcast(fb, rowIdx, {0}, indexType); - Value one = tti::createConstIntTensor(fb, fb.getLoc(), 1, indexType); - Value rowBit = arith::ShLIOp::create(fb, one, rowIdx); - effectRecipientCTAsTensor = - arith::AndIOp::create(fb, effectRecipientCTAsTensor, rowBit); - } - Value updatedBarrierWriteRecipients = arith::OrIOp::create( - fb, barrierWriteRecipients, effectRecipientCTAsTensor); - updatedBarrierWriteRecipients = arith::SelectOp::create( - fb, barriersEqBar, updatedBarrierWriteRecipients, - barrierWriteRecipients); - createCTAScopedStoreScratchMemory( - fb, fb.getLoc(), barrierWriteRecipientsPtr, - updatedBarrierWriteRecipients, barrierWriteRecipientsType, - barrierRecipientCTAs); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, writeTrackingType); + convertAndBroadcast(fb, barriersEqBar, {3}, writeTrackingType); Value bufferDescriptor = createBufferDescriptor(fb, bufOffset, bufLengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, bufferDescriptor); buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, writeTrackingType); + convertAndBroadcast(fb, buffersEqBuf, {1}, writeTrackingType); + Value bufferCTAMask = + createCTASetMask(fb, writeTrackingType, /*dim=*/0, effectCTAs); + Value barrierCTAMask = + createCTASetMask(fb, writeTrackingType, /*dim=*/2, barrierCTAs); Value trackMask = arith::AndIOp::create(fb, barriersEqBar, buffersEqBuf); + trackMask = arith::AndIOp::create(fb, trackMask, bufferCTAMask); + trackMask = arith::AndIOp::create(fb, trackMask, barrierCTAMask); Value writeTrackingOne = tti::createConstIntTensor(fb, fb.getLoc(), 1, writeTrackingType); Value newTracking = arith::SelectOp::create( fb, trackMask, writeTrackingOne, writeTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, - newTracking, writeTrackingType, - effectRecipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, + newTracking, writeTrackingType, + trackMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1884,68 +1918,43 @@ void FunctionBuilder::createClearBarrierWriteTrackingCall( auxData.writeTracking[(int)memType].at(insertPoint).value; auto writeTrackingType = cast( auxData.writeTracking[(int)memType].at(insertPoint).type); - Value barrierWriteRecipientsVal = - auxData.barrierWriteRecipients.at(insertPoint).value; - auto barrierWriteRecipientsType = cast( - auxData.barrierWriteRecipients.at(insertPoint).type); uint32_t length = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = { - mbarOffset, lengthVal, pred, - barriersVal, writeTrackingVal, barrierWriteRecipientsVal}; + SmallVector args = {mbarOffset, lengthVal, pred, barriersVal, + writeTrackingVal}; createCallToCachedFunction( b, "clear_barrier_write_tracking", args, /*assertInfo=*/std::nullopt, - {barriersType, writeTrackingType, barrierWriteRecipientsType, - (uint64_t)memType}, - [writeTrackingType, barrierWriteRecipientsType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + {barriersType, writeTrackingType, (uint64_t)memType}, + [writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); Value barriers = entryBlock->getArgument(3); Value writeTrackingPtr = entryBlock->getArgument(4); - Value barrierWriteRecipientsPtr = entryBlock->getArgument(5); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); Value writeTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); - Value barrierWriteRecipients = tti::createLoadScratchMemory( - fb, fb.getLoc(), barrierWriteRecipientsPtr, - barrierWriteRecipientsType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); - Value barrierRecipientCTAs = createCurrentCTAMask(fb); - Value barrierCTAMask = createRecipientCTAMask( - fb, barrierWriteRecipientsType, barrierRecipientCTAs); - Value selectedBarrier = - arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); - Value zeroRecipients = tti::createConstIntTensor( - fb, fb.getLoc(), 0, barrierWriteRecipientsType); - Value selectedRecipients = arith::SelectOp::create( - fb, selectedBarrier, barrierWriteRecipients, zeroRecipients); - Value effectRecipientCTAs = - reduceAll(fb, selectedRecipients); - Value recipientCTAs = - arith::OrIOp::create(fb, barrierRecipientCTAs, effectRecipientCTAs); - Value updatedRecipients = arith::SelectOp::create( - fb, selectedBarrier, zeroRecipients, barrierWriteRecipients); - createCTAScopedStoreScratchMemory( - fb, fb.getLoc(), barrierWriteRecipientsPtr, updatedRecipients, - barrierWriteRecipientsType, barrierRecipientCTAs); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, writeTrackingType); + convertAndBroadcast(fb, barriersEqBar, {3}, writeTrackingType); + Value barrierCTAMask = createCTASetMask( + fb, writeTrackingType, /*dim=*/2, createCurrentCTAMask(fb)); + barriersEqBar = + arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeTrackingType); Value updated = arith::SelectOp::create(fb, barriersEqBar, zero, writeTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, - updated, writeTrackingType, - recipientCTAs); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeTrackingPtr, + updated, writeTrackingType, + barrierCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -1995,14 +2004,18 @@ void FunctionBuilder::createClearBarrierReadTrackingCall( Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, readTrackingType); + convertAndBroadcast(fb, barriersEqBar, {3}, readTrackingType); + Value barrierCTAMask = createCTASetMask(fb, readTrackingType, /*dim=*/2, + createCurrentCTAMask(fb)); + barriersEqBar = + arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); Value zero = tti::createConstIntTensor(fb, fb.getLoc(), 0, readTrackingType); Value updated = arith::SelectOp::create(fb, barriersEqBar, zero, readTracking); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, - updated, readTrackingType, - createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readTrackingPtr, + updated, readTrackingType, + barrierCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2032,28 +2045,18 @@ void FunctionBuilder::createTransferVisibleWritesCall( auxData.writeTracking[(int)memType].at(insertPoint).value; auto writeTrackingType = cast( auxData.writeTracking[(int)memType].at(insertPoint).type); - Value barrierWriteRecipientsVal = - auxData.barrierWriteRecipients.at(insertPoint).value; - auto barrierWriteRecipientsType = cast( - auxData.barrierWriteRecipients.at(insertPoint).type); uint32_t length = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = {mbarOffset, - lengthVal, - pred, - threadMaskVal, - barriersVal, - writeVisibilityVal, - writeTrackingVal, - barrierWriteRecipientsVal}; + SmallVector args = {mbarOffset, lengthVal, pred, + threadMaskVal, barriersVal, writeVisibilityVal, + writeTrackingVal}; createCallToCachedFunction( b, "transfer_visible_writes", args, /*assertInfo=*/std::nullopt, - {barriersType, writeVisibilityType, writeTrackingType, - barrierWriteRecipientsType, (uint64_t)memType}, - [writeVisibilityType, writeTrackingType, barrierWriteRecipientsType]( - ImplicitLocOpBuilder &fb, Block *entryBlock) { + {barriersType, writeVisibilityType, writeTrackingType, (uint64_t)memType}, + [writeVisibilityType, writeTrackingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -2061,7 +2064,6 @@ void FunctionBuilder::createTransferVisibleWritesCall( Value barriers = entryBlock->getArgument(4); Value writeVisibilityPtr = entryBlock->getArgument(5); Value writeTrackingPtr = entryBlock->getArgument(6); - Value barrierWriteRecipientsPtr = entryBlock->getArgument(7); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -2070,36 +2072,24 @@ void FunctionBuilder::createTransferVisibleWritesCall( fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); Value writeTracking = tti::createLoadScratchMemory( fb, fb.getLoc(), writeTrackingPtr, writeTrackingType); - Value barrierWriteRecipients = tti::createLoadScratchMemory( - fb, fb.getLoc(), barrierWriteRecipientsPtr, - barrierWriteRecipientsType); Value descriptor = createBufferDescriptor(fb, mbarOffset, lengthVal); Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); - Value ctaId = tti::ExperimentalClusterCTAIdOp::create(fb, fb.getLoc()); + barriersEqBar = + convertAndBroadcast(fb, barriersEqBar, {3}, writeTrackingType); + Value currentCTA = createCurrentCTAMask(fb); Value barrierCTAMask = - createDimMask(fb, ctaId, barrierWriteRecipientsType, /*dim=*/0); - Value selectedBarrier = - arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); - Value zeroRecipients = tti::createConstIntTensor( - fb, fb.getLoc(), 0, barrierWriteRecipientsType); - Value selectedRecipients = arith::SelectOp::create( - fb, selectedBarrier, barrierWriteRecipients, zeroRecipients); - Value trackedRecipientCTAs = - reduceAll(fb, selectedRecipients); - Value currentCTA = arith::ShLIOp::create( - fb, arith::ConstantIntOp::create(fb, 1, 32), ctaId); - Value recipientCTAs = - arith::OrIOp::create(fb, currentCTA, trackedRecipientCTAs); + createCTASetMask(fb, writeTrackingType, /*dim=*/2, currentCTA); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, writeTrackingType); + arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); Value zeroTracking = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeTrackingType); Value trackingBuffers = arith::SelectOp::create( fb, barriersEqBar, writeTracking, zeroTracking); trackingBuffers = reduceLastDim(fb, trackingBuffers); - trackingBuffers = createConvertLayout( - fb, trackingBuffers, writeVisibilityType.getEncoding()); + trackingBuffers = reduce(fb, trackingBuffers, /*axis=*/2); + trackingBuffers = convertAndBroadcast(fb, trackingBuffers, {0, 1}, + writeVisibilityType); auto trackingBuffersType = cast(trackingBuffers.getType()); Value trackingBuffersOne = @@ -2116,9 +2106,11 @@ void FunctionBuilder::createTransferVisibleWritesCall( fb, trackingBuffers, threadMaskTensor, zeroVisibility); Value newVisibility = arith::OrIOp::create(fb, writeVisibility, trackingThreadBit); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, - newVisibility, writeVisibilityType, - recipientCTAs); + Value waitingCTAMask = + createCTASetMask(fb, writeVisibilityType, /*dim=*/2, currentCTA); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + newVisibility, writeVisibilityType, + waitingCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2179,23 +2171,34 @@ void FunctionBuilder::createTransferVisibleReadsCall( Value barriersEqBar = createCmpIntTensorScalar(fb, barriers, descriptor); barriersEqBar = - convertAndBroadcast(fb, barriersEqBar, {0, 2}, readTrackingType); + convertAndBroadcast(fb, barriersEqBar, {3}, readTrackingType); + Value currentCTA = createCurrentCTAMask(fb); + Value barrierCTAMask = + createCTASetMask(fb, readTrackingType, /*dim=*/2, currentCTA); + barriersEqBar = + arith::AndIOp::create(fb, barriersEqBar, barrierCTAMask); Value readTrackingZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, readTrackingType); Value trackingBar = arith::SelectOp::create( fb, barriersEqBar, readTracking, readTrackingZero); - trackingBar = reduceLastDim(fb, trackingBar); + trackingBar = reduce(fb, trackingBar, /*axis=*/3); + trackingBar = reduce(fb, trackingBar, /*axis=*/2); trackingBar = - convertAndBroadcast(fb, trackingBar, {0, 1}, readVisibilityType); + convertAndBroadcast(fb, trackingBar, {0, 1, 4}, readVisibilityType); Value readVisibilityOrTracking = arith::OrIOp::create(fb, readVisibility, trackingBar); Value threadColumnMask = - createThreadColumnMask(fb, threadMaskVal, readVisibilityType); + createThreadColumnMask(fb, threadMaskVal, readVisibilityType, + /*columnDim=*/3); + Value waitingCTAMask = + createCTASetMask(fb, readVisibilityType, /*dim=*/2, currentCTA); + threadColumnMask = + arith::AndIOp::create(fb, threadColumnMask, waitingCTAMask); Value newVisibility = arith::SelectOp::create( fb, threadColumnMask, readVisibilityOrTracking, readVisibility); - createCTAScopedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, - newVisibility, readVisibilityType, - createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + newVisibility, readVisibilityType, + waitingCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2205,7 +2208,7 @@ void FunctionBuilder::createTransferVisibleReadsCall( void FunctionBuilder::createVerifyWriteVisibilityCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs, bool allowNoWrite) { + Value effectCTAs, bool allowNoWrite) { if (auxData.buffers[(int)memType].empty() || auxData.writeVisibility[(int)memType].empty() || (auxData.hasNonTrivialAliasing[(int)memType] && @@ -2230,13 +2233,10 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( std::string uninitializedMessage = "Buffer being read before any write."; if (!operandName.empty()) uninitializedMessage += " Operand: " + operandName.str(); - auto verifyWriteResultType = cast( - writeVisibilityType.cloneWith(std::nullopt, b.getI1Type())); - AssertInfo assertInfo{message, verifyWriteResultType}; + AssertInfo assertInfo{message, b.getI1Type()}; Type aliasMatrixTypeBase; - auto buildVerifyWriteBody = [&writeVisibilityType, &aliasMatrixTypeBase, - verifyWriteResultType](bool useAlias, - bool allowNoWrite) { + auto buildVerifyWriteBody = [&writeVisibilityType, &aliasMatrixTypeBase]( + bool useAlias, bool allowNoWrite) { return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); @@ -2244,7 +2244,7 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( Value threadVal = entryBlock->getArgument(3); Value buffers = entryBlock->getArgument(4); Value writeVisibilityPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value effectCTAs = entryBlock->getArgument(6); Value aliasMatrix = useAlias ? entryBlock->getArgument(7) : Value(); Value writeVisibility = tti::createLoadScratchMemory( @@ -2257,10 +2257,10 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( cast(aliasMatrixTypeBase)); } buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, writeVisibilityType); - Value ctaMask = - createRecipientCTAMask(fb, writeVisibilityType, recipientCTAs); - buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); + convertAndBroadcast(fb, buffersEqBuf, {1}, writeVisibilityType); + Value relationMask = + createLeadCTAEffectMask(fb, writeVisibilityType, effectCTAs); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, relationMask); Value writeVisibilityZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); Value bufVisibility = arith::SelectOp::create( @@ -2289,8 +2289,8 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( // Alias rows are alternatives within a CTA, but every selected CTA must // have at least one initialized row. Value initializedCTAs = - reduceLastDim(fb, initializedRows); - Value selectedCTAs = reduceLastDim(fb, buffersEqBuf); + reduce(fb, initializedRows, /*axis=*/1); + Value selectedCTAs = reduce(fb, buffersEqBuf, /*axis=*/1); Value ctaOne = tti::createConstIntTensor( fb, fb.getLoc(), 1, cast(selectedCTAs.getType())); Value unmatchedCTAs = arith::XOrIOp::create(fb, selectedCTAs, ctaOne); @@ -2307,8 +2307,6 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( fb, result.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); Value predicatedWriteVisible = arith::SelectOp::create(fb, pred, result, vTrue); - predicatedWriteVisible = triton::SplatOp::create( - fb, verifyWriteResultType, predicatedWriteVisible); triton::ReturnOp::create(fb, predicatedWriteVisible); }; }; @@ -2318,12 +2316,11 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( aliasMatrixTypeBase = auxData.aliasMatrices[(int)memType].at(insertPoint).type; auto aliasMatrixType = cast(aliasMatrixTypeBase); - SmallVector args = {bufOffset, lengthVal, pred, - threadVal, buffersVal, writeVisibilityVal, - recipientCTAs, aliasMatrixVal}; + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, writeVisibilityVal, + effectCTAs, aliasMatrixVal}; if (!allowNoWrite) { - AssertInfo initializedAssertInfo{uninitializedMessage, - verifyWriteResultType}; + AssertInfo initializedAssertInfo{uninitializedMessage, b.getI1Type()}; createCallToCachedFunction( b, "verify_write_initialized", args, initializedAssertInfo, {buffersType, writeVisibilityType, aliasMatrixType, @@ -2335,12 +2332,11 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( {buffersType, writeVisibilityType, aliasMatrixType, (uint64_t)memType}, buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/true)); } else { - SmallVector args = {bufOffset, lengthVal, pred, - threadVal, buffersVal, writeVisibilityVal, - recipientCTAs}; + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, writeVisibilityVal, + effectCTAs}; if (!allowNoWrite) { - AssertInfo initializedAssertInfo{uninitializedMessage, - verifyWriteResultType}; + AssertInfo initializedAssertInfo{uninitializedMessage, b.getI1Type()}; createCallToCachedFunction( b, "verify_write_initialized_noalias", args, initializedAssertInfo, {buffersType, writeVisibilityType, (uint64_t)memType}, @@ -2356,7 +2352,7 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( void FunctionBuilder::createVerifyReadVisibilityCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value effectCTAs) { if (auxData.buffers[(int)memType].empty() || auxData.readVisibility[(int)memType].empty() || (auxData.hasNonTrivialAliasing[(int)memType] && @@ -2378,12 +2374,10 @@ void FunctionBuilder::createVerifyReadVisibilityCall( std::string message = "Buffer being accessed has outstanding reads"; if (!operandName.empty()) message += ". Operand: " + operandName.str(); - auto verifyReadResultType = cast( - readVisibilityType.cloneWith(std::nullopt, b.getI1Type())); - AssertInfo assertInfo{message, verifyReadResultType}; + AssertInfo assertInfo{message, b.getI1Type()}; Type aliasMatrixTypeBase; - auto buildVerifyReadBody = [&readVisibilityType, &aliasMatrixTypeBase, - verifyReadResultType](bool useAlias) { + auto buildVerifyReadBody = [&readVisibilityType, + &aliasMatrixTypeBase](bool useAlias) { return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); @@ -2391,7 +2385,7 @@ void FunctionBuilder::createVerifyReadVisibilityCall( Value threadVal = entryBlock->getArgument(3); Value buffers = entryBlock->getArgument(4); Value readVisibilityPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value effectCTAs = entryBlock->getArgument(6); Value aliasMatrix = useAlias ? entryBlock->getArgument(7) : Value(); Value readVisibility = tti::createLoadScratchMemory( @@ -2404,32 +2398,56 @@ void FunctionBuilder::createVerifyReadVisibilityCall( cast(aliasMatrixTypeBase)); } buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, readVisibilityType); - Value ctaMask = - createRecipientCTAMask(fb, readVisibilityType, recipientCTAs); - buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); + convertAndBroadcast(fb, buffersEqBuf, {1}, readVisibilityType); + Value bufferCTAMask = + createCTASetMask(fb, readVisibilityType, /*dim=*/0, effectCTAs); + Value relationMask = + createLeadCTAEffectMask(fb, readVisibilityType, effectCTAs); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, bufferCTAMask); Value readVisibilityZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); Value bufVisibility = arith::SelectOp::create( fb, buffersEqBuf, readVisibility, readVisibilityZero); - Value totalVisibility = reduceAll(fb, bufVisibility); + Value totalVisibility = reduce(fb, bufVisibility, + /*axis=*/3); + totalVisibility = reduce(fb, totalVisibility, /*axis=*/2); Value threadColumnMask = - createColumnMask(fb, threadVal, readVisibilityType); - Value bufThreadVisibility = arith::SelectOp::create( - fb, threadColumnMask, bufVisibility, readVisibilityZero); - bufThreadVisibility = reduceAll(fb, bufThreadVisibility); + createDimMask(fb, threadVal, readVisibilityType, /*dim=*/3); + Value accessorVisibility = arith::SelectOp::create( + fb, relationMask, bufVisibility, readVisibilityZero); + accessorVisibility = arith::SelectOp::create( + fb, threadColumnMask, accessorVisibility, readVisibilityZero); + accessorVisibility = + reduce(fb, accessorVisibility, /*axis=*/3); + auto accessorVisibilityType = + cast(accessorVisibility.getType()); + totalVisibility = convertAndBroadcast(fb, totalVisibility, {0, 1, 3}, + accessorVisibilityType); Value threadAndTotalVisibility = - arith::AndIOp::create(fb, bufThreadVisibility, totalVisibility); + arith::AndIOp::create(fb, accessorVisibility, totalVisibility); Value hasVisibility = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, threadAndTotalVisibility, totalVisibility); + Value selectedAccessors = + arith::AndIOp::create(fb, buffersEqBuf, relationMask); + selectedAccessors = + reduce(fb, selectedAccessors, /*axis=*/4); + selectedAccessors = + reduce(fb, selectedAccessors, /*axis=*/3); + selectedAccessors = convertAndBroadcast(fb, selectedAccessors, {0, 1, 2}, + accessorVisibilityType); + Value one = tti::createConstIntTensor( + fb, fb.getLoc(), 1, + cast(selectedAccessors.getType())); + Value unmatchedAccessors = + arith::XOrIOp::create(fb, selectedAccessors, one); + hasVisibility = + arith::OrIOp::create(fb, hasVisibility, unmatchedAccessors); hasVisibility = reduceAll(fb, hasVisibility); Value vTrue = arith::ConstantOp::create( fb, hasVisibility.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); Value predicatedHasVisibility = arith::SelectOp::create(fb, pred, hasVisibility, vTrue); - predicatedHasVisibility = triton::SplatOp::create( - fb, verifyReadResultType, predicatedHasVisibility); triton::ReturnOp::create(fb, predicatedHasVisibility); }; }; @@ -2439,17 +2457,17 @@ void FunctionBuilder::createVerifyReadVisibilityCall( aliasMatrixTypeBase = auxData.aliasMatrices[(int)memType].at(insertPoint).type; auto aliasMatrixType = cast(aliasMatrixTypeBase); - SmallVector args = {bufOffset, lengthVal, pred, - threadVal, buffersVal, readVisibilityVal, - recipientCTAs, aliasMatrixVal}; + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, readVisibilityVal, + effectCTAs, aliasMatrixVal}; createCallToCachedFunction( b, "verify_read_visibility", args, assertInfo, {buffersType, readVisibilityType, aliasMatrixType, (uint64_t)memType}, buildVerifyReadBody(/*useAlias=*/true)); } else { - SmallVector args = {bufOffset, lengthVal, pred, - threadVal, buffersVal, readVisibilityVal, - recipientCTAs}; + SmallVector args = {bufOffset, lengthVal, pred, + threadVal, buffersVal, readVisibilityVal, + effectCTAs}; createCallToCachedFunction( b, "verify_read_visibility_noalias", args, assertInfo, {buffersType, readVisibilityType, (uint64_t)memType}, @@ -2477,7 +2495,9 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, createCallToCachedFunction( b, "copy_write_visibility", args, /*assertInfo=*/std::nullopt, {writeVisibilityType, (uint64_t)memType}, - [writeVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [writeVisibilityType, + totalNumThreads = auxData.threadLayout.totalNumThreads]( + ImplicitLocOpBuilder &fb, Block *entryBlock) { Value sourceThread = entryBlock->getArgument(0); Value destMaskVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -2492,10 +2512,10 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value zeroTensor = tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); - uint64_t fullMask = tti::THREADS_BITMASK_SIZE == 64 + uint64_t fullMask = totalNumThreads == 64 ? std::numeric_limits::max() : (std::numeric_limits::max() >> - (64 - tti::THREADS_BITMASK_SIZE)); + (64 - totalNumThreads)); Value fullMaskVal = arith::ConstantIntOp::create(fb, fullMask, 64); Value destMaskElem = adjustIntegerWidth(fb, destMaskVal, elemType); Value fullMaskElem = adjustIntegerWidth(fb, fullMaskVal, elemType); @@ -2523,9 +2543,11 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, destMaskTensor, zeroTensor); Value updatedCurrent = arith::OrIOp::create(fb, cleared, replicated); - tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, - updatedCurrent, writeVisibilityType, - /*currentCTAOnly=*/true); + Value currentCTAMask = createCTASetMask( + fb, writeVisibilityType, /*dim=*/2, createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + updatedCurrent, writeVisibilityType, + currentCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2546,9 +2568,7 @@ void FunctionBuilder::createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, auto readVis = auxData.readVisibility[(int)memType].at(insertPoint); auto readVisibilityType = cast(readVis.type); Value sourceThreadVal = arith::ConstantIntOp::create(b, sourceThread, 32); - SmallVector args = {sourceThreadVal, - arith::ConstantIntOp::create(b, destMask, 64), - pred, readVis.value}; + SmallVector args = {sourceThreadVal, pred, readVis.value}; createCallToCachedFunction( b, "copy_read_visibility", args, /*assertInfo=*/std::nullopt, @@ -2556,9 +2576,8 @@ void FunctionBuilder::createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, [readVisibilityType, destMask](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value sourceThread = entryBlock->getArgument(0); - /*Value destMaskVal = entryBlock->getArgument(1);*/ - Value pred = entryBlock->getArgument(2); - Value readVisibilityPtr = entryBlock->getArgument(3); + Value pred = entryBlock->getArgument(1); + Value readVisibilityPtr = entryBlock->getArgument(2); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -2569,24 +2588,107 @@ void FunctionBuilder::createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); Value destMaskTensor = createThreadColumnMask( fb, arith::ConstantIntOp::create(fb, destMask, 64), - readVisibilityType); + readVisibilityType, /*columnDim=*/3); Value cleared = arith::SelectOp::create(fb, destMaskTensor, zeroTensor, readVisibility); Value sourceColumnMask = - createColumnMask(fb, sourceThread, readVisibilityType); + createDimMask(fb, sourceThread, readVisibilityType, /*dim=*/3); Value sourceColumn = arith::SelectOp::create( fb, sourceColumnMask, readVisibility, zeroTensor); - Value sourceVector = reduceLastDim(fb, sourceColumn); - Value broadcastRow = - convertAndBroadcast(fb, sourceVector, {0, 1}, readVisibilityType); + Value sourceVector = reduce(fb, sourceColumn, /*axis=*/3); + Value broadcastRow = convertAndBroadcast(fb, sourceVector, {0, 1, 2, 4}, + readVisibilityType); Value replicated = arith::SelectOp::create(fb, destMaskTensor, broadcastRow, zeroTensor); Value updated = arith::OrIOp::create(fb, cleared, replicated); + Value currentCTAMask = createCTASetMask( + fb, readVisibilityType, /*dim=*/2, createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + updated, readVisibilityType, + currentCTAMask); + + fb.setInsertionPointToEnd(thenBlock); + triton::ReturnOp::create(fb); + }); +} + +void FunctionBuilder::createPublishClusterVisibilityCall( + ImplicitLocOpBuilder &b, Value pred, MemType memType, + Operation *insertPoint) { + if (auxData.writeVisibility[(int)memType].empty() || + auxData.readVisibility[(int)memType].empty()) { + return; + } + if (!pred) + pred = arith::ConstantIntOp::create(b, 1, 1); + auto writeVis = auxData.writeVisibility[(int)memType].at(insertPoint); + auto readVis = auxData.readVisibility[(int)memType].at(insertPoint); + auto writeVisibilityType = cast(writeVis.type); + auto readVisibilityType = cast(readVis.type); + SmallVector args = {pred, writeVis.value, readVis.value}; + createCallToCachedFunction( + b, "publish_cluster_visibility", args, + /*assertInfo=*/std::nullopt, + {writeVisibilityType, readVisibilityType, (uint64_t)memType}, + [writeVisibilityType, readVisibilityType, + numBaseThreads = auxData.threadLayout.numBaseThreads]( + ImplicitLocOpBuilder &fb, Block *entryBlock) { + Value pred = entryBlock->getArgument(0); + Value writeVisibilityPtr = entryBlock->getArgument(1); + Value readVisibilityPtr = entryBlock->getArgument(2); + + auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); + fb.setInsertionPointToStart(ifBlock); + + Value writeVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityType); + Value readVisibility = tti::createLoadScratchMemory( + fb, fb.getLoc(), readVisibilityPtr, readVisibilityType); + + // Cluster barriers publish generic-proxy synchronous work. Base-thread + // visibility distinguishes those facts from async-only TMA/TC/CLC + // effects, which are published by their own completion path. + uint64_t baseThreadMask = (1ULL << numBaseThreads) - 1; + Value baseMask = tti::createConstIntTensor( + fb, fb.getLoc(), baseThreadMask, writeVisibilityType); + Value zeroWrites = + tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType); + Value hasBaseWrite = arith::CmpIOp::create( + fb, arith::CmpIPredicate::ne, + arith::AndIOp::create(fb, writeVisibility, baseMask), zeroWrites); + Value syncWrites = arith::SelectOp::create(fb, hasBaseWrite, + writeVisibility, zeroWrites); + Value writesForCluster = + reduce(fb, syncWrites, /*axis=*/2); + writesForCluster = convertAndBroadcast(fb, writesForCluster, {0, 1}, + writeVisibilityType); + Value newWriteVisibility = + arith::OrIOp::create(fb, writeVisibility, writesForCluster); + tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + newWriteVisibility, writeVisibilityType, + /*currentCTAOnly=*/false); + + Value readBaseMask = tti::createConstIntTensor( + fb, fb.getLoc(), baseThreadMask, readVisibilityType); + Value zeroReads = + tti::createConstIntTensor(fb, fb.getLoc(), 0, readVisibilityType); + Value hasBaseRead = arith::CmpIOp::create( + fb, arith::CmpIPredicate::ne, + arith::AndIOp::create(fb, readVisibility, readBaseMask), zeroReads); + Value syncReads = + arith::SelectOp::create(fb, hasBaseRead, readVisibility, zeroReads); + Value readsForCluster = reduce(fb, syncReads, /*axis=*/4); + readsForCluster = reduce(fb, readsForCluster, + /*axis=*/2); + readsForCluster = convertAndBroadcast(fb, readsForCluster, {0, 1, 3}, + readVisibilityType); + Value newReadVisibility = + arith::OrIOp::create(fb, readVisibility, readsForCluster); tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, - updated, readVisibilityType, - /*currentCTAOnly=*/true); + newReadVisibility, readVisibilityType, + /*currentCTAOnly=*/false); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2623,8 +2725,6 @@ void FunctionBuilder::createStageAccessForCommitCall( Value buffers = entryBlock->getArgument(4); Value outstandingCommitsPtr = entryBlock->getArgument(5); - (void)threadVal; - auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -2632,18 +2732,20 @@ void FunctionBuilder::createStageAccessForCommitCall( fb, fb.getLoc(), outstandingCommitsPtr, commitsType); Value descriptor = createBufferDescriptor(fb, bufOffset, lengthVal); Value buffersEqBuf = createCmpIntTensorScalar(fb, buffers, descriptor); - buffersEqBuf = - convertAndBroadcast(fb, buffersEqBuf, {0, 1}, commitsType); - Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, {1}, commitsType); + Value ctaMask = createCTASetMask(fb, commitsType, /*dim=*/0, + createCurrentCTAMask(fb)); + buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); + Value threadColumnMask = + createDimMask(fb, threadVal, commitsType, /*dim=*/2); Value bufAndThread = arith::AndIOp::create(fb, buffersEqBuf, threadColumnMask); Value minusOne = tti::createConstIntTensor(fb, fb.getLoc(), -1, commitsType, true); Value updated = arith::SelectOp::create(fb, bufAndThread, minusOne, commits); - tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, - updated, commitsType, - /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + updated, commitsType, ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2683,7 +2785,10 @@ void FunctionBuilder::createCommitAccessesCall(ImplicitLocOpBuilder &b, fb, elementType, fb.getIntegerAttr(elementType, -1)); Value ones = tti::createConstIntTensor(fb, fb.getLoc(), 1, commitsType); - Value threadMask = createColumnMask(fb, threadVal, commitsType); + Value threadMask = createDimMask(fb, threadVal, commitsType, /*dim=*/2); + Value ctaMask = createCTASetMask(fb, commitsType, /*dim=*/0, + createCurrentCTAMask(fb)); + threadMask = arith::AndIOp::create(fb, threadMask, ctaMask); auto commitsGtZero = createCmpIntTensorScalar( fb, commits, zero, arith::CmpIPredicate::sgt); commitsGtZero = arith::AndIOp::create(fb, commitsGtZero, threadMask); @@ -2697,9 +2802,8 @@ void FunctionBuilder::createCommitAccessesCall(ImplicitLocOpBuilder &b, arith::AndIOp::create(fb, commitsEqMinusOne, threadMask); commits = arith::SelectOp::create(fb, commitsEqMinusOne, ones, commits); - tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, - commits, commitsType, - /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + commits, commitsType, ctaMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2751,7 +2855,12 @@ void FunctionBuilder::createClearOutstandingCommitsTransferWritesCall( auto elemIntType = cast(commitsType.getElementType()); Value outstandingNumElem = adjustIntegerWidth(fb, outstandingNumVal, elemIntType); - Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + Value threadColumnMask = + createDimMask(fb, threadVal, commitsType, /*dim=*/2); + Value commitCTAMask = createCTASetMask(fb, commitsType, /*dim=*/0, + createCurrentCTAMask(fb)); + threadColumnMask = + arith::AndIOp::create(fb, threadColumnMask, commitCTAMask); auto outstandingCommitsGtOutstandingNum = createCmpIntTensorScalar(fb, outstandingCommits, outstandingNumElem, arith::CmpIPredicate::sgt); @@ -2760,8 +2869,7 @@ void FunctionBuilder::createClearOutstandingCommitsTransferWritesCall( Value rowMask = reduceLastDim(fb, outstandingCommitsGtOutstandingNum); - rowMask = - createConvertLayout(fb, rowMask, writeVisibilityType.getEncoding()); + rowMask = convertAndBroadcast(fb, rowMask, {0, 1}, writeVisibilityType); Value transferMaskElem = adjustIntegerWidth( fb, transferMaskVal, cast(writeVisibilityType.getElementType())); @@ -2771,18 +2879,20 @@ void FunctionBuilder::createClearOutstandingCommitsTransferWritesCall( arith::OrIOp::create(fb, writeVisibility, transferMaskTensor); Value writeVisibilityUpdated = arith::SelectOp::create( fb, rowMask, writeVisibilityOrThreadBit, writeVisibility); - tti::createStoreScratchMemory( - fb, fb.getLoc(), writeVisibilityPtr, writeVisibilityUpdated, - writeVisibilityType, /*currentCTAOnly=*/true); + Value writeMask = createCTASetMask(fb, writeVisibilityType, /*dim=*/2, + createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + writeVisibilityUpdated, + writeVisibilityType, writeMask); Value outstandingCommitsZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, commitsType); outstandingCommits = arith::SelectOp::create(fb, outstandingCommitsGtOutstandingNum, outstandingCommitsZero, outstandingCommits); - tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, - outstandingCommits, commitsType, - /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + outstandingCommits, commitsType, + commitCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2834,7 +2944,12 @@ void FunctionBuilder::createClearOutstandingCommitsTransferReadsCall( auto elemIntType = cast(commitsType.getElementType()); Value outstandingNumElem = adjustIntegerWidth(fb, outstandingNumVal, elemIntType); - Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + Value threadColumnMask = + createDimMask(fb, threadVal, commitsType, /*dim=*/2); + Value commitCTAMask = createCTASetMask(fb, commitsType, /*dim=*/0, + createCurrentCTAMask(fb)); + threadColumnMask = + arith::AndIOp::create(fb, threadColumnMask, commitCTAMask); auto outstandingCommitsGtOutstandingNum = createCmpIntTensorScalar(fb, outstandingCommits, outstandingNumElem, arith::CmpIPredicate::sgt); @@ -2844,6 +2959,9 @@ void FunctionBuilder::createClearOutstandingCommitsTransferReadsCall( Value rowMask = reduceLastDim(fb, outstandingCommitsGtOutstandingNum); rowMask = convertAndBroadcast(fb, rowMask, {0, 1}, readVisibilityType); + Value cMask = createCTASetMask(fb, readVisibilityType, /*dim=*/4, + createCurrentCTAMask(fb)); + rowMask = arith::AndIOp::create(fb, rowMask, cMask); Value transferMaskElem = adjustIntegerWidth( fb, transferMaskVal, cast(readVisibilityType.getElementType())); @@ -2853,18 +2971,20 @@ void FunctionBuilder::createClearOutstandingCommitsTransferReadsCall( arith::OrIOp::create(fb, readVisibility, transferMaskTensor); Value readVisibilityUpdated = arith::SelectOp::create( fb, rowMask, readVisibilityOrThreadBit, readVisibility); - tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, - readVisibilityUpdated, readVisibilityType, - /*currentCTAOnly=*/true); + Value readMask = createCTASetMask(fb, readVisibilityType, /*dim=*/2, + createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + readVisibilityUpdated, + readVisibilityType, readMask); Value outstandingCommitsZero = tti::createConstIntTensor(fb, fb.getLoc(), 0, commitsType); outstandingCommits = arith::SelectOp::create(fb, outstandingCommitsGtOutstandingNum, outstandingCommitsZero, outstandingCommits); - tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, - outstandingCommits, commitsType, - /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + outstandingCommits, commitsType, + commitCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -2941,7 +3061,12 @@ void FunctionBuilder::createClearOutstandingCommitsTransferBothCall( auto elemIntType = cast(commitsType.getElementType()); Value outstandingNumElem = adjustIntegerWidth(fb, outstandingNumVal, elemIntType); - Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + Value threadColumnMask = + createDimMask(fb, threadVal, commitsType, /*dim=*/2); + Value commitCTAMask = createCTASetMask(fb, commitsType, /*dim=*/0, + createCurrentCTAMask(fb)); + threadColumnMask = + arith::AndIOp::create(fb, threadColumnMask, commitCTAMask); auto outstandingCommitsGtOutstandingNum = createCmpIntTensorScalar(fb, outstandingCommits, outstandingNumElem, arith::CmpIPredicate::sgt); @@ -2951,8 +3076,8 @@ void FunctionBuilder::createClearOutstandingCommitsTransferBothCall( // Update write visibility Value writeRowMask = reduceLastDim(fb, outstandingCommitsGtOutstandingNum); - writeRowMask = createConvertLayout(fb, writeRowMask, - writeVisibilityType.getEncoding()); + writeRowMask = + convertAndBroadcast(fb, writeRowMask, {0, 1}, writeVisibilityType); Value writeTransferMaskElem = adjustIntegerWidth( fb, transferMaskVal, cast(writeVisibilityType.getElementType())); @@ -2962,16 +3087,20 @@ void FunctionBuilder::createClearOutstandingCommitsTransferBothCall( arith::OrIOp::create(fb, writeVisibility, writeTransferMaskTensor); Value writeVisibilityUpdated = arith::SelectOp::create( fb, writeRowMask, writeVisibilityOrThreadBit, writeVisibility); - tti::createStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, - writeVisibilityUpdated, - writeVisibilityType, - /*currentCTAOnly=*/true); + Value writeMask = createCTASetMask(fb, writeVisibilityType, /*dim=*/2, + createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), writeVisibilityPtr, + writeVisibilityUpdated, + writeVisibilityType, writeMask); // Update read visibility Value readRowMask = reduceLastDim(fb, outstandingCommitsGtOutstandingNum); readRowMask = convertAndBroadcast(fb, readRowMask, {0, 1}, readVisibilityType); + Value cMask = createCTASetMask(fb, readVisibilityType, /*dim=*/4, + createCurrentCTAMask(fb)); + readRowMask = arith::AndIOp::create(fb, readRowMask, cMask); Value readTransferMaskElem = adjustIntegerWidth( fb, transferMaskVal, cast(readVisibilityType.getElementType())); @@ -2981,9 +3110,11 @@ void FunctionBuilder::createClearOutstandingCommitsTransferBothCall( arith::OrIOp::create(fb, readVisibility, readTransferMaskTensor); Value readVisibilityUpdated = arith::SelectOp::create( fb, readRowMask, readVisibilityOrThreadBit, readVisibility); - tti::createStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, - readVisibilityUpdated, readVisibilityType, - /*currentCTAOnly=*/true); + Value readMask = createCTASetMask(fb, readVisibilityType, /*dim=*/2, + createCurrentCTAMask(fb)); + createMaskedStoreScratchMemory(fb, fb.getLoc(), readVisibilityPtr, + readVisibilityUpdated, + readVisibilityType, readMask); // Clear outstanding commits once Value outstandingCommitsZero = @@ -2991,9 +3122,9 @@ void FunctionBuilder::createClearOutstandingCommitsTransferBothCall( outstandingCommits = arith::SelectOp::create(fb, outstandingCommitsGtOutstandingNum, outstandingCommitsZero, outstandingCommits); - tti::createStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, - outstandingCommits, commitsType, - /*currentCTAOnly=*/true); + createMaskedStoreScratchMemory(fb, fb.getLoc(), outstandingCommitsPtr, + outstandingCommits, commitsType, + commitCTAMask); fb.setInsertionPointToEnd(thenBlock); triton::ReturnOp::create(fb); @@ -3003,7 +3134,7 @@ void FunctionBuilder::createClearOutstandingCommitsTransferBothCall( void FunctionBuilder::createCheckOutstandingCommitsCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef pendingAccessType, Value pred, MemType memType, - CommitKind::Kind commitKind, Operation *insertPoint, Value recipientCTAs, + CommitKind::Kind commitKind, Operation *insertPoint, Value effectCTAs, bool excludeSelf) { if (auxData.buffers[(int)memType].empty() || auxData.commits[commitKind].empty() || @@ -3013,7 +3144,7 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( } ValueType buffers = auxData.buffers[(int)memType].at(insertPoint); ValueType outstandingCommits = auxData.commits[commitKind].at(insertPoint); - assert(thread < NUM_THREADS && + assert(thread < auxData.threadLayout.numBaseThreads && "Commit-count tracking must operate on base threads"); Value bufOffset = tti::ExperimentalMemDescToI32Op::create(b, buf); if (!pred) @@ -3025,13 +3156,10 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( std::string message = "Accessing buffer with pending access. Pending access type: " + pendingAccessType.str(); - auto checkCommitsResultType = cast( - commitsType.cloneWith(std::nullopt, b.getI1Type())); - AssertInfo assertInfo{message, checkCommitsResultType}; + AssertInfo assertInfo{message, b.getI1Type()}; Type aliasMatrixTypeBase; - auto buildCheckOutstandingCommitsBody = [&commitsType, &aliasMatrixTypeBase, - checkCommitsResultType]( + auto buildCheckOutstandingCommitsBody = [&commitsType, &aliasMatrixTypeBase]( bool useAlias, bool exclSelf) { return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); @@ -3040,7 +3168,7 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( Value threadVal = entryBlock->getArgument(3); Value buffers = entryBlock->getArgument(4); Value outstandingCommitsPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value effectCTAs = entryBlock->getArgument(6); Value aliasMatrix = useAlias ? entryBlock->getArgument(7) : Value(); Value outstandingCommits = tti::createLoadScratchMemory( @@ -3052,15 +3180,16 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( expandAliases(fb, buffersEqBuf, aliasMatrix, cast(aliasMatrixTypeBase)); } - buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, {0, 1}, commitsType); - Value ctaMask = createRecipientCTAMask(fb, commitsType, recipientCTAs); + buffersEqBuf = convertAndBroadcast(fb, buffersEqBuf, {1}, commitsType); + Value ctaMask = createCTASetMask(fb, commitsType, /*dim=*/0, effectCTAs); buffersEqBuf = arith::AndIOp::create(fb, buffersEqBuf, ctaMask); Value zeroTensor = tti::createConstIntTensor(fb, fb.getLoc(), 0, commitsType); Value selectedRows = arith::SelectOp::create( fb, buffersEqBuf, outstandingCommits, zeroTensor); if (exclSelf) { - Value threadColumnMask = createColumnMask(fb, threadVal, commitsType); + Value threadColumnMask = + createDimMask(fb, threadVal, commitsType, /*dim=*/2); selectedRows = arith::SelectOp::create(fb, threadColumnMask, zeroTensor, selectedRows); } @@ -3072,8 +3201,6 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( fb.getIntegerAttr(fb.getI1Type(), 1)); Value predicatedSelectedEqZero = arith::SelectOp::create(fb, pred, allSelectedEqZero, vTrue); - predicatedSelectedEqZero = triton::SplatOp::create( - fb, checkCommitsResultType, predicatedSelectedEqZero); triton::ReturnOp::create(fb, predicatedSelectedEqZero); }; @@ -3085,7 +3212,7 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( SmallVector args = {bufOffset, lengthVal, pred, threadVal, buffers.value, outstandingCommits.value, - recipientCTAs, aliasMatrix.value}; + effectCTAs, aliasMatrix.value}; std::string funcName = excludeSelf ? "check_outstanding_commits_excl_self" : "check_outstanding_commits"; createCallToCachedFunction( @@ -3096,7 +3223,7 @@ void FunctionBuilder::createCheckOutstandingCommitsCall( SmallVector args = {bufOffset, lengthVal, pred, threadVal, buffers.value, outstandingCommits.value, - recipientCTAs}; + effectCTAs}; std::string funcName = excludeSelf ? "check_outstanding_commits_excl_self_noalias" : "check_outstanding_commits_noalias"; diff --git a/lib/Dialect/TritonInstrument/IR/Utility.cpp b/lib/Dialect/TritonInstrument/IR/Utility.cpp index c07c398bb181..b821a4bcbbf1 100644 --- a/lib/Dialect/TritonInstrument/IR/Utility.cpp +++ b/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -29,54 +29,38 @@ DistributedEncodingTrait getWarpLocalEncoding(MLIRContext *ctx, unsigned warps, unsigned numCTAs, unsigned bitwidth) { assert(!shape.empty() && "Expected non-empty shape"); - auto dims = standardOutDimNames(ctx, shape.size()); auto kBlock = StringAttr::get(ctx, "block"); auto kWarp = StringAttr::get(ctx, "warp"); auto kLane = StringAttr::get(ctx, "lane"); auto kRegister = StringAttr::get(ctx, "register"); + constexpr int kMaxVectorLengthBits = 128; // A warp-local layout ensures each warp has a copy of the whole tensor, so - // reductions, layout conversions, etc. don't require shared memory. TODO: - // attempt to pick a decent coalesced layout, assuming the inner dimension is - // contiguous and the tensor is 16-byte aligned. However, pick the widest - // vector length to reduce the number of instructions, speeding up - // compilation. - // unsigned vecLen = kMaxVectorLengthBits / bitwidth; - - // Broadcast along blocks and warps. Use the innermost dimension for the - // lane/register mapping and keep the outer dimensions replicated. - auto lastDim = dims.back(); - auto repOrder = llvm::to_vector(llvm::seq(0, shape.size())); - auto trivialShape = SmallVector(shape.size(), 1); - auto llReg = LinearLayout::identity1D(1, kRegister, lastDim); - auto llLane = LinearLayout::identity1D(32, kLane, lastDim); - auto llWarp = LinearLayout::zeros1D(warps, kWarp, lastDim); - // ConSan's multi-CTA state is replicated in every CTA. The leading logical - // CTA dimension is therefore broadcast across blocks instead of split. - auto llBlock = LinearLayout::zeros1D(numCTAs, kBlock, dims.front()); - LinearLayout ll = identityStandardND(kRegister, trivialShape, repOrder) * - llReg * llLane * llWarp * llBlock; - SmallVector layoutShape(shape.size(), 1); - layoutShape.back() = ll.getTotalOutDimSize(); - ll = ll.reshapeOuts(standardOutDimPairs(ctx, layoutShape)); - - llvm::SmallDenseMap bounds; - for (auto [dim, size] : llvm::zip_equal(dims, shape)) - bounds.try_emplace(dim, size); - ll = ensureLayoutNotLargerThan(ll, bounds); - ll = ensureLayoutNotSmallerThan(ll, bounds); - + // reductions, layout conversions, etc. don't require shared memory. + // We pick a layout that vectorises global loads and stores. + auto dim = StringAttr::get(ctx, "dim0"); + int numel = product(shape); + auto nlanes = std::min(numel, 32); + auto nregs = numel / nlanes; + auto vec = std::min(kMaxVectorLengthBits / bitwidth, nregs); + nregs /= vec; + auto ll = LinearLayout::identity1D(vec, kRegister, dim) * + LinearLayout::identity1D(nlanes, kLane, dim) * + LinearLayout::zeros1D(32 / nlanes, kLane, dim) * + LinearLayout::zeros1D(warps, kWarp, dim) * + LinearLayout::zeros1D(numCTAs, kBlock, dim) * + LinearLayout::identity1D(nregs, kRegister, dim); + ll = ll.reshapeOuts(standardOutDimPairs(ctx, shape)); return LinearEncodingAttr::get(ctx, ll); } -std::pair -createBufferDescriptorsTensor(ImplicitLocOpBuilder &builder, MemType memType, - ArrayRef regions) { +ValueType createBufferDescriptorsTensor(ImplicitLocOpBuilder &builder, + MemType memType, + ArrayRef regions) { Region *region = builder.getInsertionBlock()->getParent(); int64_t size = regions.size(); - int64_t numCTAs = lookupNumCTAs(region->getParentOp()); assert(llvm::isPowerOf2_64(size) && "Expected power of 2"); - auto tensorType = getIntTensorType(region, {numCTAs, size}, 64); + auto tensorType = getIntTensorType(region, {size}, 64); SmallVector offsets; SmallVector lengths; offsets.reserve(size); @@ -139,9 +123,10 @@ bool hasCrossBufferAliasing(ArrayRef regions) { return false; } -Value createInitStateTensor(ImplicitLocOpBuilder &b, ArrayRef shape, - int bitWidth, int64_t initialValue, - FunctionBuilder &funcBuilder) { +ValueType createInitStateTensor(ImplicitLocOpBuilder &b, + ArrayRef shape, int bitWidth, + int64_t initialValue, + FunctionBuilder &funcBuilder) { auto type = getIntTensorType(b.getInsertionBlock()->getParent(), shape, bitWidth); Type elType = type.getElementType(); @@ -171,12 +156,12 @@ Value createInitStateTensor(ImplicitLocOpBuilder &b, ArrayRef shape, b.setInsertionPointToStart(ifBlock); funcBuilder.createFillGlobalTensorCall(b, alloc, type, cstInit); b.setInsertionPointToStart(thenBlock); - return alloc; + return {alloc, type}; } -Value createZeroInitStateTensor(ImplicitLocOpBuilder &b, - ArrayRef shape, int bitWidth, - FunctionBuilder &funcBuilder) { +ValueType createZeroInitStateTensor(ImplicitLocOpBuilder &b, + ArrayRef shape, int bitWidth, + FunctionBuilder &funcBuilder) { return createInitStateTensor(b, shape, bitWidth, 0, funcBuilder); } @@ -190,26 +175,17 @@ createAliasMatrixTensor(ImplicitLocOpBuilder &b, for (const auto &row : matrix) assert(row.size() == cols && "Expected square alias matrix"); - int64_t numCTAs = lookupNumCTAs(region->getParentOp()); auto type = getIntTensorType( - region, {numCTAs, static_cast(rows), static_cast(cols)}, + region, {static_cast(rows), static_cast(cols)}, /*bitWidth=*/1); - auto sliceType = RankedTensorType::get( - {static_cast(rows), static_cast(cols)}, b.getI1Type(), - SliceEncodingAttr::get( - b.getContext(), /*dim=*/0, - cast(type.getEncoding()))); SmallVector values; values.reserve(rows * cols); for (const auto &row : matrix) for (uint8_t v : row) values.emplace_back(/*numBits=*/1, v); - auto denseAttr = DenseElementsAttr::get(sliceType, values); - Value constValue = - arith::ConstantOp::create(b, b.getLoc(), sliceType, denseAttr); - constValue = expandOuterSlicedDim(b, b.getLoc(), constValue); - constValue = BroadcastOp::create(b, b.getLoc(), type, constValue); + auto denseAttr = DenseElementsAttr::get(type, values); + Value constValue = arith::ConstantOp::create(b, b.getLoc(), type, denseAttr); return cast>(constValue); } @@ -284,9 +260,11 @@ gpu::GlobalScratchAllocOp createThirdPartyScratchAlloc(OpBuilder &b, Location loc, Type ptrType, int64_t sizeInBytes, int64_t alignment, bool sharedClusterState) { - return gpu::GlobalScratchAllocOp::create( + auto alloc = gpu::GlobalScratchAllocOp::create( b, loc, ptrType, sizeInBytes, alignment, b.getUnitAttr(), sharedClusterState ? b.getUnitAttr() : UnitAttr()); + alloc->setDiscardableAttr("tt.divisibility", b.getI64IntegerAttr(alignment)); + return alloc; } void createAssertInThread(ImplicitLocOpBuilder &b, Value condition, @@ -460,6 +438,48 @@ FuncOp getEntryPoint(ModuleOp module) { return publicFuncs.front(); } +AuxDataMap::ThreadLayout getThreadLayout(ModuleOp module, + const ConSanTargetHooks *hooks) { + AuxDataMap::ThreadLayout layout; + bool hasTMA = false; + bool hasTC = false; + bool hasCLC = false; + + module.walk([&](Operation *op) { + if (auto wsOp = dyn_cast(op)) + layout.numBaseThreads = std::max( + layout.numBaseThreads, wsOp.getPartitionRegions().size() + 1); + if (auto wsOp = dyn_cast(op)) + layout.numBaseThreads = std::max( + layout.numBaseThreads, wsOp.getPartitionRegions().size() + 1); + hasTMA |= hooks->isTMAOp(op); + hasTC |= isa(op); + hasCLC |= hooks->isCLCOp(op); + }); + + assert(layout.numBaseThreads <= MAX_NUM_BASE_THREADS && + "ConSan waiting bitsets assume at most 16 base threads"); + layout.numBaseThreadSlots = llvm::PowerOf2Ceil(layout.numBaseThreads); + int nextThread = layout.numBaseThreads; + if (hasTMA) { + layout.tmaThreadOffset = nextThread; + nextThread += layout.numBaseThreads; + } + if (hasTC) { + layout.tcThreadOffset = nextThread; + nextThread += layout.numBaseThreads; + } + if (hasCLC) { + layout.clcThreadOffset = nextThread; + nextThread += layout.numBaseThreads; + } + layout.totalNumThreads = nextThread; + layout.numThreadSlots = llvm::PowerOf2Ceil(layout.totalNumThreads); + assert(layout.totalNumThreads <= 64 && + "ConSan thread bitsets are stored in i64 masks"); + return layout; +} + Region *AuxDataMap::RegionToValueMap::getEnclosingParitionOrFunctionRegion( Operation *op) { Region *region = op->getParentRegion(); @@ -493,6 +513,7 @@ LogicalResult AuxDataMap::populateAndPassToWarpSpecialize( SmallVector barrierRegions; getBuffersAndBarriers(module, bufRegions, barrierRegions); int numCTAs = lookupNumCTAs(module); + threadLayout = getThreadLayout(module, hooks); int captureCounter = 0; int64_t captureBytes = 0; @@ -515,14 +536,14 @@ LogicalResult AuxDataMap::populateAndPassToWarpSpecialize( buffers[iMemType].insert( entryRegion, - {createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])}); + createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])); // Buffer descriptors are rematerialized in the warp specialize region, // not passed as an argument. - createInWarpSpecialize( - entryPoint, buffers[iMemType], [&](ImplicitLocOpBuilder &b) { - return ValueType{ - createBufferDescriptorsTensor(b, memType, bufRegions[iMemType])}; - }); + createInWarpSpecialize(entryPoint, buffers[iMemType], + [&](ImplicitLocOpBuilder &b) { + return createBufferDescriptorsTensor( + b, memType, bufRegions[iMemType]); + }); int numBufs = bufRegions[iMemType].size(); hasNonTrivialAliasing[iMemType] = @@ -546,59 +567,46 @@ LogicalResult AuxDataMap::populateAndPassToWarpSpecialize( } writeVisibility[iMemType].insert( - entryRegion, {createZeroInitStateTensor(b, {numCTAs, numBufs}, 64, fb), - getIntTensorType(entryRegion, {numCTAs, numBufs}, 64)}); + entryRegion, + createZeroInitStateTensor(b, {numCTAs, numBufs, numCTAs}, 64, fb)); passValueToWarpSpecialize(writeVisibility[iMemType].at(entryRegion), writeVisibility[iMemType]); readVisibility[iMemType].insert( entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBufs, THREADS_BITMASK_SIZE}, - 64, fb), - getIntTensorType(entryRegion, {numCTAs, numBufs, THREADS_BITMASK_SIZE}, - 64)}); + createZeroInitStateTensor( + b, + {numCTAs, numBufs, numCTAs, threadLayout.numThreadSlots, numCTAs}, + 64, fb)); passValueToWarpSpecialize(readVisibility[iMemType].at(entryRegion), readVisibility[iMemType]); } if (!barrierRegions.empty()) { // Barriers allocations are in shared memory - barriers.insert(entryRegion, {createBufferDescriptorsTensor( - b, MemType::SHARED_MEM, barrierRegions)}); + barriers.insert(entryRegion, createBufferDescriptorsTensor( + b, MemType::SHARED_MEM, barrierRegions)); // Barriers allocations are rematerialized in the warp specialize region, // not passed as an argument. createInWarpSpecialize(entryPoint, barriers, [&](ImplicitLocOpBuilder &b) { - return ValueType{createBufferDescriptorsTensor(b, MemType::SHARED_MEM, - barrierRegions)}; + return createBufferDescriptorsTensor(b, MemType::SHARED_MEM, + barrierRegions); }); int numBarriers = barrierRegions.size(); - barrierStates.insert( - entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBarriers}, 64, fb), - getIntTensorType(entryRegion, {numCTAs, numBarriers}, 64)}); + barrierStates.insert(entryRegion, createZeroInitStateTensor( + b, {numCTAs, numBarriers}, 64, fb)); passValueToWarpSpecialize(barrierStates.at(entryRegion), barrierStates); // Deadlock detection aux data over [cta, barrier]: waiting // stores waiting flag and phase bits per thread (two bits per thread). waiting.insert( entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBarriers}, 32, fb), - getIntTensorType(entryRegion, {numCTAs, numBarriers}, 32)}); + createZeroInitStateTensor(b, {numCTAs, numBarriers, numCTAs}, 32, fb)); passValueToWarpSpecialize(waiting.at(entryRegion), waiting); activeMasks.insert(entryRegion, - {createInitStateTensor(b, {numCTAs}, 32, 1, fb), - getIntTensorType(entryRegion, {numCTAs}, 32)}); - passToWarpSpecialize(entryPoint, activeMasks.at(entryRegion), activeMasks, - captureCounter, captureBytes); - - barrierWriteRecipients.insert( - entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBarriers}, 32, fb), - getIntTensorType(entryRegion, {numCTAs, numBarriers}, 32)}); - passValueToWarpSpecialize(barrierWriteRecipients.at(entryRegion), - barrierWriteRecipients); - + createInitStateTensor(b, {numCTAs}, 32, 1, fb)); + passValueToWarpSpecialize(activeMasks.at(entryRegion), activeMasks); for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { int iMemType = (int)memType; // Create state tensors: @@ -606,18 +614,14 @@ LogicalResult AuxDataMap::populateAndPassToWarpSpecialize( if (numBufs > 0) { writeTracking[iMemType].insert( entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBufs, numBarriers}, 8, - fb), - getIntTensorType(entryRegion, {numCTAs, numBufs, numBarriers}, - 8)}); + createZeroInitStateTensor( + b, {numCTAs, numBufs, numCTAs, numBarriers}, 8, fb)); passValueToWarpSpecialize(writeTracking[iMemType].at(entryRegion), writeTracking[iMemType]); readTracking[iMemType].insert( entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBufs, numBarriers}, 64, - fb), - getIntTensorType(entryRegion, {numCTAs, numBufs, numBarriers}, - 64)}); + createZeroInitStateTensor( + b, {numCTAs, numBufs, numCTAs, numBarriers, numCTAs}, 64, fb)); passValueToWarpSpecialize(readTracking[iMemType].at(entryRegion), readTracking[iMemType]); } @@ -645,12 +649,11 @@ LogicalResult AuxDataMap::populateAndPassToWarpSpecialize( int numBufs = bufRegions[(int)MemType::SHARED_MEM].size(); if (numBufs == 0) return; - // NUM_THREADS instead of THREADS_BITMASK_SIZE as commit-count tracking - // operates on base threads. + // Commit-count tracking operates on base threads. commits[commitKind].insert( entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBufs, NUM_THREADS}, 8, fb), - getIntTensorType(entryRegion, {numCTAs, numBufs, NUM_THREADS}, 8)}); + createZeroInitStateTensor( + b, {numCTAs, numBufs, threadLayout.numBaseThreadSlots}, 8, fb)); passValueToWarpSpecialize(commits[commitKind].at(entryRegion), commits[commitKind]); }; diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index af849c94d7d3..2934843c9caf 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -87,36 +87,49 @@ std::optional maybeGetPartitionIdx(Operation *op) { return std::nullopt; } -int getCurrentThread(Operation *op, const ConSanTargetHooks *hooks) { +int getCurrentThread(Operation *op, const ConSanTargetHooks *hooks, + const AuxDataMap::ThreadLayout &threadLayout) { // Default partition is 0, other partitions are idx + 1 int thread = maybeGetPartitionIdx(op).value_or(-1) + 1; if (hooks->isTMAOp(op)) { - thread += TMA_THREAD_OFFSET; + assert(threadLayout.hasTMAThreads() && + "TMA thread class must exist when instrumenting a TMA op"); + thread += threadLayout.tmaThreadOffset; return thread; } if (isTensorCoreOp(op)) { - thread += TC_THREAD_OFFSET; + assert(threadLayout.hasTCThreads() && + "TC thread class must exist when instrumenting a tensor-core op"); + thread += threadLayout.tcThreadOffset; return thread; } if (hooks->isCLCOp(op)) { - thread += CLC_THREAD_OFFSET; + assert(threadLayout.hasCLCThreads() && + "CLC thread class must exist when instrumenting a CLC op"); + thread += threadLayout.clcThreadOffset; return thread; } return thread; } -int getBaseThread(int thread) { return thread % NUM_THREADS; } +int getBaseThread(int thread, const AuxDataMap::ThreadLayout &threadLayout) { + return thread % threadLayout.numBaseThreads; +} // Peer threads are the equivalent threads in the TMA, TC, CLC and normal // thread classes. // If a thread is a base thread, return the mask with the peers, otherwise // return the mask with the thread itself. -uint64_t getThreadPeersMask(int thread) { +uint64_t getThreadPeersMask(int thread, + const AuxDataMap::ThreadLayout &threadLayout) { uint64_t mask = 1ULL << thread; - if (thread < NUM_THREADS) { - mask |= 1ULL << (thread + TMA_THREAD_OFFSET); - mask |= 1ULL << (thread + TC_THREAD_OFFSET); - mask |= 1ULL << (thread + CLC_THREAD_OFFSET); + if (thread < threadLayout.numBaseThreads) { + if (threadLayout.hasTMAThreads()) + mask |= 1ULL << (thread + threadLayout.tmaThreadOffset); + if (threadLayout.hasTCThreads()) + mask |= 1ULL << (thread + threadLayout.tcThreadOffset); + if (threadLayout.hasCLCThreads()) + mask |= 1ULL << (thread + threadLayout.clcThreadOffset); } return mask; } @@ -232,9 +245,7 @@ SmallVector getTensorCoreBarrierBroadcastMasks(Operation *op) { return ttng::getCTABroadcastMasks(twoCTAs, commitDescs); } -Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op); - -Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { +Value getMemEffectCTAs(ImplicitLocOpBuilder &b, Operation *op) { if (auto tmaLoad = dyn_cast(op)) { if (tmaLoad.getMulticast()) return getMulticastRecipientCTAs(b, tmaLoad.getResult()); @@ -242,7 +253,7 @@ Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { } if (isa(op)) return allCTAsMask(b); - if (isTensorCoreOp(op)) + if (isa(op)) return getRecipientCTAsForBroadcastMasks( b, ttng::getCTABroadcastMasks(ttng::getModuleTwoCTAs(op), {})); return currentCTAMask(b); @@ -296,8 +307,8 @@ class ConcurrencySanitizerImpl { CriticalSectionListener listener; b.setListener(&listener); - int thread = getCurrentThread(op, hooks); - int baseThread = getBaseThread(thread); + int thread = getCurrentThread(op, hooks, auxData.threadLayout); + int baseThread = getBaseThread(thread, auxData.threadLayout); b.setLoc(op->getLoc()); b.setInsertionPoint(op); if (isa(op)) { @@ -326,7 +337,8 @@ class ConcurrencySanitizerImpl { if (!partitionRegions.empty()) { uint64_t destMask = 0; for (Region *region : partitionRegions) - destMask |= getThreadPeersMask(region->getRegionNumber() + 1); + destMask |= getThreadPeersMask(region->getRegionNumber() + 1, + auxData.threadLayout); if (destMask) { for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { funcBuilder.createCopyWriteVisibilityCall(b, thread, destMask, @@ -364,28 +376,40 @@ class ConcurrencySanitizerImpl { } if (auto asyncWaitOp = dyn_cast(op)) { funcBuilder.createClearOutstandingCommitsTransferWritesCall( - b, baseThread, getThreadPeersMask(thread), asyncWaitOp.getNum(), - nullptr, CommitKind::AsyncCp, MemType::SHARED_MEM, op); + b, baseThread, getThreadPeersMask(thread, auxData.threadLayout), + asyncWaitOp.getNum(), nullptr, CommitKind::AsyncCp, + MemType::SHARED_MEM, op); } if (auto wgmmaWaitOp = dyn_cast(op)) { funcBuilder.createClearOutstandingCommitsTransferReadsCall( - b, baseThread, getThreadPeersMask(thread), + b, baseThread, getThreadPeersMask(thread, auxData.threadLayout), wgmmaWaitOp.getPendings(), nullptr, CommitKind::Wgmma, MemType::SHARED_MEM, op); } if (auto info = hooks->getWaitOpInfo(op)) { if (info->transferWrites && info->transferReads) { funcBuilder.createClearOutstandingCommitsTransferBothCall( - b, baseThread, getThreadPeersMask(thread), info->pendingCount, - nullptr, info->commitKind, MemType::SHARED_MEM, op); + b, baseThread, getThreadPeersMask(thread, auxData.threadLayout), + info->pendingCount, nullptr, info->commitKind, + MemType::SHARED_MEM, op); } else if (info->transferWrites) { funcBuilder.createClearOutstandingCommitsTransferWritesCall( - b, baseThread, getThreadPeersMask(thread), info->pendingCount, - nullptr, info->commitKind, MemType::SHARED_MEM, op); + b, baseThread, getThreadPeersMask(thread, auxData.threadLayout), + info->pendingCount, nullptr, info->commitKind, + MemType::SHARED_MEM, op); } else if (info->transferReads) { funcBuilder.createClearOutstandingCommitsTransferReadsCall( - b, baseThread, getThreadPeersMask(thread), info->pendingCount, - nullptr, info->commitKind, MemType::SHARED_MEM, op); + b, baseThread, getThreadPeersMask(thread, auxData.threadLayout), + info->pendingCount, nullptr, info->commitKind, + MemType::SHARED_MEM, op); + } + } + if (auto clusterBarrier = dyn_cast(op)) { + if (!clusterBarrier.getRelaxed()) { + b.setInsertionPointAfter(op); + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) + funcBuilder.createPublishClusterVisibilityCall(b, nullptr, memType, + op); } } @@ -434,9 +458,11 @@ class ConcurrencySanitizerImpl { tti::ExperimentalLockAcquireOp::create(wb, lock, pred); for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { funcBuilder.createTransferVisibleWritesCall( - wb, alloc, getThreadPeersMask(thread), pred, memType, op); + wb, alloc, getThreadPeersMask(thread, auxData.threadLayout), pred, + memType, op); funcBuilder.createTransferVisibleReadsCall( - wb, alloc, getThreadPeersMask(thread), pred, memType, op); + wb, alloc, getThreadPeersMask(thread, auxData.threadLayout), pred, + memType, op); } funcBuilder.createClearWaitingCall(wb, alloc, baseThread, pred, op); tti::ExperimentalLockReleaseOp::create(wb, lock, pred); @@ -444,7 +470,7 @@ class ConcurrencySanitizerImpl { void instrumentMemEffects(ImplicitLocOpBuilder &b, Operation *op, int thread, tti::FunctionBuilder &funcBuilder) { - int baseThread = getBaseThread(thread); + int baseThread = getBaseThread(thread, auxData.threadLayout); std::optional opInfo = hooks->getMemEffectsOpInfo(op); if (!opInfo) { return; @@ -452,7 +478,7 @@ class ConcurrencySanitizerImpl { Value pred = opInfo->pred; Value issuerCTAPred = hooks->getIssuerCTAPred(b, op); pred = tti::maybeAnd(b, pred, issuerCTAPred); - Value effectRecipientCTAs = getMemEffectRecipientCTAs(b, op); + Value effectCTAs = getMemEffectCTAs(b, op); for (auto effect : opInfo->operandEffects) { Value buf = effect.buf; auto bufType = cast(buf.getType()); @@ -464,12 +490,13 @@ class ConcurrencySanitizerImpl { // For op that is reading, we only need to check if anything else // is writing to the same buffer. addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, - thread, effect.operandName, effectRecipientCTAs, + thread, effect.operandName, effectCTAs, /*allowNoWrite=*/false, opInfo->commitKind); if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { funcBuilder.createSetReadVisibilityCall( - b, buf, effect.length, getThreadPeersMask(thread), pred, memType, - op, effectRecipientCTAs); + b, buf, effect.length, + getThreadPeersMask(thread, auxData.threadLayout), pred, memType, + op, effectCTAs); } if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::CommitCount) { @@ -483,21 +510,22 @@ class ConcurrencySanitizerImpl { // Op is writing to the buffer, we need to check if anything else // is reading or writing to the same buffer. addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, - thread, effect.operandName, effectRecipientCTAs, + thread, effect.operandName, effectCTAs, /*allowNoWrite=*/true, opInfo->commitKind); addReadChecks(b, funcBuilder, op, buf, effect.length, pred, memType, - thread, effect.operandName, effectRecipientCTAs, + thread, effect.operandName, effectCTAs, opInfo->commitKind); if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { funcBuilder.createSetWriteVisibilityCall( - b, buf, effect.length, getThreadPeersMask(thread), pred, memType, - op, effectRecipientCTAs); - funcBuilder.createClearWriteTrackingCall( - b, buf, effect.length, pred, memType, op, effectRecipientCTAs); - funcBuilder.createClearReadVisibilityCall( - b, buf, effect.length, pred, memType, op, effectRecipientCTAs); - funcBuilder.createClearReadTrackingCall( - b, buf, effect.length, pred, memType, op, effectRecipientCTAs); + b, buf, effect.length, + getThreadPeersMask(thread, auxData.threadLayout), pred, memType, + op, effectCTAs); + funcBuilder.createClearWriteTrackingCall(b, buf, effect.length, pred, + memType, op, effectCTAs); + funcBuilder.createClearReadVisibilityCall(b, buf, effect.length, pred, + memType, op, effectCTAs); + funcBuilder.createClearReadTrackingCall(b, buf, effect.length, pred, + memType, op, effectCTAs); } if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::CommitCount) { @@ -535,8 +563,7 @@ class ConcurrencySanitizerImpl { memType = MemType::SHARED_MEM; funcBuilder.createTrackBarrierWriteForBufferCall( b, barrier, effect.buf, effect.length, combinedPred, memType, op, - recipientCTAs, effectRecipientCTAs, - barrierInfo.diagonalEffectRecipientCTAs); + recipientCTAs, effectCTAs); } } if (barrierInfo.count > 0 || barrierInfo.txCount != 0) { @@ -560,11 +587,11 @@ class ConcurrencySanitizerImpl { tti::FunctionBuilder &funcBuilder, Operation *op, Value buf, uint32_t length, Value pred, MemType memType, int thread, const std::string &operandName, - Value recipientCTAs, bool allowNoWrite, + Value effectCTAs, bool allowNoWrite, CommitKind::Kind opCommitKind = CommitKind::None) { funcBuilder.createVerifyWriteVisibilityCall(b, buf, length, thread, operandName, pred, memType, op, - recipientCTAs, allowNoWrite); + effectCTAs, allowNoWrite); // commit-num-based synchronization is only supported for shared memory if (memType == MemType::SHARED_MEM) { for (const auto &commitKindDesc : @@ -572,8 +599,9 @@ class ConcurrencySanitizerImpl { bool excludeSelf = (opCommitKind == commitKindDesc.kind && hooks->isOrderedCommitKind(opCommitKind)); funcBuilder.createCheckOutstandingCommitsCall( - b, buf, length, getBaseThread(thread), commitKindDesc.operationDesc, - pred, memType, commitKindDesc.kind, op, recipientCTAs, excludeSelf); + b, buf, length, getBaseThread(thread, auxData.threadLayout), + commitKindDesc.operationDesc, pred, memType, commitKindDesc.kind, + op, effectCTAs, excludeSelf); } } } @@ -581,10 +609,10 @@ class ConcurrencySanitizerImpl { void addReadChecks(ImplicitLocOpBuilder &b, tti::FunctionBuilder &funcBuilder, Operation *op, Value buf, uint32_t length, Value pred, MemType memType, int thread, - const std::string &operandName, Value recipientCTAs, + const std::string &operandName, Value effectCTAs, CommitKind::Kind opCommitKind = CommitKind::None) { funcBuilder.createVerifyReadVisibilityCall( - b, buf, length, thread, operandName, pred, memType, op, recipientCTAs); + b, buf, length, thread, operandName, pred, memType, op, effectCTAs); // commit-num-based synchronization is only supported for shared memory if (memType == MemType::SHARED_MEM) { for (const auto &commitKindDesc : @@ -592,8 +620,9 @@ class ConcurrencySanitizerImpl { bool excludeSelf = (opCommitKind == commitKindDesc.kind && hooks->isOrderedCommitKind(opCommitKind)); funcBuilder.createCheckOutstandingCommitsCall( - b, buf, length, getBaseThread(thread), commitKindDesc.operationDesc, - pred, memType, commitKindDesc.kind, op, recipientCTAs, excludeSelf); + b, buf, length, getBaseThread(thread, auxData.threadLayout), + commitKindDesc.operationDesc, pred, memType, commitKindDesc.kind, + op, effectCTAs, excludeSelf); } } } diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index dab399c463fa..c4df44267fcd 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -232,8 +232,6 @@ class TmemScratchManager { auto ptrTy = triton::getPointerType(storageElemTy); auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy, sizeInBytes, alignment); - allocOp->setDiscardableAttr("tt.divisibility", - rewriter.getI64IntegerAttr(alignment)); Value ptr = allocOp.getResult(); if (Value init = alloc.getSrc()) { @@ -383,8 +381,6 @@ Value createScratchAndStore(PatternRewriter &rewriter, Location loc, Value val, auto ptrTy = triton::getPointerType(storageTy.getElementType()); auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy, sizeInBytes, alignment); - allocOp->setDiscardableAttr("tt.divisibility", - rewriter.getI64IntegerAttr(alignment)); if (!storeFpSanScratchMemory(rewriter, loc, allocOp.getResult(), val, tensorTy)) return Value(); @@ -837,8 +833,6 @@ createOperandScratch(PatternRewriter &rewriter, Location loc, getScratchStorageElementType(memTy.getElementType())); auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy, sizeInBytes, alignment); - allocOp->setDiscardableAttr("tt.divisibility", - rewriter.getI64IntegerAttr(alignment)); Value ptr = allocOp.getResult(); if (!storeFpSanScratchMemory(rewriter, loc, ptr, fullVal, tensorTy)) return std::nullopt; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index eaf922354b38..e8c6080fd710 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -237,8 +237,7 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { {tryCancelOp.getMbarrier(), nullptr, /*count=*/0, MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, /*txCount=*/ - -static_cast(tti::getMemDescLength(tryCancelOp.getResult())), - /*diagonalEffectRecipientCTAs=*/true}); + -static_cast(tti::getMemDescLength(tryCancelOp.getResult()))}); info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, tryCancelOp.getResult()); } diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 79007665a434..4edd97a5563c 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -281,6 +281,48 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell") +def test_collapsed_wait_does_not_publish_peer_cta(device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_collapsed_wait_does_not_publish_peer_cta, (device, False, monkeypatch)) + assert_expected_cuda_failure(result.exc) + assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(a_desc, b_desc): + blocked_a: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], ((1, 0), )) + smem_a = ttgl.allocate_shared_memory(ttgl.float16, [256, 128], a_desc.layout) + smem_b = ttgl.allocate_shared_memory( + ttgl.float16, [128, 128], + ttgl.NVMMASharedLayout.get_default_for([128, 128], ttgl.float16, cga_layout=((0, 1), ))) + tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) + mbarrier.init(tma_bar, count=1) + mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) + tma.async_load(a_desc, [0, 0], tma_bar, smem_a) + tma.async_load(b_desc, [0, 0], tma_bar, smem_b) + mbarrier.wait(tma_bar, 0, deps=[smem_a, smem_b]) + val = smem_a.load(blocked_a) + smem_a.store(val) + + acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([128, 128], col_stride=1, cga_layout=((1, 0), ), + two_ctas=True) + acc = blackwell.allocate_tensor_memory(ttgl.float32, [256, 128], acc_layout) + blackwell.tcgen05_mma(smem_a, smem_b, acc, use_acc=False) + + a = torch.randn((256, 128), device=device, dtype=torch.float16) + b = torch.randn((128, 128), device=device, dtype=torch.float16) + a_layout = ttgl.NVMMASharedLayout.get_default_for([256, 128], ttgl.float16, cga_layout=((1, 0), )) + b_layout = ttgl.NVMMASharedLayout.get_default_for([128, 128], ttgl.float16, cga_layout=((0, 1), )) + a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(a, [256, 128], a_layout) + b_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(b, [128, 128], b_layout) + kernel[(1, )](a_desc, b_desc, num_warps=4, num_ctas=2) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell") @pytest.mark.parametrize("FAILURE", [True, False]) def test_clc_result_visibility(FAILURE, device, run_wrapper, monkeypatch, num_ctas): @@ -288,7 +330,8 @@ def test_clc_result_visibility(FAILURE, device, run_wrapper, monkeypatch, num_ct result = run_in_process(test_clc_result_visibility, (FAILURE, device, False, monkeypatch, num_ctas)) if FAILURE: assert_expected_cuda_failure(result.exc) - assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + assert ("Buffer being accessed has outstanding writes" in result.driver_stderr_output + or "Buffer being read before any write" in result.driver_stderr_output) else: assert result.exc is None assert result.driver_stderr_output == "" @@ -319,6 +362,74 @@ def kernel(out, FAILURE: ttgl.constexpr): kernel[(1, )](output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell") +def test_clc_double_try_cancel_result_overwrite(device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_clc_double_try_cancel_result_overwrite, (device, False, monkeypatch)) + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(): + cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1) + layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) + result = ttgl.allocate_shared_memory(ttgl.int64, [2], layout) + bars = mbarrier.allocate_mbarrier(batch=2) + mbarrier.init(bars.index(0), count=1) + mbarrier.init(bars.index(1), count=1) + + mbarrier.expect(bars.index(0), 16) + clc.try_cancel(result, bars.index(0)) + mbarrier.expect(bars.index(1), 16) + clc.try_cancel(result, bars.index(1)) + + mbarrier.wait(bars.index(0), 0) + mbarrier.wait(bars.index(1), 0) + + kernel[(1, )](num_warps=4, num_ctas=2) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell") +def test_clc_result_reuse_after_cluster_barrier(device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_clc_result_reuse_after_cluster_barrier, (device, False, monkeypatch)) + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(out): + cga_layout: ttgl.constexpr = [[0]] + layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) + clc_result = ttgl.allocate_shared_memory(ttgl.int64, [2], layout) + clc_bar = mbarrier.allocate_mbarrier() + mbarrier.init(clc_bar, count=1) + + mbarrier.expect(clc_bar, 16) + clc.try_cancel(clc_result, clc_bar) + mbarrier.wait(clc_bar, 0) + first = clc.load_result(clc_result) + ttgl.barrier(cluster=True) + + mbarrier.expect(clc_bar, 16) + clc.try_cancel(clc_result, clc_bar) + mbarrier.wait(clc_bar, 1) + second = clc.load_result(clc_result) + ttgl.store(out + ttgl.program_id(0), first.is_canceled() | second.is_canceled()) + + output = torch.empty((1, ), device=device, dtype=torch.bool) + kernel[(1, )](output, num_warps=4, num_ctas=2) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") def test_async_tma_multicast_kernel_reuse(device, run_wrapper, monkeypatch, num_ctas): if num_ctas == 1: @@ -620,6 +731,67 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") +@pytest.mark.parametrize("WAIT_LATEST", [True, False]) +def test_tma_wait_does_not_publish_overwritten_row(WAIT_LATEST, device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_tma_wait_does_not_publish_overwritten_row, + (WAIT_LATEST, device, False, monkeypatch, num_ctas)) + if WAIT_LATEST: + assert result.exc is None + assert result.driver_stderr_output == "" + else: + assert_expected_cuda_failure(result.exc) + assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(input_desc, out, WAIT_LATEST: ttgl.constexpr): + block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() + cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) + smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) + bar = mbarrier.allocate_mbarrier(batch=2) + mbarrier.init(bar.index(0), count=1) + mbarrier.init(bar.index(1), count=1) + + mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta) + tma.async_load(input_desc, [0, 0], bar.index(0), smem) + mbarrier.expect(bar.index(1), input_desc.nbytes_per_cta) + tma.async_load(input_desc, [0, 0], bar.index(1), smem) + + if WAIT_LATEST: + mbarrier.wait(bar.index(1), 0) + else: + mbarrier.wait(bar.index(0), 0) + + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + val = smem.load(blocked_layout) + out_m = ttgl.arange(0, block_m, ttgl.SliceLayout(1, blocked_layout))[:, None] + out_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :] + ttgl.store(out + out_m * XBLOCK + out_n, val) + + if WAIT_LATEST: + mbarrier.wait(bar.index(0), 0) + else: + mbarrier.wait(bar.index(1), 0) + + mbarrier.invalidate(bar.index(0)) + mbarrier.invalidate(bar.index(1)) + + block_m = XBLOCK.value * num_ctas + input = torch.randn((block_m, XBLOCK.value), device=device, dtype=torch.float16) + output = torch.empty_like(input) + shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, + cga_layout=default_cga_layout(num_ctas, 2)) + input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [block_m, XBLOCK.value], shared_layout) + kernel[(1, )](input_desc, output, WAIT_LATEST=WAIT_LATEST, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires ampere or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_async_copy(FAILURE, device, run_wrapper, monkeypatch, num_ctas): @@ -727,7 +899,7 @@ def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, TWO_CTAS, device, run_wrapper, mon elif MEM_ACCESS_KIND == "tmem_load": # tmem is being written by the tcgen05_mma assert ("Buffer being accessed has outstanding writes" in result.driver_stderr_output - or "Buffer being read before any write. Operand: B" in result.driver_stderr_output) + or "Buffer being read before any write." in result.driver_stderr_output) elif MEM_ACCESS_KIND == "tmem_store": # tmem is being written by the tcgen05_mma assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output @@ -1201,7 +1373,8 @@ def test_tma_tcgen05_mma_multicast_loop(FAILURE, device, run_wrapper, monkeypatc result = run_in_process(test_tma_tcgen05_mma_multicast_loop, (FAILURE, device, False, monkeypatch, num_ctas)) if FAILURE: assert_expected_cuda_failure(result.exc) - assert "Buffer being accessed has outstanding" in result.driver_stderr_output + assert ("Buffer being accessed has outstanding" in result.driver_stderr_output + or "Buffer being read before any write" in result.driver_stderr_output) else: assert result.exc is None assert result.driver_stderr_output == "" @@ -1270,6 +1443,133 @@ def kernel(a_desc, b_desc, FAILURE: ttgl.constexpr): kernel[(1, )](a_desc, b_desc, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer") +def test_tma_tcgen05_mma_missing_multicast(device, run_wrapper, monkeypatch, num_ctas): + if num_ctas != 4: + pytest.skip("Need 4 CTAs to exercise the missing tcgen05_mma multicast race") + if run_wrapper: + result = run_in_process(test_tma_tcgen05_mma_missing_multicast, (device, False, monkeypatch, num_ctas)) + assert_expected_cuda_failure(result.exc) + assert "Buffer being accessed has outstanding" in result.driver_stderr_output + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(a_desc, b_desc): + num_k_tiles: ttgl.constexpr = 4 + block_m: ttgl.constexpr = mma_block_m(ttgl.num_ctas()) + block_n: ttgl.constexpr = mma_block_n(ttgl.num_ctas()) + acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout( + [XBLOCK, XBLOCK], + col_stride=1, + cga_layout=mma_cga_layout(ttgl.num_ctas(), 2, True), + two_ctas=True, + ) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], a_desc.layout) + smemB = ttgl.allocate_shared_memory( + ttgl.float16, + [XBLOCK, block_n], + ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, + cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, True)), + ) + acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) + tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) + mbarrier.init(tma_bar, count=1) + mma_bar = mbarrier.allocate_mbarrier() + mbarrier.init(mma_bar, count=blackwell.tcgen05_mma_barrier_count([smemA, smemB], False, + acc.type.layout.two_ctas)) + + phase_tma = 0 + phase_mma = 0 + for k in range(num_k_tiles): + offs_k = k * XBLOCK + mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) + tma.async_load(a_desc, [0, offs_k], tma_bar, smemA, multicast=True) + tma.async_load(b_desc, [offs_k, 0], tma_bar, smemB, multicast=True) + mbarrier.wait(tma_bar, phase_tma, deps=[smemA, smemB]) + + # Missing multicast=True is the bug under test. The next iteration + # reuses smemA/smemB after a local completion wait. + blackwell.tcgen05_mma(smemA, smemB, acc, use_acc=k != 0, mbarriers=[mma_bar]) + mbarrier.wait(mma_bar, phase_mma, deps=[smemA, smemB]) + phase_tma = (phase_tma + 1) % 2 + phase_mma = (phase_mma + 1) % 2 + + mbarrier.invalidate(tma_bar) + mbarrier.invalidate(mma_bar) + + block_m = mma_block_m(num_ctas) + block_n = mma_block_n(num_ctas) + num_k_tiles = 4 + a = torch.randn((block_m, XBLOCK.value * num_k_tiles), device=device, dtype=torch.float16) + b = torch.randn((XBLOCK.value * num_k_tiles, block_n), device=device, dtype=torch.float16) + a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor( + a, [block_m, XBLOCK.value], + ttgl.NVMMASharedLayout.get_default_for([block_m, XBLOCK.value], ttgl.float16, + cga_layout=mma_cga_layout(num_ctas, 0, True))) + b_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor( + b, [XBLOCK.value, block_n], + ttgl.NVMMASharedLayout.get_default_for([XBLOCK.value, block_n], ttgl.float16, + cga_layout=mma_cga_layout(num_ctas, 1, True))) + kernel[(1, )](a_desc, b_desc, num_warps=4, num_ctas=num_ctas) + + +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer") +@pytest.mark.parametrize("OVERCOUNTED", [False, True]) +def test_tcgen5_commit_multicast_barrier_count(OVERCOUNTED, device, run_wrapper, monkeypatch): + if run_wrapper: + result = run_in_process(test_tcgen5_commit_multicast_barrier_count, (OVERCOUNTED, device, False, monkeypatch)) + if OVERCOUNTED: + assert_expected_cuda_failure(result.exc) + assert "Deadlock detected" in result.driver_stderr_output + else: + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(a_desc, b_desc, OVERCOUNTED: ttgl.constexpr): + block_m: ttgl.constexpr = 256 + block_n: ttgl.constexpr = 128 + smem_a = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], a_desc.layout) + smem_b = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], b_desc.layout) + acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1, cga_layout=((1, 0), ), + two_ctas=True) + acc = allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) + + tma_bar = mbarrier.allocate_mbarrier(two_ctas=True) + commit_bar = mbarrier.allocate_mbarrier() + count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smem_a, smem_b], True, acc.type.layout.two_ctas) + mbarrier.init(tma_bar, count=1) + mbarrier.init(commit_bar, count=count + OVERCOUNTED) + mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta) + tma.async_load(a_desc, [0, 0], tma_bar, smem_a, multicast=True) + tma.async_load(b_desc, [0, 0], tma_bar, smem_b, multicast=True) + mbarrier.wait(tma_bar, 0, deps=[smem_a, smem_b]) + mbarrier.invalidate(tma_bar) + + blackwell.tcgen05_mma(smem_a, smem_b, acc, use_acc=False, multicast=True) + blackwell.tcgen05_commit(commit_bar, descs=[smem_a, smem_b]) + mbarrier.wait(commit_bar, 0) + + a = torch.randn((256, XBLOCK.value), device=device, dtype=torch.float16) + b = torch.randn((XBLOCK.value, 128), device=device, dtype=torch.float16) + a_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor( + a, [256, XBLOCK.value], + ttgl.NVMMASharedLayout.get_default_for([256, XBLOCK.value], ttgl.float16, cga_layout=((1, 0), ))) + b_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor( + b, [XBLOCK.value, 128], + ttgl.NVMMASharedLayout.get_default_for([XBLOCK.value, 128], ttgl.float16, cga_layout=((0, 1), ))) + kernel[(1, )](a_desc, b_desc, OVERCOUNTED=OVERCOUNTED, num_warps=4, num_ctas=2) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper") @pytest.mark.parametrize("FAILURE", [True, False]) def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper, monkeypatch, num_ctas): @@ -2548,6 +2848,61 @@ def kernel(MISSING_BAR: ttgl.constexpr, OVERLAP: ttgl.constexpr): kernel[(1, )](MISSING_BAR=MISSING_BAR, OVERLAP=OVERLAP, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper") +def test_aliasing_tma_overwrite_clears_stale_write_visibility(device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_aliasing_tma_overwrite_clears_stale_write_visibility, + (device, False, monkeypatch, num_ctas)) + assert result.exc is None + assert result.driver_stderr_output == "" + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def writer(full: ttgl.constexpr, tail: ttgl.constexpr, input_desc, bar: ttgl.constexpr, + blocked_layout_wide: ttgl.constexpr): + block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() + vals = ttgl.full([block_m, XBLOCK * 2], 42.0, ttgl.float16, blocked_layout_wide) + full.store(vals) + mbarrier.expect(bar.index(0), input_desc.nbytes_per_cta) + tma.async_load(input_desc, [0, 0], bar.index(0), tail) + + @gluon.jit + def reader(tail: ttgl.constexpr, dummy: ttgl.constexpr, bar: ttgl.constexpr, blocked_layout: ttgl.constexpr): + mbarrier.wait(bar.index(0), phase=0) + val = tail.load(blocked_layout) + dummy.store(val) + + @gluon.jit + def kernel(input_desc): + block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() + cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, + cga_layout=cga_layout) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + blocked_layout_wide: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[2, XBLOCK], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, + 1], cga_layout=cga_layout) + full = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK * 2], smem_layout) + tail = full.slice(XBLOCK, XBLOCK, dim=1) + dummy = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout) + bar = mbarrier.allocate_mbarrier(batch=1) + mbarrier.init(bar.index(0), count=1) + ttgl.warp_specialize([(writer, (full, tail, input_desc, bar, blocked_layout_wide)), + (reader, (tail, dummy, bar, blocked_layout))], [4], [32]) + + block_m = XBLOCK.value * num_ctas + input = torch.randn((block_m, XBLOCK.value), device=device, dtype=torch.float16) + shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, + cga_layout=default_cga_layout(num_ctas, 2)) + input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [block_m, XBLOCK.value], shared_layout) + kernel[(1, )](input_desc, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_aliasing_tensor_visibility_outstanding_read(FAILURE, device, run_wrapper, monkeypatch, num_ctas): diff --git a/test/Conversion/tritoninstrument_to_llvm.mlir b/test/Conversion/tritoninstrument_to_llvm.mlir index 50d5ba34aaf2..b8114d52d126 100644 --- a/test/Conversion/tritoninstrument_to_llvm.mlir +++ b/test/Conversion/tritoninstrument_to_llvm.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20 -#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { // CHECK: global internal constant @tensor_constant_1([34359738368, 68719476736]) {addr_space = 0 : i32} : !llvm.array<2 x i64> @@ -9,14 +9,14 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { // CHECK-LABEL: @experimental_buffer_descriptors_tmem // CHECK: llvm.mlir.constant(4294967295 : i64) : i64 tt.func private @experimental_buffer_descriptors_tmem() { - tti.experimental_buffer_descriptors [0, 42], [8, 16], tensor_mem : tensor<1x2xi64, #blocked> + tti.experimental_buffer_descriptors [0, 42], [8, 16], tensor_mem : tensor<2xi64, #blocked> tt.return } } // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { // CHECK: global internal constant @tensor_constant_1([17179869184, 51539607552]) @@ -25,7 +25,7 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90"} { // CHECK-LABEL: @experimental_buffer_descriptors_shared // CHECK: llvm.mlir.constant(16777215 : i64) : i64 tt.func private @experimental_buffer_descriptors_shared() { - tti.experimental_buffer_descriptors [0, 42], [4, 12], shared_mem : tensor<1x2xi64, #blocked> + tti.experimental_buffer_descriptors [0, 42], [4, 12], shared_mem : tensor<2xi64, #blocked> tt.return } } diff --git a/test/TritonGPU/amd/amd-consan.mlir b/test/TritonGPU/amd/amd-consan.mlir index d1c90c566dee..a85b1709a149 100644 --- a/test/TritonGPU/amd/amd-consan.mlir +++ b/test/TritonGPU/amd/amd-consan.mlir @@ -7,19 +7,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @single_local_alloc tt.func public @single_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -37,19 +37,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @two_local_alloc tt.func public @two_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096], [{{.*}}], shared_mem : tensor<1x2xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096], [{{.*}}], shared_mem : tensor<2xi64 - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1024 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %1 = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -69,19 +69,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @three_local_alloc tt.func public @three_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<1x4xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<4xi64, - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2048 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %1 = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %2 = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -103,19 +103,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @three_sub_bufs tt.func public @three_sub_bufs() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<1x4xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<4xi64, - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2048 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> %1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -136,8 +136,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: #[[READ_BARS_L:.*]] = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [0, 1]}> // CHECK: @read_bars_alloc tt.func public @read_bars_alloc() { - // CHECK: %[[READ_BARS_G:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x4xI8(%[[READ_BARS_G]], %c0_i8 + // CHECK: %[[READ_BARS_G:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_BARS_G]], %c0_i8 %c0 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<4x1xi64, #shared1, #smem, mutable> @@ -157,9 +157,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_copy_global_to_local tt.func public @async_copy_global_to_local(%ptr: tensor<32x32x!tt.ptr, #blocked>) { - // CHECK: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x16xI8(%[[WRT_COMMITS_GLOB]], %c0_i8 + // CHECK: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRT_COMMITS_GLOB]], %c0_i8 // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility_noalias_nw1{{.*}}(%[[A_I64]] @@ -187,13 +187,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_copy_global_to_local_with_barriers tt.func public @async_copy_global_to_local_with_barriers(%ptr: tensor<32x32x!tt.ptr, #blocked>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr - // CHECK-DAG: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr + // CHECK-DAG: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr // CHECK: tt.call @__triton_consan_init_barrier_state @@ -242,7 +242,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @async_wait() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 1 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -280,7 +280,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @alias_matrix_shared tt.func public @alias_matrix_shared() { - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<2xi64 // CHECK-DAG: arith.constant dense : tensor<2x2xi1 %buf0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> %buf1 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> @@ -301,7 +301,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @alias_matrix_shared_indexed() { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [128, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [128, 128], shared_mem : tensor<2xi64 // CHECK-NOT: arith.constant dense<{{\[\[true, false\], \[false, true\]\]}}> : tensor<2x2xi1 %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32xf32, #shared, #smem, mutable> %buf0 = ttg.memdesc_index %smem[%c0_i32] : !ttg.memdesc<2x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> @@ -321,7 +321,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @alias_matrix_shared_subslice tt.func public @alias_matrix_shared_subslice() { - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [256, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [256, 128], shared_mem : tensor<2xi64 // CHECK-DAG: arith.constant dense : tensor<2x2xi1 %buf0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64xf32, #shared, #smem, mutable> %buf1 = ttg.memdesc_subslice %buf0 [32] : !ttg.memdesc<64xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> @@ -361,21 +361,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @amdg_wait_barrier tt.func public @amdg_wait_barrier() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 - // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 - // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -405,8 +405,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @amdg_arrive_barrier tt.func public @amdg_arrive_barrier() { - // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI32(%[[BSTATE_GLOB]], %c0_i32 + // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[BSTATE_GLOB]], %c0_i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tt.call @__triton_consan_verify_barrier_can_init @@ -465,20 +465,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local tt.func public @async_tdm_copy_global_to_local(%desc: !tt.tensordesc<32x32xf32>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 - // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -699,7 +699,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @amdg_async_wait() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 1 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -823,18 +823,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_allocation tt.func public @ws_allocation() { - // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64, - // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 2 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 - // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] + // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { // CHECK: tti.experimental_lock_acquire @@ -846,8 +845,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar } partition0(%arg1: !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) num_warps(4) { // CHECK: partition0 - // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64, - // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK: tti.experimental_lock_acquire // CHECK: tt.call @__triton_consan_verify_write_visibility // CHECK: tt.call @__triton_consan_set_read_visibility @@ -905,11 +904,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar amdg.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 2 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 - // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] + // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { ttg.warp_yield diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 433966623e1d..76d4eafeceaa 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -10,19 +10,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK-DAG: #[[BUFS_BARS_L:.*]] = #ttg.linear<{register = [], lane = {{\[}}[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], warp = [], block = []}> // CHECK: @single_local_alloc tt.func public @single_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64, #{{.*}}> + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64, #{{.*}}> - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -41,15 +41,15 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tti.experimental_cluster_cta_id : i32 // CHECK-LABEL: @single_local_alloc_multi_cta tt.func public @single_local_alloc_multi_cta() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<2x1xi64 - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T2x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1024 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T2x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T2x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T2x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 64 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 64 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -67,19 +67,19 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @two_local_alloc tt.func public @two_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096], [{{.*}}], shared_mem : tensor<1x2xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096], [{{.*}}], shared_mem : tensor<2xi64, - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1024 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %1 = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -99,19 +99,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @three_local_alloc tt.func public @three_local_alloc() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<1x4xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<4xi64, - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2048 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %1 = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %2 = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -133,19 +133,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @three_sub_bufs tt.func public @three_sub_bufs() { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<1x4xi64, + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0, 4096, 8192, 0], [{{.*}}], shared_mem : tensor<4xi64, - // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2048 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x4x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> %1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<3x32x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -166,8 +166,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: #[[READ_BARS_L:.*]] = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [0, 1]}> // CHECK: @read_bars_alloc tt.func public @read_bars_alloc() { - // CHECK: %[[READ_BARS_G:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x4xI8(%[[READ_BARS_G]], %c0_i8 + // CHECK: %[[READ_BARS_G:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_BARS_G]], %c0_i8 %c0 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 8192 : i32} : () -> !ttg.memdesc<4x1xi64, #shared1, #smem, mutable> @@ -189,8 +189,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: #[[BUFS_L:.*]] = #ttg.linear<{register = [], lane = {{\[}}[0], [0], [0], [0], [0]], warp = [], block = []}> // CHECK: @tmem_alloc tt.func public @tmem_alloc() { - // CHECK-DAG: %[[TMEM_BUFS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], tensor_mem : tensor<1x1xi64, #{{.*}}> - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [4096], [{{.*}}], shared_mem : tensor<1x1xi64, #{{.*}}> + // CHECK-DAG: %[[TMEM_BUFS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], tensor_mem : tensor<1xi64, #{{.*}}> + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [4096], [{{.*}}], shared_mem : tensor<1xi64, #{{.*}}> %0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -207,20 +207,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_copy_global_to_local tt.func public @async_tma_copy_global_to_local(%arg0: !tt.tensordesc<32x32xf32, #shared>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 - // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -265,9 +265,9 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #barrier, #smem, mutable> // CHECK: tti.experimental_cluster_cta_id // CHECK: arith.cmpi eq - // CHECK: %[[CLC_THREAD:.*]] = arith.constant 48 : i32 + // CHECK: %[[CLC_THREAD:.*]] = arith.constant 1 : i32 // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[CLC_THREAD]] - // CHECK: %[[CLC_MASK:.*]] = arith.constant 281474976710656 : i64 + // CHECK: %[[CLC_MASK:.*]] = arith.constant 2 : i64 // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[CLC_MASK]] // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer{{.*}}_I1 ttng.clc_try_cancel %result, %bar : !ttg.memdesc<2xi64, #shared_clc, #smem, mutable>, !ttg.memdesc<1xi64, #barrier, #smem, mutable> @@ -287,10 +287,10 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #smem = #ttg.shared_memory #blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CGALayout = [[0, 0]]}> module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { - // CHECK-LABEL: tt.func private @__triton_consan_check_outstanding_commits_noalias{{.*}}T2x1x16xI8 + // CHECK-LABEL: tt.func private @__triton_consan_check_outstanding_commits_noalias{{.*}}T2x1x1xI8 // CHECK-SAME: %arg6: i32 // CHECK-NOT: tti.experimental_cluster_cta_id - // CHECK: tt.splat %arg6 : i32 -> tensor<2x1x16xi32 + // CHECK: tt.splat %arg6 : i32 -> tensor<2x1x1xi32 // CHECK: arith.shrui // CHECK-LABEL: @outstanding_commits_multicast_tma_recipients tt.func public @outstanding_commits_multicast_tma_recipients( @@ -610,21 +610,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @wait_barrier tt.func public @wait_barrier(%arg0: !tt.tensordesc<32x32xf32, #shared>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64, #linear{{[0-9]*}}> + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64, #linear{{[0-9]*}}> - // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x64xI64(%[[READ_VISIBILITY_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_VISIBILITY_GLOB]], %c0_i64 - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64, #linear{{[0-9]*}}> + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, #linear{{[0-9]*}}> - // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI8(%[[WRITE_TRACKING_GLOB]], %c0_i8 + // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRITE_TRACKING_GLOB]], %c0_i8 - // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x1xI64(%[[READ_TRACKING_GLOB]], %c0_i64 + // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[READ_TRACKING_GLOB]], %c0_i64 %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -655,8 +655,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @arrive_barrier tt.func public @arrive_barrier(%arg0: !tt.tensordesc<32x32xf32, #shared>) { - // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[BSTATE_GLOB]], %c0_i64 + // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[BSTATE_GLOB]], %c0_i64 %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -720,38 +720,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_mma tt.func public @tcgen5_mma(%arg0: !tt.tensordesc<32x32xf32, #shared>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<1x2xi64 - // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1024 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], tensor_mem : tensor<1x1xi64 - // CHECK-DAG: %[[TM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64 - - // CHECK-DAG: %[[SM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[SM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<2xi64 + // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], tensor_mem : tensor<1xi64 + // CHECK-DAG: %[[TM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 + + // CHECK-DAG: %[[SM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[SM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + + // CHECK: ttng.init_barrier + // CHECK: arith.shli + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 + // CHECK: %[[TC_MASK:.*]] = arith.constant 2 : i64 // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 + // CHECK: %[[TC_MASK:.*]] = arith.constant 2 : i64 // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B]] : // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] - // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 + // CHECK: %[[TC_MASK:.*]] = arith.constant 2 : i64 // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : @@ -760,16 +762,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{.*}}, %[[TM_BUFS]], %{{[^,]+}} - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: tt.call @__triton_consan_verify_barrier_initialized + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR:.*]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_WRITE_VISIBILITY_GLOB]], %[[SM_WRITE_TRACKING_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %{{[^,]+}} - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_WRITE_VISIBILITY_GLOB]], %[[TM_WRITE_TRACKING_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %{{[^,]+}} // CHECK: ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][], {{.*}}, {{.*}}, %[[BAR]] @@ -795,38 +798,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @tcgen5_mma_lhs_in_tmem tt.func public @tcgen5_mma_lhs_in_tmem(%arg0: !tt.tensordesc<32x32xf32, #shared>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [32768], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 128], [{{.*}}], tensor_mem : tensor<1x2xi64 - // CHECK-DAG: %[[TM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1024 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64 - - // CHECK-DAG: %[[SM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[SM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[TM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [32768], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[SM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[SM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 128], [{{.*}}], tensor_mem : tensor<2xi64 + // CHECK-DAG: %[[TM_WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[BARRIERS:.*]] = tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64 + + // CHECK-DAG: %[[SM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[SM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[TM_READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + + // CHECK: ttng.init_barrier + // CHECK: arith.shli + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 + // CHECK: %[[TC_MASK:.*]] = arith.constant 2 : i64 // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A]] : // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[A_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[SM_BUFS]], %[[SM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 + // CHECK: %[[TC_MASK:.*]] = arith.constant 2 : i64 // CHECK: %[[B_I64:.*]] = tti.experimental_memdesc_to_i32 %[[B]] : // CHECK: tt.call @__triton_consan_set_read_visibility{{.*}}%[[B_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_MASK]], %[[SM_BUFS]], %[[SM_READ_VISIBILITY_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_verify_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_BIT]], %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] - // CHECK: %[[TC_MASK:.*]] = arith.constant 4294967296 : i64 + // CHECK: %[[TC_MASK:.*]] = arith.constant 2 : i64 // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_set_write_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TC_MASK]], %[[TM_BUFS]], %[[TM_WRITE_VISIBILITY_GLOB]] // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : @@ -835,16 +840,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{.*}}, %[[TM_BUFS]], %{{[^,]+}} - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: tt.call @__triton_consan_verify_barrier_initialized + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR:.*]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_WRITE_VISIBILITY_GLOB]], %[[SM_WRITE_TRACKING_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %{{[^,]+}} - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_WRITE_VISIBILITY_GLOB]], %[[TM_WRITE_TRACKING_GLOB]] - // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 + // CHECK: %[[TC_BIT:.*]] = arith.constant 1 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %{{[^,]+}} // CHECK: tt.call @__triton_consan_verify_barrier_arrive @@ -919,9 +925,9 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_copy_global_to_local tt.func public @async_copy_global_to_local(%ptr: tensor<128x128x!tt.ptr, #blocked>) { - // CHECK: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1x16xI8(%[[WRT_COMMITS_GLOB]], %c0_i8 + // CHECK: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[WRT_COMMITS_GLOB]], %c0_i8 // CHECK: %[[A_I64:.*]] = tti.experimental_memdesc_to_i32 %[[A:.*]] : // CHECK: tt.call @__triton_consan_verify_write_visibility_noalias_nw1{{.*}}(%[[A_I64]] @@ -950,13 +956,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_copy_global_to_local_with_barriers tt.func public @async_copy_global_to_local_with_barriers(%ptr: tensor<128x128x!tt.ptr, #blocked>) { - // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 - // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 512 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr + // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 + // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[READ_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[WRITE_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK-DAG: %[[READ_TRACKING_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr - // CHECK-DAG: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 16 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr + // CHECK-DAG: %[[WRT_COMMITS_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 1 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr // CHECK: tt.call @__triton_consan_init_barrier_state @@ -1030,7 +1036,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @async_commit_group() { // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 281479271743489 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 1 : i64 // CHECK: %[[OUTSTANDING_NUM:.*]] = arith.constant 42 : i32 // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_writes{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]], %[[OUTSTANDING_NUM]] %shmem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -1067,10 +1073,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @warp_group_dot tt.func public @warp_group_dot(%acc: tensor<128x128xf16, #mma>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<1x2xi64 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<2xi64 - // CHECK-DAG: %[[SM_WGMMA_WRITES_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x16xI8(%[[SM_WGMMA_WRITES_GLOB]], %c0_i8 + // CHECK-DAG: %[[SM_WGMMA_WRITES_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[SM_WGMMA_WRITES_GLOB]], %c0_i8 // CHECK: tt.call @__triton_consan_verify_write_visibility // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 @@ -1098,10 +1104,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @warp_group_dot_sync tt.func public @warp_group_dot_sync(%acc: tensor<128x128xf16, #mma>) { - // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<1x2xi64 + // CHECK-DAG: %[[SM_BUFS:.*]] = tti.experimental_buffer_descriptors [0, 32768], [{{.*}}], shared_mem : tensor<2xi64 - // CHECK-DAG: %[[SM_WGMMA_WRITES_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 32 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x2x16xI8(%[[SM_WGMMA_WRITES_GLOB]], %c0_i8 + // CHECK-DAG: %[[SM_WGMMA_WRITES_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 2 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}(%[[SM_WGMMA_WRITES_GLOB]], %c0_i8 // CHECK: "before_dot" // CHECK-NOT: tt.call @__triton_consan_stage_access_for_commit @@ -1274,18 +1280,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @ws_allocation tt.func public @ws_allocation(%arg0: !tt.tensordesc<32x32xf32, #shared>) { - // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64, - // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 2 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 - // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] + // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { // CHECK: tti.experimental_lock_acquire @@ -1297,8 +1302,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar } partition0(%arg1: !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<1xi64, #shared1, #smem, mutable>) num_warps(4) { // CHECK: partition0 - // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1x1xi64, - // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1x1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [65536], [{{.*}}], shared_mem : tensor<1xi64, + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK: tti.experimental_lock_acquire // CHECK: tt.call @__triton_consan_verify_write_visibility // CHECK: tt.call @__triton_consan_set_read_visibility @@ -1359,11 +1364,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tti.experimental_lock_acquire // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 + // CHECK: %[[THREAD_MASK:.*]] = arith.constant 2 : i64 // CHECK: tt.call @__triton_consan_copy_write_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] // CHECK: %[[THREAD_BIT:.*]] = arith.constant 0 : i32 - // CHECK: %[[THREAD_MASK:.*]] = arith.constant 562958543486978 : i64 - // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]], %[[THREAD_MASK]] + // CHECK: tt.call @__triton_consan_copy_read_visibility{{.*}}(%[[THREAD_BIT]] ttg.warp_specialize(%smem, %bar) attributes {actualRegisters = array, allocation.offset = 512 : i32, requestedRegisters = array, warpGroupStartIds = array} default { ttg.warp_yield @@ -1429,7 +1433,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @alias_matrix_shared tt.func public @alias_matrix_shared() { - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<2xi64 // CHECK-DAG: arith.constant dense : tensor<2x2xi1 %buf0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> %buf1 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> @@ -1452,7 +1456,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar tt.func public @alias_matrix_shared_indexed() { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [128, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [128, 128], shared_mem : tensor<2xi64 // CHECK-NOT: arith.constant dense<{{\[\[true, false\], \[false, true\]\]}}> : tensor<2x2xi1 %smem = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<2x32xf32, #shared, #smem, mutable> %buf0 = ttg.memdesc_index %smem[%c0_i32] : !ttg.memdesc<2x32xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> @@ -1474,7 +1478,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @alias_matrix_shared_subslice tt.func public @alias_matrix_shared_subslice() { - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [256, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 128], [256, 128], shared_mem : tensor<2xi64 // CHECK-DAG: arith.constant dense : tensor<2x2xi1 %buf0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64xf32, #shared, #smem, mutable> %buf1 = ttg.memdesc_subslice %buf0 [32] : !ttg.memdesc<64xf32, #shared, #smem, mutable> -> !ttg.memdesc<32xf32, #shared, #smem, mutable> @@ -1494,7 +1498,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @alias_matrix_tensor tt.func public @alias_matrix_tensor() { - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32, 64, 0], [64, 32, 64, 0], tensor_mem : tensor<1x4xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 32, 64, 0], [64, 32, 64, 0], tensor_mem : tensor<4xi64 // CHECK-DAG: arith.constant dense<{{\[\[true, true, false, false\], \[true, true, false, false\], \[false, false, true, false\], \[false, false, false, false\]\]}}> : tensor<4x4xi1 %buf0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> %buf1 = ttng.tmem_alloc {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> @@ -1515,9 +1519,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @alias_matrix_mixed tt.func public @alias_matrix_mixed() { - // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<1x2xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0, 16], [128, 128], shared_mem : tensor<2xi64 // CHECK-DAG: arith.constant dense : tensor<2x2xi1 - // CHECK-DAG: tti.experimental_buffer_descriptors [0], [64], tensor_mem : tensor<1x1xi64 + // CHECK-DAG: tti.experimental_buffer_descriptors [0], [64], tensor_mem : tensor<1xi64 %smem0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> %smem1 = ttg.local_alloc {allocation.offset = 16 : i32} : () -> !ttg.memdesc<32xf32, #shared, #smem, mutable> %tmem0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable> diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 2f6531784942..2694bff43645 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -514,9 +514,11 @@ def make_cubin(self, src, metadata, opt, capability): # Accept more ptxas options if provided ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else [] - # Use -Ofc mid to compile ConSan code, if nothing else is specified. - if any(mode in knobs.compilation.instrumentation_mode for mode in ["consan", "fpsan"]): - ptx_extra_options += ["-Ofc", "mid"] + # -Ofc mid miscompiles some large ConSan kernels into invalid global + # accesses; -O1 keeps compile time reasonable without that ptxas bug. + if (not knobs.nvidia.disable_ptxas_opt + and any(mode in knobs.compilation.instrumentation_mode for mode in ["consan", "fpsan"])): + ptx_extra_options += ["--opt-level", "1"] # Add --regAllocOptLevel=2 to work around ptxas 13.x bug reg_alloc = ['--regAllocOptLevel=2']