Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,51 +142,46 @@ class FunctionBuilder {
void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, uint64_t threadMask,
Value pred, MemType memType,
Operation *insertPoint,
Value recipientCTAs);
Operation *insertPoint, Value effectCTAs);
// setReadVisibility: add the threads set in threadMask to the buffer's read
// visibility bitmask.
void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, uint64_t threadMask,
Value pred, MemType memType,
Operation *insertPoint, Value recipientCTAs);
Operation *insertPoint, Value effectCTAs);
// clearWriteTracking: clear all the information about threads writing to a
// buffer.
void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, Value pred,
MemType memType, Operation *insertPoint,
Value recipientCTAs);
Value effectCTAs);
// clearReadVisibility: clear the read visibility for a buffer.
void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, Value pred,
MemType memType, Operation *insertPoint,
Value recipientCTAs);
Value effectCTAs);
// clearReadTracking: clear the read tracking for a buffer.
void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, Value pred, MemType memType,
Operation *insertPoint, Value recipientCTAs);
Operation *insertPoint, Value effectCTAs);
// trackVisibleWrites: snapshot buffers currently visible to the thread into
// the tracking table for a barrier.
void createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar,
int thread, Value pred, MemType memType,
Operation *insertPoint,
Value recipientCTAs);
Operation *insertPoint, Value barrierCTAs);
// trackVisibleReads: snapshot buffers currently visible to the thread into
// the read tracking table for a barrier.
void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
int thread, Value pred, MemType memType,
Operation *insertPoint, Value recipientCTAs);
Operation *insertPoint, Value barrierCTAs);
// trackBarrierWriteForBuffer: mark a specific buffer as tracked by a
// barrier in the write-tracking table. When diagonalEffectRecipientCTAs is
// false, every signaled barrier row publishes the full effectRecipientCTAs
// mask. When it is true, barrier row i publishes only bit i of that mask.
// barrier in the write-tracking table.
void createTrackBarrierWriteForBufferCall(ImplicitLocOpBuilder &b, Value mbar,
Value buf, uint32_t length,
Value pred, MemType memType,
Operation *insertPoint,
Value barrierRecipientCTAs,
Value effectRecipientCTAs,
bool diagonalEffectRecipientCTAs);
Value barrierCTAs,
Value effectCTAs);
// clearBarrierWriteTracking: clear all write tracking associated with the
// given barrier row.
void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar,
Expand All @@ -213,14 +208,14 @@ class FunctionBuilder {
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint,
Value recipientCTAs, bool allowNoWrite);
Value effectCTAs, bool allowNoWrite);
// verifyReadVisibility: ensure all reads from the buffer are visible to the
// thread.
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint,
Value recipientCTAs);
Value effectCTAs);
// copyWriteVisibility: replicate the write visibility bit of sourceThread to
// every destination thread in destMask.
void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
Expand All @@ -231,6 +226,11 @@ class FunctionBuilder {
void createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
uint64_t destMask, Value pred,
MemType memType, Operation *insertPoint);
// publishClusterVisibility: after a non-relaxed cluster barrier, make
// synchronous facts visible to every CTA in the cluster.
void createPublishClusterVisibilityCall(ImplicitLocOpBuilder &b, Value pred,
MemType memType,
Operation *insertPoint);
// stageAccessForCommit: mark the buffer as staged (value -1) in the
// outstanding commit table for this thread.
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
Expand Down Expand Up @@ -273,7 +273,7 @@ class FunctionBuilder {
void createCheckOutstandingCommitsCall(
ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread,
StringRef pendingAccessType, Value pred, MemType memType,
CommitKind::Kind commitKind, Operation *insertPoint, Value recipientCTAs,
CommitKind::Kind commitKind, Operation *insertPoint, Value effectCTAs,
bool excludeSelf = false);

private:
Expand Down
61 changes: 42 additions & 19 deletions include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ implementation:
ConSan currently supports one public entry point in the module. It uses
BufferRegion analysis to collect shared-memory buffers, tensor-memory buffers,
and barrier allocations, then creates auxiliary state in distributed tensors and
shared-cluster global scratch memory. Most state has a leading CTA dimension so
shared-cluster global scratch memory. Most scratch state is CTA-qualified so
cluster and multicast effects can be modeled explicitly.

## Visibility Model
Expand Down Expand Up @@ -58,6 +58,24 @@ At a `ttg.warp_specialize`, the pass copies the default thread's read and write
visibility to the destination partition peer masks so partition-local execution
starts with the visibility frontier that existed before specialization.

## CTA Model

The single-CTA case is the degenerate form of the multiCTA model: every
CTA-qualified axis has one row, so the usual per-buffer and per-barrier rules
apply unchanged.

For multiCTA kernels, each CTA is modeled as its own set of logical threads.
Buffer and barrier descriptors stay CTA-agnostic, while shadow state records the
CTA whose buffer row, barrier row, logical thread, or visibility mask a fact
belongs to. This keeps the single-CTA visibility rules intact and adds only the
question of which CTA rows an operation reads, writes, or synchronizes.

A multicast-layout barrier has one live barrier row per multicast group, owned
by the group's lead CTA. Every CTA in the group may `arrive` or `expect` on that
row, but only the lead CTA initializes, waits on, and invalidates it. This is
the same model as several independent logical threads arriving on one barrier
while only one logical thread waits on it; non-leader barrier rows are not live.

## Runtime State

ConSan keeps enough runtime state to answer two questions at each instrumented
Expand All @@ -74,7 +92,7 @@ At a high level, the pass maintains:
- Optional alias metadata when tracked buffer regions overlap.
- A shared-cluster lock that serializes instrumentation updates.

Most runtime state is CTA-indexed so cluster, multicast, and cross-CTA effects
Most runtime state is CTA-qualified so cluster, multicast, and cross-CTA effects
can be represented directly. Scratch state is zero-initialized once before the
instrumented body runs, and the initialization is followed by a CTA or cluster
barrier before any instrumented operation can use it.
Expand All @@ -94,7 +112,8 @@ selected buffer.
The runtime checks account for aliasing and CTA recipients. A check against one
buffer is expanded through the alias metadata when BufferRegion analysis found
overlapping tracked regions, and multi-CTA operations only inspect the CTA rows
that the operation can affect.
that the operation can affect. Aliasing remains intra-CTA: overlapping
descriptors may alias within one CTA row, but not across different CTA rows.

After a barrier-tracked read, ConSan records that the current peer thread mask
can see that read. After a barrier-tracked write, ConSan records the current
Expand All @@ -105,28 +124,28 @@ All normal instrumentation emitted around one IR operation is wrapped in the
ConSan lock. Barrier waits are split into a locked pre-wait section and a locked
post-wait section.

## CTA Recipients and Multicast
## CTA Issuers, Effects, and Recipients

Most multi-CTA instrumentation computes a CTA recipient bitset. That bitset is
converted to a mask over the CTA dimension so only relevant CTA rows are checked
or updated.
Single-CTA operations implicitly use the current CTA for all three roles. In a
multiCTA kernel those roles can differ:

The target hooks compute recipients from the operation:
- The issuer predicate selects which CTA actually executes the instrumented op.
- The memory-effect CTA bitset selects the buffer rows that the op reads or
writes.
- The barrier-recipient CTA bitset selects the live barrier rows that arrivals,
expectations, and completion signals update.

- Non-multicast operations usually target the current CTA.
- Multicast TMA loads update all result-recipient CTAs, while barrier arrivals
route to the leader barrier CTA.
- NVIDIA two-CTA Tensor Core operations are predicated to the issuing CTA pair
leader.
- TMA load effects that write one set of CTAs and signal a different leader
barrier remember the effect-recipient CTA rows so a later wait can transfer
write visibility to both the waiting CTA and the written CTA rows.
- CLC try-cancel effects use all CTA rows for the barrier and memory effect,
with diagonal recipient handling for the written result rows.
For example, a multicast TMA load is issued by the multicast-group leader,
writes every result-recipient CTA row, and signals the leader barrier row. A
two-CTA Tensor Core operation is issued by the even CTA in the pair, but its
memory effects cover both CTA rows in that pair. CLC try-cancel is issued once
for the cluster and touches all CTA rows.

## Barrier Synchronization

ConSan separates barrier tracking from visibility transfer.
ConSan separates barrier tracking from visibility transfer. Ordinary mbarrier
operations follow the live-row rule from the CTA model above: all participating
CTAs address the lead barrier row, and only the lead CTA performs the wait.

For frontier-tracked barriers, an arrive or commit snapshots the current
thread's visible writes and reads into the barrier's tracking state. A later
Expand Down Expand Up @@ -154,6 +173,10 @@ Write transfers also consult the recorded effect-recipient CTA rows, which lets
TMA-style and CLC cross-CTA writes become visible in the CTA rows reached by the
memory effect. Read transfers update the current CTA row.

A non-relaxed cluster barrier is different from an mbarrier wait: it publishes
synchronous work from base threads to all CTA rows directly (i.e., just the generic
proxy).

## Barrier Lifecycle and Deadlock Checks

The barrier state table models initialized, invalidated, phase, arrival-count,
Expand Down
65 changes: 37 additions & 28 deletions include/triton/Dialect/TritonInstrument/IR/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,7 @@ class FunctionBuilder;

constexpr int numMemTypes = getMaxEnumValForMemType() + 1;

constexpr int NUM_THREADS = 16;
constexpr int TMA_THREAD_OFFSET = NUM_THREADS;
constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS;
constexpr int CLC_THREAD_OFFSET = TC_THREAD_OFFSET + NUM_THREADS;
constexpr int TOTAL_NUM_THREADS = CLC_THREAD_OFFSET + NUM_THREADS;
static_assert(TOTAL_NUM_THREADS <= 64,
"ConSan thread bitsets are stored in i64 masks");
const int THREADS_BITMASK_SIZE = llvm::PowerOf2Ceil(TOTAL_NUM_THREADS);
constexpr int MAX_NUM_BASE_THREADS = 16;

namespace CommitKind {
enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds };
Expand All @@ -42,9 +35,8 @@ enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds };
// writeVisibility + readVisibility per active memory type.
constexpr int kCapturesPerMemType = 2;

// barrierStates + waiting + activeMasks + barrierWriteRecipients (only when
// barriers exist).
constexpr int kBarrierBaseCaptures = 4;
// barrierStates + waiting + activeMasks (only when barriers exist).
constexpr int kBarrierBaseCaptures = 3;

// writeTracking + readTracking per active memory type (only when barriers
// exist and the memory type has buffers).
Expand Down Expand Up @@ -122,6 +114,20 @@ struct ValueType {
// that pointer. For tensor descriptors and constants, ValueType::value is the
// tensor itself and ValueType::type is its type.
struct AuxDataMap {
struct ThreadLayout {
int numBaseThreads = 1;
int numBaseThreadSlots = 1;
int tmaThreadOffset = -1;
int tcThreadOffset = -1;
int clcThreadOffset = -1;
int totalNumThreads = 1;
int numThreadSlots = 1;

bool hasTMAThreads() const { return tmaThreadOffset >= 0; }
bool hasTCThreads() const { return tcThreadOffset >= 0; }
bool hasCLCThreads() const { return clcThreadOffset >= 0; }
};

struct RegionToValueMap {
DenseMap<Region *, ValueType> values;
ValueType at(Region *region) {
Expand All @@ -142,52 +148,49 @@ struct AuxDataMap {

// Shape notation:
// C = CTAs in the cluster.
// Cbar, Cbuf, Cthr, Cmask = CTA dimensions qualifying barriers, buffers,
// threads, and thread masks respectively. Each has extent C.
// B = tracked buffers for one memory type, power-of-two padded.
// K = tracked mbarriers, power-of-two padded.
// T = logical ConSan thread bit slots, padded to 64.
// P = base-thread commit columns, currently 16.
// T = logical ConSan thread bit slots used by this module, power-of-two
// padded for the distributed layout.
// P = base-thread commit columns used by this module, power-of-two padded.
//
// Storage notation:
// tensor = distributed tensor value.
// scratch = pointer to shared-cluster global scratch memory.

// tensor, <C x B x i64>
// tensor, <B x i64>
// Per-memory-type packed buffer descriptors. Each i64 stores the 32-bit base
// offset and 32-bit length of one shared-memory or tensor-memory region.
RegionToValueMap buffers[numMemTypes];

// tensor, <C x K x i64>
// tensor, <K x i64>
// Packed descriptors for tracked mbarrier allocations. Barriers are shared
// memory descriptors.
RegionToValueMap barriers;

// scratch, <C x K x i64>
// scratch, <Cbar x K x i64>
// Packed barrier lifecycle state. Zero means invalid/uninitialized. Bit 0 is
// phase, bits [1..20] are the initial arrival count, bits [21..40] are the
// current arrival count, and bits [41..61] hold a signed tx-count.
RegionToValueMap barrierStates;

// scratch, <C x K x i32>
// Per-barrier CTA bitsets of write-recipient rows reached by outstanding
// EffectWrites operations such as TMA and CLC. Used when a later wait
// transfers tracked writes.
RegionToValueMap barrierWriteRecipients;

// scratch, <C x B x i64>
// scratch, <Cbuf x B x Cmask x i64>
// Per-memory-type write frontier. Bit i means logical ConSan thread i can see
// the latest write to the buffer row.
RegionToValueMap writeVisibility[numMemTypes];

// scratch, <C x B x K x i8>
// scratch, <Cbuf x B x Cbar x K x i8>
// Per-memory-type buffer/barrier map for writes that a barrier tracks.
RegionToValueMap writeTracking[numMemTypes];

// scratch, <C x B x T x i64>
// scratch, <Cbuf x B x Cthr x T x Cmask x i64>
// Per-memory-type read frontier. For each buffer and logical thread lane, the
// i64 value is a bitmask of reads visible to that lane's thread.
RegionToValueMap readVisibility[numMemTypes];

// scratch, <C x B x K x i64>
// scratch, <Cbuf x B x Cbar x K x Cmask x i64>
// Per-memory-type buffer/barrier map for read visibility masks that a barrier
// tracks.
RegionToValueMap readTracking[numMemTypes];
Expand All @@ -196,9 +199,11 @@ struct AuxDataMap {
// Per-commit-kind outstanding commit counters for shared-memory buffers.
// Entries are 0 for none, -1 for staged but uncommitted, and positive for a
// committed access with an outstanding-group distance.
// Just one C dimension as ampere async_copy, WGMMA and TMA store are
// intra-CTA.
RegionToValueMap commits[CommitKind::NumCommitKinds];

// tensor, <C x B x B x i1>
// tensor, <B x B x i1>
// Optional per-memory-type alias matrix. Created only when BufferRegion
// analysis finds cross-buffer aliasing; checks expand selected buffer rows
// through this matrix.
Expand All @@ -208,7 +213,7 @@ struct AuxDataMap {
// Shared-cluster lock used to serialize ConSan instrumentation updates.
RegionToValueMap lock;

// scratch, <C x K x i32>
// scratch, <Cbar x K x Cthr x i32>
// Deadlock-detection bitfield. Each base thread uses two bits: waiting flag
// and stored phase.
RegionToValueMap waiting;
Expand All @@ -223,6 +228,10 @@ struct AuxDataMap {
// aliasMatrices to make visibility and commit checks conservative.
std::array<bool, numMemTypes> hasNonTrivialAliasing{};

// Dense logical-thread numbering for this module. Base threads are always
// present; TMA/TC/CLC peer ranges are added only when the module uses them.
ThreadLayout threadLayout;

LogicalResult populateAndPassToWarpSpecialize(ModuleOp module,
FunctionBuilder &funcBuilder,
const ConSanTargetHooks *hooks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,6 @@ struct MemEffectsOpInfo {
int count;
BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier;
int txCount = 0;
// For EffectWrites, effectRecipientCTAs identifies the CTA rows where the
// op wrote its explicit result. By default, for
// diagonalEffectRecipientCTAs=false, waiting on a barrier publishes the CTA
// rows in effectRecipientCTAs, which is the full mask. This is the
// behaviour of TMA multicast. If diagonalEffectRecipientCTAs is true,
// waiting on a barrier publishes only the CTA rows in effectRecipientCTAs,
// which is the diagonal mask. e.g. effectRecipientCTAs = 0b1101 If
// DiagonalEffectRecipientCTAs is false, waiting on the barrier publishes
// the following CTA rows: CTA0 0b1101 CTA1 0b1101 CTA2 0b1101 CTA3 0b1101
// If diagonalEffectRecipientCTAs is true, waiting on the barrier publishes
// the following CTA rows: CTA0 0b1000 CTA1 0b0100 CTA2 0b0000 CTA3 0b0001
bool diagonalEffectRecipientCTAs = false;
};
enum class TrackingKind {
None,
Expand Down
Loading
Loading