Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,16 @@ class FunctionBuilder {
int thread, Value pred, MemType memType,
Operation *insertPoint, Value recipientCTAs);
// trackBarrierWriteForBuffer: mark a specific buffer as tracked by a
// barrier in the write-tracking table.
// barrier in the write-tracking table. When diagonalEffectRecipientCTAs is
// false, every signaled barrier row publishes the full effectRecipientCTAs
// mask. When it is true, barrier row i publishes only bit i of that mask.
void createTrackBarrierWriteForBufferCall(ImplicitLocOpBuilder &b, Value mbar,
Value buf, uint32_t length,
Value pred, MemType memType,
Operation *insertPoint,
Value barrierRecipientCTAs,
Value effectRecipientCTAs);
Value effectRecipientCTAs,
bool diagonalEffectRecipientCTAs);
// clearBarrierWriteTracking: clear all write tracking associated with the
// given barrier row.
void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar,
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ Auxiliary state is kept in distributed tensors and global scratch memory, with t
### Thread model

- Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions).
- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads.
- Total logical threads: 48. Bitmasks are sized to the next power of two: 64.
- Peer classes: +16 TMA threads, +16 Tensor Core (TC) threads, and +16 CLC threads to model lack of ordering with base threads.
- Total logical threads: 64.

Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience.
Indexing uses a logical thread id in [0, 64), with column vectors sized to 64 for layout convenience.

## Auxiliary data structures

Expand All @@ -21,7 +21,7 @@ All types are generated on-demand (per partition) based on:
- B: number of tracked buffers (power-of-two padded)
- K: number of mbarriers (power-of-two padded)
- T_bits: 64 (bitmask width)
- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers)
- T_commits: 16 (base threads; commit counters do not apply to TC/TMA/CLC helpers)

“tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts.

Expand Down Expand Up @@ -53,7 +53,7 @@ ConSan separates “tracking” from “visibility transfer”:
- experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer.
- experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier.
- At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes.
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent.
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC, CLC) to keep the classes consistent.

### Barrier phase/count tracking

Expand Down
8 changes: 6 additions & 2 deletions include/triton/Dialect/TritonInstrument/IR/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
#include "llvm/Support/MathExtras.h"

#include <array>

Expand All @@ -22,8 +23,11 @@ constexpr int numMemTypes = getMaxEnumValForMemType() + 1;
constexpr int NUM_THREADS = 16;
constexpr int TMA_THREAD_OFFSET = NUM_THREADS;
constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS;
constexpr int TOTAL_NUM_THREADS = TC_THREAD_OFFSET + NUM_THREADS;
constexpr int THREADS_BITMASK_SIZE = llvm::NextPowerOf2(TOTAL_NUM_THREADS);
constexpr int CLC_THREAD_OFFSET = TC_THREAD_OFFSET + NUM_THREADS;
constexpr int TOTAL_NUM_THREADS = CLC_THREAD_OFFSET + NUM_THREADS;
static_assert(TOTAL_NUM_THREADS <= 64,
"ConSan thread bitsets are stored in i64 masks");
const int THREADS_BITMASK_SIZE = llvm::PowerOf2Ceil(TOTAL_NUM_THREADS);

namespace CommitKind {
enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,19 @@
namespace mlir::triton::instrument {

struct MemEffectsOpInfo {
// Frontier: snapshot thread-visible frontier into barrier tracking.
// EffectWrites: track only buffers written by op effects.
// Controls which memory effects become visible to a CTA after it waits on
// this barrier.
//
// Frontier snapshots the issuing thread's current visibility frontier into
// the barrier. A later wait publishes whatever shared/tensor memory writes
// and reads were visible to that logical thread before the arrive/commit. Use
// this for ordering operations whose semantics are a release of prior work.
//
// EffectWrites does not snapshot the whole thread frontier. Instead, it
// attaches only the explicit write effects of this op to the barrier. A later
// wait publishes those op-local writes and nothing else. Use this for PTX ops
// that perform the write and also signal the barrier via
// `mbarrier::complete_tx`.
enum class BarrierTrackingMode {
Frontier,
EffectWrites,
Expand All @@ -34,6 +45,18 @@ struct MemEffectsOpInfo {
int count;
BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier;
int txCount = 0;
// For EffectWrites, effectRecipientCTAs identifies the CTA rows where the
// op wrote its explicit result. By default, for
// diagonalEffectRecipientCTAs=false, waiting on a barrier publishes the CTA
// rows in effectRecipientCTAs, which is the full mask. This is the
// behaviour of TMA multicast. If diagonalEffectRecipientCTAs is true,
// waiting on a barrier publishes only the CTA rows in effectRecipientCTAs,
// which is the diagonal mask. e.g. effectRecipientCTAs = 0b1101 If
// DiagonalEffectRecipientCTAs is false, waiting on the barrier publishes
// the following CTA rows: CTA0 0b1101 CTA1 0b1101 CTA2 0b1101 CTA3 0b1101
// If diagonalEffectRecipientCTAs is true, waiting on the barrier publishes
// the following CTA rows: CTA0 0b1000 CTA1 0b0100 CTA2 0b0000 CTA3 0b0001
bool diagonalEffectRecipientCTAs = false;
};
enum class TrackingKind {
None,
Expand Down Expand Up @@ -83,6 +106,8 @@ class ConSanTargetHooks {

virtual bool isTMAOp(Operation *op) const = 0;

virtual bool isCLCOp(Operation *op) const { return false; }

virtual std::optional<BarrierInitInfo>
getBarrierInitInfo(Operation *op) const = 0;

Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/BufferRegion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void BufferRegionAnalysis::calculateUsedBufferRegions(Operation *op) {
bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) {
if (isa<ttg::LocalLoadOp, ttg::LocalStoreOp, ttng::TMEMLoadOp,
ttng::TMEMStoreOp, ttng::TMEMCopyOp, ttg::AsyncCopyGlobalToLocalOp,
ttng::TMAOpInterface>(op)) {
ttng::TMAOpInterface, ttng::CLCLoadResultOp>(op)) {
return true;
}
if (isa<ttg::MBarrierOpInterface>(op)) {
Expand Down
45 changes: 34 additions & 11 deletions lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1674,13 +1674,13 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b,
void FunctionBuilder::createTrackBarrierWriteForBufferCall(
ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred,
MemType memType, Operation *insertPoint, Value barrierRecipientCTAs,
Value effectRecipientCTAs) {
Value effectRecipientCTAs, bool diagonalEffectRecipientCTAs) {
if (auxData.barriers.empty() || auxData.buffers[(int)memType].empty() ||
auxData.writeTracking[(int)memType].empty()) {
return;
}
assert(!auxData.barrierWriteRecipients.empty() &&
"barrier write recipients must exist when tracking TMA writes");
"barrier write recipients must exist when tracking EffectWrites");
if (!pred)
pred = arith::ConstantIntOp::create(b, 1, 1);
Value barriersVal = auxData.barriers.at(insertPoint).value;
Expand Down Expand Up @@ -1717,10 +1717,10 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall(
b, "track_barrier_write_for_buffer", args,
/*assertInfo=*/std::nullopt,
{barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType,
(uint64_t)memType},
[barriersType, buffersType, writeTrackingType,
barrierWriteRecipientsType](ImplicitLocOpBuilder &fb,
Block *entryBlock) {
(uint64_t)memType, (uint64_t)diagonalEffectRecipientCTAs},
[barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType,
diagonalEffectRecipientCTAs](ImplicitLocOpBuilder &fb,
Block *entryBlock) {
Value mbarOffset = entryBlock->getArgument(0);
Value mbarLengthVal = entryBlock->getArgument(1);
Value pred = entryBlock->getArgument(2);
Expand All @@ -1747,6 +1747,30 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall(
createCmpIntTensorScalar(fb, barriers, barrierDescriptor);
Value effectRecipientCTAsTensor = triton::SplatOp::create(
fb, barrierWriteRecipientsType, effectRecipientCTAs);
if (diagonalEffectRecipientCTAs) {
// Expand the effect CTA mask diagonally: barrier row i publishes only
// bit i. This models per-CTA results, while the default replicated
// mask models TMA multicast where a barrier publishes all result
// rows.
auto encoding = cast<ttg::DistributedEncodingTrait>(
barrierWriteRecipientsType.getEncoding());
auto rowSliceEncoding =
tti::getSingleDimSliceEncoding(encoding, /*dim=*/0);
int numCTAs = barrierWriteRecipientsType.getShape()[0];
auto rowType = RankedTensorType::get({numCTAs}, fb.getI32Type(),
rowSliceEncoding);
Value rowIdx = triton::MakeRangeOp::create(fb, rowType,
/*start=*/0,
/*end=*/numCTAs);
auto indexType =
cast<RankedTensorType>(barrierWriteRecipientsType.cloneWith(
std::nullopt, fb.getI32Type()));
rowIdx = convertAndBroadcast(fb, rowIdx, {0}, indexType);
Value one = tti::createConstIntTensor(fb, fb.getLoc(), 1, indexType);
Value rowBit = arith::ShLIOp::create(fb, one, rowIdx);
effectRecipientCTAsTensor =
arith::AndIOp::create(fb, effectRecipientCTAsTensor, rowBit);
}
Value updatedBarrierWriteRecipients = arith::OrIOp::create(
fb, barrierWriteRecipients, effectRecipientCTAsTensor);
updatedBarrierWriteRecipients = arith::SelectOp::create(
Expand Down Expand Up @@ -2365,11 +2389,10 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b,
Value zeroTensor =
tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType);

constexpr uint64_t fullMask =
tti::THREADS_BITMASK_SIZE == 64
? std::numeric_limits<uint64_t>::max()
: (std::numeric_limits<uint64_t>::max() >>
(64 - tti::THREADS_BITMASK_SIZE));
uint64_t fullMask = tti::THREADS_BITMASK_SIZE == 64
? std::numeric_limits<uint64_t>::max()
: (std::numeric_limits<uint64_t>::max() >>
(64 - tti::THREADS_BITMASK_SIZE));
Value fullMaskVal = arith::ConstantIntOp::create(fb, fullMask, 64);
Value destMaskElem = adjustIntegerWidth(fb, destMaskVal, elemType);
Value fullMaskElem = adjustIntegerWidth(fb, fullMaskVal, elemType);
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TritonInstrument/IR/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,7 @@ void AuxDataMap::populateAndPassToWarpSpecialize(
if (numCTAs > 1) {
ClusterBarrierOp::create(b, b.getLoc());
} else {
BarrierOp::create(b, b.getLoc(),
AddrSpace::GlobalRead | AddrSpace::GlobalWrite);
BarrierOp::create(b, b.getLoc(), AddrSpace::Local);
}
lock.insert(entryRegion, {lockVal, lockVal.getType()});
passToWarpSpecialize(entryPoint, lock.at(entryRegion), lock, captureCounter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
// buffers | tensor | <C x B x i64> | Base pointers of all (sub)buffers
// barriers | tensor | <C x K x i64> | Pointers to all individual mbarriers
// barrierStates | scratch | <C x K x i64> | Packed barrier phase (bit 0), arrival counts (bits[1..20] init, [21..40] current), and signed tx-count (bits[41..61]); zero means invalid/uninitialized
// barrierWriteRecipients | scratch | <C x K x i32> | CTA bitsets of write-tracking rows reached by outstanding TMA effects on each barrier
// barrierWriteRecipients | scratch | <C x K x i32> | CTA bitsets of EffectWrites rows published by each barrier
// waiting | scratch | <C x K x i32> | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1)
// writeVisibility | scratch | <C x B x i64> | Per-buffer thread-visibility bitmask (bit i => thread i visible)
// readVisibility | scratch | <C x B x T x i64> | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks)
Expand Down Expand Up @@ -120,12 +120,16 @@ int getCurrentThread(Operation *op, const ConSanTargetHooks *hooks) {
thread += TC_THREAD_OFFSET;
return thread;
}
if (hooks->isCLCOp(op)) {
thread += CLC_THREAD_OFFSET;
return thread;
}
return thread;
}

int getBaseThread(int thread) { return thread % NUM_THREADS; }

// Peer threads are the equivalent threads in the TMA, TC and normal
// Peer threads are the equivalent threads in the TMA, TC, CLC and normal
// thread classes.
// If a thread is a base thread, return the mask with the peers, otherwise
// return the mask with the thread itself.
Expand All @@ -134,6 +138,7 @@ uint64_t getThreadPeersMask(int thread) {
if (thread < NUM_THREADS) {
mask |= 1ULL << (thread + TMA_THREAD_OFFSET);
mask |= 1ULL << (thread + TC_THREAD_OFFSET);
mask |= 1ULL << (thread + CLC_THREAD_OFFSET);
}
return mask;
}
Expand All @@ -159,6 +164,12 @@ Value currentCTAMask(ImplicitLocOpBuilder &b) {
ctaId);
}

Value allCTAsMask(ImplicitLocOpBuilder &b) {
int numCTAs = ttg::lookupNumCTAs(b);
assert(numCTAs <= 16 && "ConSan CTA bitsets assume at most 16 CTAs");
return arith::ConstantIntOp::create(b, (1u << numCTAs) - 1, 32);
}

uint16_t getBlockBroadcastMask(Value alloc) {
auto allocTy = cast<ttg::MemDescType>(alloc.getType());
auto kBlock = StringAttr::get(alloc.getContext(), "block");
Expand Down Expand Up @@ -259,6 +270,8 @@ Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
return getMulticastRecipientCTAs(b, tmaLoad.getResult());
return currentCTAMask(b);
}
if (isa<ttng::CLCTryCancelOp>(op))
return allCTAsMask(b);
if (isTensorCoreOp(op))
return getRecipientCTAsForBroadcastMasks(
b, ttng::getCTABroadcastMasks(ttng::getModuleTwoCTAs(op), {}));
Expand All @@ -278,6 +291,8 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
tmaLoad.getBarrier());
return getLeaderCTA(b, tmaLoad.getBarrier());
}
if (isa<ttng::CLCTryCancelOp>(op))
return allCTAsMask(b);

if (isTensorCoreOp(op))
return getRecipientCTAsForBroadcastMasks(
Expand Down Expand Up @@ -560,7 +575,8 @@ class ConcurrencySanitizerImpl {
memType = MemType::SHARED_MEM;
funcBuilder.createTrackBarrierWriteForBufferCall(
b, barrier, effect.buf, effect.length, combinedPred, memType, op,
recipientCTAs, effectRecipientCTAs);
recipientCTAs, effectRecipientCTAs,
barrierInfo.diagonalEffectRecipientCTAs);
}
}
if (barrierInfo.count > 0 || barrierInfo.txCount != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op,
if (auto commit = dyn_cast<ttng::TCGen5CommitOp>(op)) {
return ttng::getModuleTwoCTAs(op) && aliasesTracked(commit.getBarrier());
}
if (auto tma = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
if (auto tma = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
return tma.getMulticast() && aliasesTracked(tma.getBarrier());
}
if (auto tma = dyn_cast<ttng::AsyncTMAGatherOp>(op)) {
return tma.getMulticast() && aliasesTracked(tma.getBarrier());
if (auto clc = dyn_cast<ttng::CLCTryCancelOp>(op)) {
return aliasesTracked(clc.getMbarrier());
}
return false;
}
Expand Down
27 changes: 27 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
return isa<ttng::TMAOpInterface>(op);
}

bool isCLCOp(Operation *op) const override {
return isa<ttng::CLCTryCancelOp>(op);
}

std::optional<BarrierInitInfo>
getBarrierInitInfo(Operation *op) const override {
if (auto initOp = dyn_cast<ttng::InitBarrierOp>(op)) {
Expand Down Expand Up @@ -101,6 +105,11 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
}
if (auto storeOp = dyn_cast<ttng::TMAStoreLikeOpInterface>(op))
mask = getBlockBroadcastMask(storeOp.getSrc().getType());
if (isa<ttng::CLCTryCancelOp>(op) && ttg::lookupNumCTAs(op) > 1) {
Value ctaId = tti::ExperimentalClusterCTAIdOp::create(b, b.getLoc());
return arith::CmpIOp::create(b, arith::CmpIPredicate::eq, ctaId,
arith::ConstantIntOp::create(b, 0, 32));
}

// In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair
// issues the op.
Expand Down Expand Up @@ -226,6 +235,24 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write,
loadOp.getResult());
}
if (auto tryCancelOp = dyn_cast<ttng::CLCTryCancelOp>(op)) {
info.emplace();
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
info->barriers.push_back(
{tryCancelOp.getMbarrier(), nullptr, /*count=*/0,
MemEffectsOpInfo::BarrierTrackingMode::EffectWrites,
/*txCount=*/
-static_cast<int>(tti::getMemDescLength(tryCancelOp.getResult())),
/*diagonalEffectRecipientCTAs=*/true});
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write,
tryCancelOp.getResult());
}
if (auto loadResultOp = dyn_cast<ttng::CLCLoadResultOp>(op)) {
info.emplace();
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read,
loadResultOp.getSrc());
}
if (auto storeOp = dyn_cast<ttng::TMAStoreLikeOpInterface>(op)) {
info.emplace();
info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount;
Expand Down
Loading