[MULTICTA] Fix multicast pattern for tcgen05_mma_scaled#10196
Merged
Conversation
The helper we had was missing the necessary `two_ctas` flag to compute the number of arrivals correctly for an mma that goes into a multicast TMA. I changed the API all across I think it would be much better if we just had a TTNG_InitMmaBarrierOp as proposed in #9957, as this would enable to get actual perf with multicast and would make the API much cleaner, but for now we just get this helper. What's nice is that this bug was found using the multicta consan :D
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9a5ab0baf5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
jeffniu-openai
approved these changes
May 1, 2026
adstraw
previously requested changes
May 1, 2026
| TypedValue<MemDescType> desc) { | ||
| if (isa<SharedEncodingTrait>(desc.getType().getEncoding())) | ||
| descs.push_back(desc); | ||
| descs.push_back(desc); |
Collaborator
There was a problem hiding this comment.
Can you add a LIT test to ensure there is no regression?
This was referenced May 3, 2026
Mogball
approved these changes
May 4, 2026
lezcano
added a commit
that referenced
this pull request
May 8, 2026
### Reviewers This PR includes #10167 and #10196 which I'll kill after they are merged. It is also separated in logical commits so that reviewing is simpler, as I had to add a few optimisations to generate less code as it was taking too long. We left out an optimisation where we sliced the indices to avoid loading the whole tensor from HBM when we know statically which rows we need to load (e.g. if we just slice current_cta along a dimension). This should help with compilation of num_ctas > 2 kernels as those take quite a bit of time with the new model. This is not pressing as those programs are not very performant without #9957 ### Idea A CTA is modelled as a set of independent logical threads, as if we had multiple warp-specialised threads running in parallel. Everything pretty much follows from that rule. A multicast-layout barrier has one live barrier row, owned by the lead CTA. Every CTA in the barrier group may arrive / expect on that row, but only the lead CTA initializes, waits, and invalidates it. This is the same model as several independent logical threads arriving on one barrier while only one logical thread waits on it. ### Shadow tables The buffer and barrier tables are CTA-agnostic We add a CTA dimension to go with every buffer/barrier/thread/mask dimension: ```text buffers | tensor | <B x i64> barriers | tensor | <K x i64> barrierStates | scratch | <Cbar x K x i64> waiting | scratch | <Cbar x K x Cthr x i32> writeVisibility | scratch | <Cbuf x B x Cmask x i64> readVisibility | scratch | <Cbuf x B x Cthr x T x Cmask x i64> writeTracking | scratch | <Cbuf x B x Cbar x K x i8> readTracking | scratch | <Cbuf x B x Cbar x K x Cmask x i64> outstandingCommits | scratch | <C x B x P x i8> ``` `Cbar`, `Cbuf`, `Cthr`, and `Cmask` are CTA dimensions qualifying barriers, buffers, threads, and thread masks respectively. C in outsatndingCommits folds the CTA dim for B and P, as there is no cross-CTA work in wgmma. Each CTA dimension is placed immediately before the dimension it qualifies. This keeps the multiCTA lift regular: a pre-existing dimension at position `pos` moves to `2 * pos`. Aliasing happens per-CTA ### CTA Issuers and Receivers and their Representation We need two pieces of CTA information: 1. Issuer predication: a 4-bit mask `m` such that the canonical issuer of a group satisfies `(cta_id & m) == 0`. 2. Receiver sets: a runtime `cta_id`-dependent `uint16_t` CTA bitmask (i.e., a `Value`) computed with the same lowering helpers used by the real implementation. Example: `mask == 0x1` in a 4-CTA kernel it means that: CTA0 accessed CTA0 and CTA1 CTA2 accessed CTA2 and CTA3 ### Cross-CTA memory effects TMA and CLC multicast Their mask is the multicas group (all the CTAs in the case of CLC) MMA / TMEMCopy, 2CTA The lead CTA must observe all the inputs and write to all outputs of both CTAs in the pair. In other words, the mask == 0x1 The idea here is that even though CTA0 and CTA1 collaborate, since it is launched from CTA0 and its synchronisation is emitted from CTA0, we can model it as if CTA0 did all the work. ### Barrier semantics TMA barriers and MMA completion barriers are dual. ```text Operation Issuer Receiver TMA, 1CTA, no multicast cta_id cta_id TMA, 2CTA, no multicast cta_id cta_id & ~1 TMA, 1CTA, multicast multicast-group leader multicast-group TMA, 2CTA, multicast multicast-group leader even cta_ids in the multicast-group MMA, 1CTA, no multicast cta_id 0x0 MMA, 2CTA, no multicast even CTA 0x1 MMA, multicast see lowering broadcastBits_d ``` For TMA, barrier receivers are obtained by applying the barrier-leader map to the TMA data-receiver rows. The relevant lowering is `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp`. For MMA completion, the multicast receiver mask is ```text broadcastBits_d = getBlockBroadcastMask(d) | (twoCTAs ? 0x1 : 0x0) ``` using the same logic as `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp`. Intuitively, an MMA completion waits for all CTAs that may write data consumed by the CTA performing the commit. CLC is the 1-CTA TMA multicast special case whose multicast group contains all CTAs. The CLC layout makes this explicit through an all-zero `cga_layout`. For ordinary mbarrier ops, use the semantics in `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp`: - `init`, `wait`, and `invalidate` are predicated on the barrier leader CTA. - `expect`, `arrive` are executed by each CTA. - All operations target the leader barrier address. - Therefore only leader barriers are live; non-leader barriers is as if they didn't exist and should never be accessed. In particular, non leader CTAs do not block on waits. A non-relaxed cluster barrier publishes all generic-proxy inflight events to all CTAs. ### Testing I run this on `test_consan.py` and all the multicta tests / kernels that we have, to make sure they all pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The helper we had was missing the necessary
two_ctasflag to computethe number of arrivals correctly for an mma that goes into a multicast
TMA. I changed the API all across
I think it would be much better if we just had a TTNG_InitMmaBarrierOp
as proposed in #9957, as this
would enable to get actual perf with multicast and would make the API
much cleaner, but for now we just get this helper.
What's nice is that this bug was found using the multicta consan :D