Skip to content

[CONSAN] Multi CTA model v2#10212

Merged
lezcano merged 10 commits into
mainfrom
consan_multicta_final
May 8, 2026
Merged

[CONSAN] Multi CTA model v2#10212
lezcano merged 10 commits into
mainfrom
consan_multicta_final

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented May 4, 2026

Reviewers

This PR includes #10167 and #10196 which I'll kill after they are merged.

It is also separated in logical commits so that reviewing is simpler, as I had to add a few optimisations to generate less code as it was taking too long.

We left out an optimisation where we sliced the indices to avoid loading the whole tensor from HBM when we know statically which rows we need to load (e.g. if we just slice current_cta along a dimension). This should help with compilation of num_ctas > 2 kernels as those take quite a bit of time with the new model. This is not pressing as those programs are not very performant without #9957

Idea

A CTA is modelled as a set of independent logical threads, as if we had multiple
warp-specialised threads running in parallel.

Everything pretty much follows from that rule.

A multicast-layout barrier has one live barrier row, owned by the lead CTA.
Every CTA in the barrier group may arrive / expect on that row, but only the
lead CTA initializes, waits, and invalidates it.

This is the same model as several independent logical threads arriving on one
barrier while only one logical thread waits on it.

Shadow tables

The buffer and barrier tables are CTA-agnostic
We add a CTA dimension to go with every buffer/barrier/thread/mask dimension:

buffers                 | tensor  | <B x i64>
barriers                | tensor  | <K x i64>
barrierStates           | scratch | <Cbar x K x i64>
waiting                 | scratch | <Cbar x K x Cthr x i32>
writeVisibility         | scratch | <Cbuf x B x Cmask x i64>
readVisibility          | scratch | <Cbuf x B x Cthr x T x Cmask x i64>
writeTracking           | scratch | <Cbuf x B x Cbar x K x i8>
readTracking            | scratch | <Cbuf x B x Cbar x K x Cmask x i64>
outstandingCommits      | scratch | <C x B  x P x i8>

Cbar, Cbuf, Cthr, and Cmask are CTA dimensions qualifying barriers,
buffers, threads, and thread masks respectively. C in outsatndingCommits folds
the CTA dim for B and P, as there is no cross-CTA work in wgmma.
Each CTA dimension is placed
immediately before the dimension it qualifies. This keeps the multiCTA lift
regular: a pre-existing dimension at position pos moves to 2 * pos.

Aliasing happens per-CTA

CTA Issuers and Receivers and their Representation

We need two pieces of CTA information:

  1. Issuer predication: a 4-bit mask m such that the canonical issuer of a
    group satisfies (cta_id & m) == 0.
  2. Receiver sets: a runtime cta_id-dependent uint16_t CTA bitmask
    (i.e., a Value) computed with the same lowering helpers used by the
    real implementation.

Example:
mask == 0x1 in a 4-CTA kernel it means that:
CTA0 accessed CTA0 and CTA1
CTA2 accessed CTA2 and CTA3

Cross-CTA memory effects

TMA and CLC multicast
Their mask is the multicas group (all the CTAs in the case of CLC)

MMA / TMEMCopy, 2CTA
The lead CTA must observe all the inputs and write to all outputs of both CTAs in the pair.
In other words, the mask == 0x1
The idea here is that even though CTA0 and CTA1 collaborate, since it is launched from CTA0
and its synchronisation is emitted from CTA0, we can model it as if CTA0 did all the work.

### Barrier semantics

TMA barriers and MMA completion barriers are dual.

Operation                         Issuer                         Receiver
TMA, 1CTA, no multicast           cta_id                         cta_id
TMA, 2CTA, no multicast           cta_id                         cta_id & ~1
TMA, 1CTA, multicast              multicast-group leader         multicast-group
TMA, 2CTA, multicast              multicast-group leader         even cta_ids in the multicast-group

MMA, 1CTA, no multicast           cta_id                         0x0
MMA, 2CTA, no multicast           even CTA                       0x1
MMA, multicast                    see lowering                   broadcastBits_d

For TMA, barrier receivers are obtained by applying the barrier-leader map to
the TMA data-receiver rows. The relevant lowering is
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp.

For MMA completion, the multicast receiver mask is

broadcastBits_d = getBlockBroadcastMask(d) | (twoCTAs ? 0x1 : 0x0)

using the same logic as
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp.
Intuitively, an MMA completion waits for all CTAs that may write data consumed
by the CTA performing the commit.

CLC is the 1-CTA TMA multicast special case whose multicast group contains all
CTAs. The CLC layout makes this explicit through an all-zero cga_layout.

For ordinary mbarrier ops, use the semantics in
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp:

  • init, wait, and invalidate are predicated on the barrier leader CTA.
  • expect, arrive are executed by each CTA.
  • All operations target the leader barrier address.
  • Therefore only leader barriers are live; non-leader barriers is as if they didn't
    exist and should never be accessed. In particular, non leader CTAs do not block on
    waits.

A non-relaxed cluster barrier publishes all generic-proxy inflight events to all
CTAs.

Testing

I run this on test_consan.py and all the multicta tests / kernels that we have, to make sure they all pass.

@lezcano lezcano requested a review from pawelszczerbuk May 4, 2026 10:32
@lezcano lezcano changed the title [CONSAN] Multi CTA from first principles [CONSAN] Multi CTA rewrite May 4, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ca16eb7640

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp
@lezcano lezcano changed the title [CONSAN] Multi CTA rewrite [CONSAN] Multi CTA model v2 May 4, 2026
@lezcano lezcano force-pushed the consan_multicta_final branch 4 times, most recently from 5cc79d2 to 53c47ed Compare May 5, 2026 15:00
@peterbell10 peterbell10 removed their request for review May 5, 2026 19:49
Copy link
Copy Markdown
Collaborator

@jeffniu-openai jeffniu-openai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm but @pawelszczerbuk should take a look

Copy link
Copy Markdown
Collaborator

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love having to manage 2 accounts

Comment thread include/triton/Dialect/TritonInstrument/IR/Utility.h Outdated
auto ptrTy = triton::getPointerType(memTy.getElementType());
auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy,
sizeInBytes, alignment);
allocOp->setDiscardableAttr("tt.divisibility",
Copy link
Copy Markdown
Contributor

@pawelszczerbuk pawelszczerbuk May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did these changes come from?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b5604f2
these were lifted to createThirdPartyScratchAlloc so that everyone benefits from the vectorisation part.

@lezcano lezcano force-pushed the consan_multicta_final branch from 53c47ed to aa9e664 Compare May 8, 2026 22:07
@lezcano lezcano enabled auto-merge (squash) May 8, 2026 22:07
@lezcano lezcano merged commit d80d286 into main May 8, 2026
9 checks passed
@lezcano lezcano deleted the consan_multicta_final branch May 8, 2026 22:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants