diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index 8a1d18de750b..2ff4e150074f 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -114,18 +114,19 @@ class FunctionBuilder { void createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, Value pred, Operation *insertPoint); // verifyBarrierArrive: Check that applying the arrive count would not drive - // the tracked current count negative. Triggers an assertion on failure. + // the tracked current count negative, and that applying the tx-count delta + // would keep it in range. Triggers an assertion on failure. void createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbar, int count, Value pred, Operation *insertPoint, - Value recipientCTAs); + Value recipientCTAs, int txCount = 0); // updateBarrierState: Apply an arrive count to the tracked barrier state, - // toggling the phase when the count reaches zero and reloading the current - // count from the initial count. + // apply a tx-count delta, toggling the phase when both counts reach zero and + // reloading the current count from the initial count. void createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, int count, Value pred, - Operation *insertPoint, - Value recipientCTAs); + Operation *insertPoint, Value recipientCTAs, + int txCount = 0); // setWriteVisibility: Set the write visibility for a buffer. Marks the buffer // as visible to the threads set in threadMask. Clears out any other threads // from the visibility bitmask. We know this is safe because there cannot be diff --git a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md index 2e563b108bdb..4fd74bfce0fb 100644 --- a/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md +++ b/include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md @@ -31,7 +31,7 @@ All types are generated on-demand (per partition) based on: - readVisibility (scratch, ): Per-buffer, per-thread lanes. Each lane stores a 64-bit mask of other threads whose reads are visible to that lane’s thread - writeTracking (scratch, ): Map buffers → barriers tracking writes (boolean stored in i8) - readTracking (scratch, ): Map buffers → barriers tracking reads (bitmask of threads) -- barrierStates (scratch, ): Packed barrier metadata. Bit 0 stores the current phase, bits [1..10] the initial arrival count, bits [11..20] the current arrival count. The verifier checks underflow before updating, and flips the phase when the current count reaches zero. +- barrierStates (scratch, ): Packed barrier metadata. Bit 0 stores the current phase, bits [1..20] the initial arrival count, bits [21..40] the current arrival count, and bits [41..61] the signed tx-count. The verifier checks underflow before updating, and flips the phase when both the current count and tx-count reach zero. - waiting (scratch, ): Per-barrier bitfield describing waiting threads. Each base thread gets two bits: bit (2 * thread + 0) is the waiting flag, bit (2 * thread + 1) stores the phase the thread is waiting on. - outstandingCommits (scratch, ): Per-buffer, per-base-thread commit counters for cp.async and wgmma @@ -58,8 +58,8 @@ ConSan separates “tracking” from “visibility transfer”: ### Barrier phase/count tracking - experimental_init_barrier_state(barrier, count, barrierStates) initializes the per-barrier state with phase = 0 and both initial/current arrival counts = `count`. -- experimental_verify_barrier_arrive(barrier, count, barrierStates) checks that subtracting `count` from the current arrival count would not underflow. The codegen emits an assert if it would. -- experimental_update_barrier_state(barrier, count, barrierStates) applies the arrive: subtracts `count`, flips the phase when the count reaches zero, and reloads the current count from the initial count. +- experimental_verify_barrier_arrive(barrier, count, txCount, barrierStates) checks that subtracting `count` from the current arrival count would not underflow and that applying `txCount` keeps the tx-count in range. The codegen emits an assert if it would not. +- experimental_update_barrier_state(barrier, count, txCount, barrierStates) applies the arrive and tx-count delta, flips the phase when both counts reach zero, and reloads the current count from the initial count. ### Deadlock detection diff --git a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h index fa23790fc534..89f55733aaa4 100644 --- a/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h +++ b/include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h @@ -14,11 +14,9 @@ namespace mlir::triton::instrument { struct MemEffectsOpInfo { // Frontier: snapshot thread-visible frontier into barrier tracking. // EffectWrites: track only buffers written by op effects. - // None: perform no visibility tracking for the barrier. enum class BarrierTrackingMode { Frontier, EffectWrites, - None, }; struct Effects { enum RW { Read, Write } rw; @@ -35,6 +33,7 @@ struct MemEffectsOpInfo { Value pred; int count; BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier; + int txCount = 0; }; enum class TrackingKind { None, diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index 15cc2777beb3..c2af6dbef485 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -48,9 +48,14 @@ namespace { namespace BarrierBits { constexpr unsigned phaseBit = 0; constexpr unsigned initCountLsb = 1; -constexpr unsigned currentCountLsb = 11; -constexpr unsigned countBitWidth = 10; -constexpr unsigned countMask = (1u << countBitWidth) - 1; +constexpr unsigned currentCountLsb = 21; +constexpr unsigned txCountLsb = 41; +constexpr unsigned countBitWidth = 20; +constexpr unsigned txCountBitWidth = 21; +constexpr uint64_t countMask = (1ull << countBitWidth) - 1; +constexpr uint64_t txCountMask = (1ull << txCountBitWidth) - 1; +constexpr int64_t txCountMin = -(int64_t)countMask; +constexpr int64_t txCountMax = (int64_t)countMask; } // namespace BarrierBits namespace WaitingBits { @@ -798,7 +803,7 @@ void FunctionBuilder::createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, int count, Value pred, Operation *insertPoint) { - assert((unsigned)count <= BarrierBits::countMask && + assert(count >= 0 && (uint64_t)count <= BarrierBits::countMask && "barrier init count exceeds barrier state capacity"); if (auxData.barriers.empty() || auxData.barrierStates.empty()) { @@ -838,21 +843,23 @@ void FunctionBuilder::createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mask = createCmpIntTensorScalar(fb, barriers, descriptor); mask = convertAndBroadcast(fb, mask, {0, 1}, barrierStatesType); + Value countWide = adjustIntegerWidth( + fb, count, cast(barrierStatesType.getElementType())); Value countMask = - arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); - Value maskedCount = arith::AndIOp::create(fb, count, countMask); + arith::ConstantIntOp::create(fb, BarrierBits::countMask, 64); + Value maskedCount = arith::AndIOp::create(fb, countWide, countMask); Value countTensor = triton::SplatOp::create(fb, barrierStatesType, maskedCount); - Value shiftOneTensor = tti::createConstIntTensor( + Value shiftInitTensor = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::initCountLsb, barrierStatesType); - Value shiftNineTensor = tti::createConstIntTensor( + Value shiftCurrentTensor = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::currentCountLsb, barrierStatesType); Value initField = - arith::ShLIOp::create(fb, countTensor, shiftOneTensor); + arith::ShLIOp::create(fb, countTensor, shiftInitTensor); Value currentField = - arith::ShLIOp::create(fb, countTensor, shiftNineTensor); + arith::ShLIOp::create(fb, countTensor, shiftCurrentTensor); Value newState = arith::OrIOp::create(fb, initField, currentField); Value updated = arith::SelectOp::create(fb, mask, newState, states); @@ -943,13 +950,14 @@ void FunctionBuilder::createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, }); } -void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, - Value mbar, int count, - Value pred, - Operation *insertPoint, - Value recipientCTAs) { - assert((unsigned)count <= BarrierBits::countMask && +void FunctionBuilder::createVerifyBarrierArriveCall( + ImplicitLocOpBuilder &b, Value mbar, int count, Value pred, + Operation *insertPoint, Value recipientCTAs, int txCount) { + assert(count >= 0 && (uint64_t)count <= BarrierBits::countMask && "barrier arrive count exceeds barrier state capacity"); + assert(txCount >= BarrierBits::txCountMin && + txCount <= BarrierBits::txCountMax && + "barrier tx-count delta exceeds barrier state capacity"); if (auxData.barriers.empty() || auxData.barrierStates.empty()) { return; @@ -958,6 +966,7 @@ void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, pred = arith::ConstantIntOp::create(b, 1, 1); } Value countVal = arith::ConstantIntOp::create(b, count, 32); + Value txCountVal = arith::ConstantIntOp::create(b, txCount, 64); Value barriersVal = auxData.barriers.at(insertPoint).value; auto barriersType = cast(auxData.barriers.at(insertPoint).type); @@ -967,10 +976,12 @@ void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, uint32_t length = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = {mbarOffset, lengthVal, countVal, pred, - barriersVal, barrierStatesVal, recipientCTAs}; + SmallVector args = {mbarOffset, lengthVal, countVal, + txCountVal, pred, barriersVal, + barrierStatesVal, recipientCTAs}; AssertInfo assertInfo{ - "Barrier arrive underflow: current count would become negative", + "Barrier arrive underflow: current count or tx-count would become " + "invalid", barrierStatesType.cloneWith(std::nullopt, b.getI1Type())}; createCallToCachedFunction( b, "verify_barrier_arrive", args, assertInfo, @@ -980,11 +991,12 @@ void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value count = entryBlock->getArgument(2); - Value pred = entryBlock->getArgument(3); + Value txCount = entryBlock->getArgument(3); + Value pred = entryBlock->getArgument(4); - Value barriers = entryBlock->getArgument(4); - Value statesPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value barriers = entryBlock->getArgument(5); + Value statesPtr = entryBlock->getArgument(6); + Value recipientCTAs = entryBlock->getArgument(7); Value states = tti::createLoadScratchMemory(fb, fb.getLoc(), statesPtr, barrierStatesType); @@ -996,45 +1008,78 @@ void FunctionBuilder::createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, tti::createConstIntTensor(fb, fb.getLoc(), 0, barrierStatesType); Value maskFF = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::countMask, barrierStatesType); - Value shiftNineTensor = tti::createConstIntTensor( + Value shiftCurrentTensor = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::currentCountLsb, barrierStatesType); + Value shiftTxTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::txCountLsb, barrierStatesType); + Value shiftTxSignTensor = tti::createConstIntTensor( + fb, fb.getLoc(), 64 - BarrierBits::txCountBitWidth, + barrierStatesType); Value currentCount = - arith::ShRUIOp::create(fb, states, shiftNineTensor); + arith::ShRUIOp::create(fb, states, shiftCurrentTensor); currentCount = arith::AndIOp::create(fb, currentCount, maskFF); + Value currentTxCount = + arith::ShRUIOp::create(fb, states, shiftTxTensor); + currentTxCount = + arith::ShLIOp::create(fb, currentTxCount, shiftTxSignTensor); + currentTxCount = + arith::ShRSIOp::create(fb, currentTxCount, shiftTxSignTensor); Value countMask = - arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); - Value maskedCount = arith::AndIOp::create(fb, count, countMask); + arith::ConstantIntOp::create(fb, BarrierBits::countMask, 64); + Value countWide = adjustIntegerWidth( + fb, count, cast(barrierStatesType.getElementType())); + Value maskedCount = arith::AndIOp::create(fb, countWide, countMask); Value arriveCount = triton::SplatOp::create(fb, barrierStatesType, maskedCount); + Value txCountTensor = + triton::SplatOp::create(fb, barrierStatesType, txCount); Value newCurrent = arith::SubIOp::create(fb, currentCount, arriveCount); Value newCurrentMasked = arith::SelectOp::create(fb, mask, newCurrent, zero32); - Value nonNegative = arith::CmpIOp::create(fb, arith::CmpIPredicate::sge, - newCurrentMasked, zero32); + Value newTxCount = + arith::AddIOp::create(fb, currentTxCount, txCountTensor); + Value newTxCountMasked = + arith::SelectOp::create(fb, mask, newTxCount, zero32); + Value arrivalsNonNegative = arith::CmpIOp::create( + fb, arith::CmpIPredicate::sge, newCurrentMasked, zero32); + Value minTxCount = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::txCountMin, barrierStatesType, + /*isSigned=*/true); + Value maxTxCount = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::txCountMax, barrierStatesType); + Value txCountInRange = arith::AndIOp::create( + fb, + arith::CmpIOp::create(fb, arith::CmpIPredicate::sge, + newTxCountMasked, minTxCount), + arith::CmpIOp::create(fb, arith::CmpIPredicate::sle, + newTxCountMasked, maxTxCount)); + Value valid = + arith::AndIOp::create(fb, arrivalsNonNegative, txCountInRange); Value vTrue = tti::createConstIntTensor( - fb, fb.getLoc(), 1, cast(nonNegative.getType())); - auto condType = cast(nonNegative.getType()); + fb, fb.getLoc(), 1, cast(valid.getType())); + auto condType = cast(valid.getType()); Value ctaMask = createRecipientCTAMask(fb, condType, recipientCTAs); - nonNegative = arith::SelectOp::create(fb, ctaMask, nonNegative, vTrue); + valid = arith::SelectOp::create(fb, ctaMask, valid, vTrue); Value predTensor = triton::SplatOp::create( - fb, cast(nonNegative.getType()), pred); - Value predicatedNonNegative = - arith::SelectOp::create(fb, predTensor, nonNegative, vTrue); + fb, cast(valid.getType()), pred); + Value predicatedValid = + arith::SelectOp::create(fb, predTensor, valid, vTrue); - triton::ReturnOp::create(fb, predicatedNonNegative); + triton::ReturnOp::create(fb, predicatedValid); }); } -void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, - Value mbar, int count, - Value pred, - Operation *insertPoint, - Value recipientCTAs) { - assert((unsigned)count <= BarrierBits::countMask && +void FunctionBuilder::createUpdateBarrierStateCall( + ImplicitLocOpBuilder &b, Value mbar, int count, Value pred, + Operation *insertPoint, Value recipientCTAs, int txCount) { + assert(count >= 0 && (uint64_t)count <= BarrierBits::countMask && "barrier update count exceeds barrier state capacity"); + assert(txCount >= BarrierBits::txCountMin && + txCount <= BarrierBits::txCountMax && + "barrier tx-count delta exceeds barrier state capacity"); if (auxData.barriers.empty() || auxData.barrierStates.empty()) { return; @@ -1043,6 +1088,7 @@ void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, pred = arith::ConstantIntOp::create(b, 1, 1); } Value countVal = arith::ConstantIntOp::create(b, count, 32); + Value txCountVal = arith::ConstantIntOp::create(b, txCount, 64); Value barriersVal = auxData.barriers.at(insertPoint).value; auto barriersType = cast(auxData.barriers.at(insertPoint).type); @@ -1052,8 +1098,9 @@ void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, uint32_t length = getMemDescLength(mbar); Value mbarOffset = tti::ExperimentalMemDescToI32Op::create(b, mbar); Value lengthVal = arith::ConstantIntOp::create(b, length, 32); - SmallVector args = {mbarOffset, lengthVal, countVal, pred, - barriersVal, barrierStatesVal, recipientCTAs}; + SmallVector args = {mbarOffset, lengthVal, countVal, + txCountVal, pred, barriersVal, + barrierStatesVal, recipientCTAs}; createCallToCachedFunction( b, "update_barrier_state", args, /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, @@ -1062,11 +1109,12 @@ void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value count = entryBlock->getArgument(2); - Value pred = entryBlock->getArgument(3); + Value txCount = entryBlock->getArgument(3); + Value pred = entryBlock->getArgument(4); - Value barriers = entryBlock->getArgument(4); - Value statesPtr = entryBlock->getArgument(5); - Value recipientCTAs = entryBlock->getArgument(6); + Value barriers = entryBlock->getArgument(5); + Value statesPtr = entryBlock->getArgument(6); + Value recipientCTAs = entryBlock->getArgument(7); auto [prevBlock, ifBlock, thenBlock] = createIfBlock(fb, pred); fb.setInsertionPointToStart(ifBlock); @@ -1083,42 +1131,73 @@ void FunctionBuilder::createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, tti::createConstIntTensor(fb, fb.getLoc(), 1, barrierStatesType); Value maskFF = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::countMask, barrierStatesType); - Value shiftOneTensor = tti::createConstIntTensor( + Value shiftInitTensor = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::initCountLsb, barrierStatesType); - Value shiftNineTensor = tti::createConstIntTensor( + Value shiftCurrentTensor = tti::createConstIntTensor( fb, fb.getLoc(), BarrierBits::currentCountLsb, barrierStatesType); + Value shiftTxTensor = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::txCountLsb, barrierStatesType); + Value shiftTxSignTensor = tti::createConstIntTensor( + fb, fb.getLoc(), 64 - BarrierBits::txCountBitWidth, + barrierStatesType); Value phase = arith::AndIOp::create(fb, states, one32); - Value initCount = arith::ShRUIOp::create(fb, states, shiftOneTensor); + Value initCount = arith::ShRUIOp::create(fb, states, shiftInitTensor); initCount = arith::AndIOp::create(fb, initCount, maskFF); Value currentCount = - arith::ShRUIOp::create(fb, states, shiftNineTensor); + arith::ShRUIOp::create(fb, states, shiftCurrentTensor); currentCount = arith::AndIOp::create(fb, currentCount, maskFF); + Value currentTxCount = + arith::ShRUIOp::create(fb, states, shiftTxTensor); + currentTxCount = + arith::ShLIOp::create(fb, currentTxCount, shiftTxSignTensor); + currentTxCount = + arith::ShRSIOp::create(fb, currentTxCount, shiftTxSignTensor); Value countMask = - arith::ConstantIntOp::create(fb, BarrierBits::countMask, 32); - Value maskedCount = arith::AndIOp::create(fb, count, countMask); + arith::ConstantIntOp::create(fb, BarrierBits::countMask, 64); + Value countWide = adjustIntegerWidth( + fb, count, cast(barrierStatesType.getElementType())); + Value maskedCount = arith::AndIOp::create(fb, countWide, countMask); Value arriveCount = triton::SplatOp::create(fb, barrierStatesType, maskedCount); + Value txCountTensor = + triton::SplatOp::create(fb, barrierStatesType, txCount); Value newCurrent = arith::SubIOp::create(fb, currentCount, arriveCount); Value newCurrentMasked = arith::SelectOp::create(fb, mask, newCurrent, currentCount); - - Value zeroCond = arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, - newCurrentMasked, zero32); + Value newTxCount = + arith::AddIOp::create(fb, currentTxCount, txCountTensor); + Value newTxCountMasked = + arith::SelectOp::create(fb, mask, newTxCount, currentTxCount); + + Value zeroCond = arith::AndIOp::create( + fb, + arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + newCurrentMasked, zero32), + arith::CmpIOp::create(fb, arith::CmpIPredicate::eq, + newTxCountMasked, zero32)); zeroCond = arith::AndIOp::create(fb, zeroCond, mask); Value zeroCondI32 = arith::ExtUIOp::create(fb, barrierStatesType, zeroCond); Value newPhase = arith::XOrIOp::create(fb, phase, zeroCondI32); Value newCurrentValue = arith::SelectOp::create(fb, zeroCond, initCount, newCurrentMasked); + Value newTxCountValue = + arith::SelectOp::create(fb, zeroCond, zero32, newTxCountMasked); - Value initField = arith::ShLIOp::create(fb, initCount, shiftOneTensor); + Value initField = arith::ShLIOp::create(fb, initCount, shiftInitTensor); Value currentField = - arith::ShLIOp::create(fb, newCurrentValue, shiftNineTensor); + arith::ShLIOp::create(fb, newCurrentValue, shiftCurrentTensor); + Value txCountMask = tti::createConstIntTensor( + fb, fb.getLoc(), BarrierBits::txCountMask, barrierStatesType); + Value txCountField = + arith::AndIOp::create(fb, newTxCountValue, txCountMask); + txCountField = arith::ShLIOp::create(fb, txCountField, shiftTxTensor); Value newState = arith::OrIOp::create(fb, newPhase, initField); newState = arith::OrIOp::create(fb, newState, currentField); + newState = arith::OrIOp::create(fb, newState, txCountField); Value updated = arith::SelectOp::create(fb, mask, newState, states); createCTAScopedStoreScratchMemory(fb, fb.getLoc(), statesPtr, updated, diff --git a/lib/Dialect/TritonInstrument/IR/Utility.cpp b/lib/Dialect/TritonInstrument/IR/Utility.cpp index d1df49ddc9fd..4a7945383abf 100644 --- a/lib/Dialect/TritonInstrument/IR/Utility.cpp +++ b/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -535,8 +535,8 @@ void AuxDataMap::populateAndPassToWarpSpecialize( int numBarriers = barrierRegions.size(); barrierStates.insert( entryRegion, - {createZeroInitStateTensor(b, {numCTAs, numBarriers}, 32, fb), - getIntTensorType(entryRegion, {numCTAs, numBarriers}, 32)}); + {createZeroInitStateTensor(b, {numCTAs, numBarriers}, 64, fb), + getIntTensorType(entryRegion, {numCTAs, numBarriers}, 64)}); passToWarpSpecialize(entryPoint, barrierStates.at(entryRegion), barrierStates, captureCounter); diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index e55e4de8610d..38b24fcc2158 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -19,7 +19,7 @@ // ------------------|---------|-----------------|------------ // buffers | tensor | | Base pointers of all (sub)buffers // barriers | tensor | | Pointers to all individual mbarriers -// barrierStates | scratch | | Packed barrier phase (bit 0) and arrival counts (bits[1..10] init, [11..20] current); zero means invalid/uninitialized +// barrierStates | scratch | | Packed barrier phase (bit 0), arrival counts (bits[1..20] init, [21..40] current), and signed tx-count (bits[41..61]); zero means invalid/uninitialized // waiting | scratch | | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1) // writeVisibility | scratch | | Per-buffer thread-visibility bitmask (bit i => thread i visible) // readVisibility | scratch | | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks) @@ -155,23 +155,37 @@ Value currentCTAMask(ImplicitLocOpBuilder &b) { ctaId); } -Value getRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { - if (auto arriveOp = dyn_cast(op)) { - auto barrierTy = cast(arriveOp.getAlloc().getType()); - auto kBlock = StringAttr::get(op->getContext(), "block"); - uint16_t broadcastMask = - toLinearLayout(barrierTy).getFreeVariableMasks().lookup(kBlock); - if (broadcastMask) { - int numCTAs = ttg::lookupNumCTAs(b); - auto encoding = ttng::getTMAMulticastMaskEncoding(numCTAs, broadcastMask); - Value ctaId = tti::ExperimentalClusterCTAIdOp::create(b, b.getLoc()); - Value leaderCTA = arith::AndIOp::create( - b, ctaId, arith::ConstantIntOp::create(b, encoding.fixedBits, 32)); - return arith::ShLIOp::create(b, arith::ConstantIntOp::create(b, 1, 32), - leaderCTA); - } +uint16_t getBlockBroadcastMask(Value alloc) { + auto allocTy = cast(alloc.getType()); + auto kBlock = StringAttr::get(alloc.getContext(), "block"); + return toLinearLayout(allocTy).getFreeVariableMasks().lookup(kBlock); +} + +Value createCTABitset(ImplicitLocOpBuilder &b, uint32_t pattern, + uint32_t baseMask) { + Value ctaId = tti::ExperimentalClusterCTAIdOp::create(b, b.getLoc()); + Value base = arith::AndIOp::create( + b, ctaId, arith::ConstantIntOp::create(b, baseMask, 32)); + return arith::ShLIOp::create(b, arith::ConstantIntOp::create(b, pattern, 32), + base); +} + +Value getLeaderCTA(ImplicitLocOpBuilder &b, Value barrier) { + uint16_t broadcastMask = getBlockBroadcastMask(barrier); + if (!broadcastMask) return currentCTAMask(b); - } + int numCTAs = ttg::lookupNumCTAs(b); + auto encoding = ttng::getTMAMulticastMaskEncoding(numCTAs, broadcastMask); + return createCTABitset(b, /*pattern=*/1, encoding.fixedBits); +} + +Value getRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) { + if (auto expectOp = dyn_cast(op)) + return getLeaderCTA(b, expectOp.getAlloc()); + if (auto arriveOp = dyn_cast(op)) + return getLeaderCTA(b, arriveOp.getAlloc()); + if (auto copyOp = dyn_cast(op)) + return getLeaderCTA(b, copyOp.getBarrier()); SmallVector broadcastMasks; if (auto commitOp = dyn_cast(op)) { @@ -363,6 +377,27 @@ class ConcurrencySanitizerImpl { tti::ExperimentalLockReleaseOp::create(wb, lock, pred); } + void instrumentBarrierExpectNonLeaderArrive( + ImplicitLocOpBuilder &b, ttng::BarrierExpectOp expectOp, + Value nonLeaderPred, int thread, tti::FunctionBuilder &funcBuilder) { + Value barrier = expectOp.getAlloc(); + Value recipientCTAs = getLeaderCTA(b, barrier); + + // Match BarrierOpToLLVM's cross-CTA path: non-leader CTAs contribute a + // plain arrive of count 1 to the leader barrier. The generic barrier path + // models the leader CTA's expect_tx. + for (MemType memType : {MemType::SHARED_MEM, MemType::TENSOR_MEM}) { + funcBuilder.createTrackVisibleWritesCall( + b, barrier, thread, nonLeaderPred, memType, expectOp, recipientCTAs); + funcBuilder.createTrackVisibleReadsCall(b, barrier, thread, nonLeaderPred, + memType, expectOp, recipientCTAs); + } + funcBuilder.createVerifyBarrierArriveCall( + b, barrier, /*count=*/1, nonLeaderPred, expectOp, recipientCTAs); + funcBuilder.createUpdateBarrierStateCall( + b, barrier, /*count=*/1, nonLeaderPred, expectOp, recipientCTAs); + } + void instrumentMemEffects(ImplicitLocOpBuilder &b, Operation *op, int thread, tti::FunctionBuilder &funcBuilder) { int baseThread = getBaseThread(thread); @@ -370,7 +405,20 @@ class ConcurrencySanitizerImpl { if (!opInfo) { return; } - Value pred = tti::maybeAnd(b, opInfo->pred, hooks->getIssuerCTAPred(b, op)); + Value pred = opInfo->pred; + // Barrier expect performs an arrive on non-leader CTAs, so we need to + // instrument it separately before incorporating getIssuerCTAPred. + Value issuerCTAPred = hooks->getIssuerCTAPred(b, op); + if (auto expectOp = dyn_cast(op)) { + if (issuerCTAPred) { + Value nonLeaderPred = arith::XOrIOp::create( + b, issuerCTAPred, arith::ConstantIntOp::create(b, 1, 1)); + nonLeaderPred = tti::maybeAnd(b, pred, nonLeaderPred); + instrumentBarrierExpectNonLeaderArrive(b, expectOp, nonLeaderPred, + thread, funcBuilder); + } + } + pred = tti::maybeAnd(b, pred, issuerCTAPred); Value recipientCTAs = getRecipientCTAs(b, op); for (auto effect : opInfo->operandEffects) { Value buf = effect.buf; @@ -452,11 +500,13 @@ class ConcurrencySanitizerImpl { b, barrier, effect.buf, effect.length, combinedPred, memType, op); } } - if (barrierInfo.count > 0) { + if (barrierInfo.count > 0 || barrierInfo.txCount != 0) { funcBuilder.createVerifyBarrierArriveCall( - b, barrier, barrierInfo.count, combinedPred, op, recipientCTAs); + b, barrier, barrierInfo.count, combinedPred, op, recipientCTAs, + barrierInfo.txCount); funcBuilder.createUpdateBarrierStateCall( - b, barrier, barrierInfo.count, combinedPred, op, recipientCTAs); + b, barrier, barrierInfo.count, combinedPred, op, recipientCTAs, + barrierInfo.txCount); } } if (opInfo->implicitCommit) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp index 20094a44c122..aa829c7f9739 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp @@ -85,6 +85,8 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { }; if (auto initOp = dyn_cast(op)) mask = getBarrierMask(initOp.getAlloc()); + if (auto expectOp = dyn_cast(op)) + mask = getBarrierMask(expectOp.getAlloc()); if (auto waitOp = dyn_cast(op)) mask = getBarrierMask(waitOp.getAlloc()); if (auto invalOp = dyn_cast(op)) @@ -107,19 +109,16 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { if (info) return info; if (auto expectOp = dyn_cast(op)) { - // TODO: For async TMA barriers, the barrier "arrive" corresponding to the - // completion mechanism is modeled by barrier_expect. Individual - // async_tma_copy ops should not decrement the barrier state, otherwise - // multiple copies using the same barrier would incorrectly advance the - // phase multiple times. This should be improved bu tracking the barrier - // expected byte count, and "arriving" the barrier when the expected byte - // count is reached. info.emplace(); info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier; info->pred = expectOp.getPred(); + auto barrierTy = expectOp.getAlloc().getType(); + int txCount = expectOp.getSize() * ttg::lookupNumCTAs(op) / + barrierTy.getNumElements(); info->barriers.push_back({expectOp.getBarrier(), nullptr, /*count=*/1, - MemEffectsOpInfo::BarrierTrackingMode::None}); + MemEffectsOpInfo::BarrierTrackingMode::Frontier, + /*txCount=*/txCount}); } if (auto loadOp = dyn_cast(op)) { info.emplace(); @@ -204,7 +203,8 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->pred = copyOp.getPred(); info->barriers.push_back( {copyOp.getBarrier(), nullptr, /*count=*/0, - MemEffectsOpInfo::BarrierTrackingMode::EffectWrites}); + MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, + /*txCount=*/-(int)tti::getMemDescLength(copyOp.getResult())}); info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, copyOp.getResult()); } @@ -222,7 +222,8 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks { info->pred = gatherOp.getPred(); info->barriers.push_back( {gatherOp.getBarrier(), nullptr, /*count=*/0, - MemEffectsOpInfo::BarrierTrackingMode::EffectWrites}); + MemEffectsOpInfo::BarrierTrackingMode::EffectWrites, + /*txCount=*/-(int)tti::getMemDescLength(gatherOp.getResult())}); info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write, gatherOp.getResult()); } diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 5467f0232e38..97435e55fa50 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -247,6 +247,49 @@ def kernel(a_desc, b_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](a_desc, b_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") +@pytest.mark.parametrize("EXPECT_DELTA", [-16, 16], ids=["under", "over"]) +def test_async_tma_expect_bytes_mismatch(EXPECT_DELTA, device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_async_tma_expect_bytes_mismatch, + (EXPECT_DELTA, device, False, monkeypatch, num_ctas)) + assert_expected_cuda_failure(result.exc) + assert "Deadlock detected" in result.driver_stderr_output + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(input_desc, out, EXPECT_DELTA: ttgl.constexpr): + block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() + cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) + bar = mbarrier.allocate_mbarrier() + mbarrier.init(bar, count=1) + mbarrier.expect(bar, input_desc.nbytes_per_cta + EXPECT_DELTA) + tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0, deps=[smem]) + val = smem.load(blocked_layout) + mbarrier.invalidate(bar) + + out_m = ttgl.arange(0, block_m, ttgl.SliceLayout(1, blocked_layout))[:, None] + out_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :] + out_ptr = out + out_m * XBLOCK + out_n + ttgl.store(out_ptr, val) + + block_m = XBLOCK.value * num_ctas + input = torch.randn((block_m, XBLOCK.value), device=device, dtype=torch.float16) + output = torch.empty((block_m, XBLOCK.value), device=device, dtype=torch.float16) + shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, + cga_layout=default_cga_layout(num_ctas, 2)) + input_desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input, [block_m, XBLOCK.value], shared_layout) + kernel[(1, )](input_desc, output, EXPECT_DELTA=EXPECT_DELTA, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_tma_interleave_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): @@ -1888,7 +1931,7 @@ def test_barrier_underflow(device, run_wrapper, monkeypatch, num_ctas): if run_wrapper: result = run_in_process(test_barrier_underflow, (device, False, monkeypatch, num_ctas)) assert_expected_cuda_failure(result.exc) - assert "Barrier arrive underflow: current count would become negative" in result.driver_stderr_output + assert "Barrier arrive underflow" in result.driver_stderr_output return monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") diff --git a/test/TritonGPU/consan.mlir b/test/TritonGPU/consan.mlir index 0b5c0d7273cf..9c75abf104bb 100644 --- a/test/TritonGPU/consan.mlir +++ b/test/TritonGPU/consan.mlir @@ -232,6 +232,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar ttng.barrier_expect %bar, 4096, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: tt.call @__triton_consan_init_barrier_state // CHECK: tt.call @__triton_consan_verify_barrier_initialized + // CHECK: tt.call @__triton_consan_track_visible_writes + // CHECK: tt.call @__triton_consan_track_visible_reads // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state // CHECK: tt.call @__triton_consan_verify_write_visibility @@ -240,9 +242,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: tt.call @__triton_consan_clear_write_tracking // CHECK: tt.call @__triton_consan_clear_read_visibility // CHECK: tt.call @__triton_consan_clear_read_tracking - // CHECK-NOT: tt.call @__triton_consan_track_visible_writes - // CHECK-NOT: tt.call @__triton_consan_track_visible_reads // CHECK: tt.call @__triton_consan_track_barrier_write_for_buffer + // CHECK: tt.call @__triton_consan_verify_barrier_arrive + // CHECK: tt.call @__triton_consan_update_barrier_state ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] %0, %bar, %true : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } @@ -297,8 +299,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65552 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @async_tma_copy_global_to_local_two_bufs_two_barriers - // CHECK-NOT: tt.call @__triton_consan_track_visible_writes - // CHECK-NOT: tt.call @__triton_consan_track_visible_reads // CHECK: %[[A_SMEM:.*]] = ttg.local_alloc {allocation.offset = 0 : i32} // CHECK: %[[B_SMEM:.*]] = ttg.local_alloc {allocation.offset = 4096 : i32} // CHECK: %[[BAR0:.*]] = ttg.local_alloc {allocation.offset = 65536 : i32} @@ -524,8 +524,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:90", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} { // CHECK-LABEL: @arrive_barrier tt.func public @arrive_barrier(%arg0: !tt.tensordesc>) { - // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 4 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr - // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI32(%[[BSTATE_GLOB]], %c0_i32 + // CHECK-DAG: %[[BSTATE_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation} : !tt.ptr + // CHECK: call {{.*}}fill_global_tensor{{.*}}T1x1xI64(%[[BSTATE_GLOB]], %c0_i64 %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -628,19 +628,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TM_BUFS]], %[[TM_READ_TRACKING_GLOB]] + // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{.*}}, %[[TM_BUFS]], %{{[^,]+}} // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR:.*]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_WRITE_VISIBILITY_GLOB]], %[[SM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : - // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %[[SM_READ_TRACKING_GLOB]] + // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %{{[^,]+}} // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_WRITE_VISIBILITY_GLOB]], %[[TM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : - // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %[[TM_READ_TRACKING_GLOB]] + // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %{{[^,]+}} // CHECK: ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][], {{.*}}, {{.*}}, %[[BAR]] %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -703,19 +703,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : // CHECK: tt.call @__triton_consan_clear_read_visibility{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TM_BUFS]], %[[TM_READ_VISIBILITY_GLOB]] // CHECK: %[[ACC_I64:.*]] = tti.experimental_memdesc_to_i32 %[[ACC]] : - // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{[^,]+}}, %{{[^,]+}}, %[[TM_BUFS]], %[[TM_READ_TRACKING_GLOB]] + // CHECK: tt.call @__triton_consan_clear_read_tracking{{.*}}%[[ACC_I64]], {{.*}}, %[[TM_BUFS]], %{{[^,]+}} // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR:.*]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_WRITE_VISIBILITY_GLOB]], %[[SM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : - // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %[[SM_READ_TRACKING_GLOB]] + // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[SM_READ_VISIBILITY_GLOB]], %{{[^,]+}} // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : // CHECK: tt.call @__triton_consan_track_visible_writes{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_WRITE_VISIBILITY_GLOB]], %[[TM_WRITE_TRACKING_GLOB]] // CHECK: %[[TC_BIT:.*]] = arith.constant 32 : i32 // CHECK: %[[BAR_I64:.*]] = tti.experimental_memdesc_to_i32 %[[BAR]] : - // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %[[TM_READ_TRACKING_GLOB]] + // CHECK: tt.call @__triton_consan_track_visible_reads{{.*}}%[[BAR_I64]], {{.*}}, %[[TC_BIT]], %[[BARRIERS]], %[[TM_READ_VISIBILITY_GLOB]], %{{[^,]+}} // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state // CHECK: tti.experimental_lock_release