Skip to content

[MULTICTA] Fix multicast pattern for tcgen05_mma_scaled#10196

Merged
lezcano merged 2 commits into
mainfrom
multicast_fix_count
May 4, 2026
Merged

[MULTICTA] Fix multicast pattern for tcgen05_mma_scaled#10196
lezcano merged 2 commits into
mainfrom
multicast_fix_count

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented May 1, 2026

The helper we had was missing the necessary two_ctas flag to compute
the number of arrivals correctly for an mma that goes into a multicast
TMA. I changed the API all across

I think it would be much better if we just had a TTNG_InitMmaBarrierOp
as proposed in #9957, as this
would enable to get actual perf with multicast and would make the API
much cleaner, but for now we just get this helper.

What's nice is that this bug was found using the multicta consan :D

The helper we had was missing the necessary `two_ctas` flag to compute
the number of arrivals correctly for an mma that goes into a multicast
TMA. I changed the API all across

I think it would be much better if we just had a TTNG_InitMmaBarrierOp
as proposed in #9957, as this
would enable to get actual perf with multicast and would make the API
much cleaner, but for now we just get this helper.

What's nice is that this bug was found using the multicta consan :D
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9a5ab0baf5

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@masahi masahi requested a review from adstraw May 1, 2026 20:55
adstraw
adstraw previously requested changes May 1, 2026
TypedValue<MemDescType> desc) {
if (isa<SharedEncodingTrait>(desc.getType().getEncoding()))
descs.push_back(desc);
descs.push_back(desc);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a LIT test to ensure there is no regression?

Comment thread python/examples/gluon/04-2cta-block-scale-matmul.py
@lezcano lezcano requested a review from adstraw May 2, 2026 15:01
@lezcano lezcano enabled auto-merge (squash) May 2, 2026 15:02
@lezcano lezcano disabled auto-merge May 2, 2026 15:02
@lezcano lezcano dismissed adstraw’s stale review May 4, 2026 17:10

there are no regressions

@lezcano lezcano merged commit e77dbcd into main May 4, 2026
9 checks passed
@lezcano lezcano deleted the multicast_fix_count branch May 4, 2026 17:10
lezcano added a commit that referenced this pull request May 8, 2026
### Reviewers
This PR includes #10167 and
#10196 which I'll kill after
they are merged.

It is also separated in logical commits so that reviewing is simpler, as
I had to add a few optimisations to generate less code as it was taking
too long.

We left out an optimisation where we sliced the indices to avoid loading
the whole tensor from HBM when we know statically which rows we need to
load (e.g. if we just slice current_cta along a dimension). This should
help with compilation of num_ctas > 2 kernels as those take quite a bit
of time with the new model. This is not pressing as those programs are
not very performant without
#9957

### Idea
A CTA is modelled as a set of independent logical threads, as if we had
multiple
warp-specialised threads running in parallel.

Everything pretty much follows from that rule.

A multicast-layout barrier has one live barrier row, owned by the lead
CTA.
Every CTA in the barrier group may arrive / expect on that row, but only
the
lead CTA initializes, waits, and invalidates it.

This is the same model as several independent logical threads arriving
on one
barrier while only one logical thread waits on it.

### Shadow tables
The buffer and barrier tables are CTA-agnostic
We add a CTA dimension to go with every buffer/barrier/thread/mask
dimension:
```text
buffers                 | tensor  | <B x i64>
barriers                | tensor  | <K x i64>
barrierStates           | scratch | <Cbar x K x i64>
waiting                 | scratch | <Cbar x K x Cthr x i32>
writeVisibility         | scratch | <Cbuf x B x Cmask x i64>
readVisibility          | scratch | <Cbuf x B x Cthr x T x Cmask x i64>
writeTracking           | scratch | <Cbuf x B x Cbar x K x i8>
readTracking            | scratch | <Cbuf x B x Cbar x K x Cmask x i64>
outstandingCommits      | scratch | <C x B  x P x i8>
```

`Cbar`, `Cbuf`, `Cthr`, and `Cmask` are CTA dimensions qualifying
barriers,
buffers, threads, and thread masks respectively. C in outsatndingCommits
folds
the CTA dim for B and P, as there is no cross-CTA work in wgmma.
Each CTA dimension is placed
immediately before the dimension it qualifies. This keeps the multiCTA
lift
regular: a pre-existing dimension at position `pos` moves to `2 * pos`.

Aliasing happens per-CTA

### CTA Issuers and Receivers and their Representation

We need two pieces of CTA information:

1. Issuer predication: a 4-bit mask `m` such that the canonical issuer
of a
   group satisfies `(cta_id & m) == 0`.
2. Receiver sets: a runtime `cta_id`-dependent  `uint16_t` CTA bitmask 
(i.e., a `Value`) computed with the same lowering helpers used by the
    real implementation.

Example:
`mask == 0x1` in a 4-CTA kernel it means that:
CTA0 accessed CTA0 and CTA1
CTA2 accessed CTA2 and CTA3

### Cross-CTA memory effects

TMA and CLC multicast
  Their mask is the multicas group (all the CTAs in the case of CLC)

MMA / TMEMCopy, 2CTA
The lead CTA must observe all the inputs and write to all outputs of
both CTAs in the pair.
  In other words, the mask == 0x1
The idea here is that even though CTA0 and CTA1 collaborate, since it is
launched from CTA0
and its synchronisation is emitted from CTA0, we can model it as if CTA0
did all the work.

### Barrier semantics

TMA barriers and MMA completion barriers are dual.

```text
Operation                         Issuer                         Receiver
TMA, 1CTA, no multicast           cta_id                         cta_id
TMA, 2CTA, no multicast           cta_id                         cta_id & ~1
TMA, 1CTA, multicast              multicast-group leader         multicast-group
TMA, 2CTA, multicast              multicast-group leader         even cta_ids in the multicast-group

MMA, 1CTA, no multicast           cta_id                         0x0
MMA, 2CTA, no multicast           even CTA                       0x1
MMA, multicast                    see lowering                   broadcastBits_d
```

For TMA, barrier receivers are obtained by applying the barrier-leader
map to
the TMA data-receiver rows. The relevant lowering is
`third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp`.

For MMA completion, the multicast receiver mask is

```text
broadcastBits_d = getBlockBroadcastMask(d) | (twoCTAs ? 0x1 : 0x0)
```

using the same logic as
`third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp`.
Intuitively, an MMA completion waits for all CTAs that may write data
consumed
by the CTA performing the commit.

CLC is the 1-CTA TMA multicast special case whose multicast group
contains all
CTAs. The CLC layout makes this explicit through an all-zero
`cga_layout`.

For ordinary mbarrier ops, use the semantics in
`third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp`:

- `init`, `wait`, and `invalidate` are predicated on the barrier leader
CTA.
- `expect`, `arrive` are executed by each CTA.
- All operations target the leader barrier address.
- Therefore only leader barriers are live; non-leader barriers is as if
they didn't
exist and should never be accessed. In particular, non leader CTAs do
not block on
  waits.

A non-relaxed cluster barrier publishes all generic-proxy inflight
events to all
CTAs.

### Testing
I run this on `test_consan.py` and all the multicta tests / kernels that
we have, to make sure they all pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants