Skip to content

[EXAMPLES] Implement multicta attention#10211

Merged
lezcano merged 2 commits into
mainfrom
multicta_attn
May 5, 2026
Merged

[EXAMPLES] Implement multicta attention#10211
lezcano merged 2 commits into
mainfrom
multicta_attn

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented May 3, 2026

Now the attention example supports arbitrary num_ctas by sharding
along M (didn't try to shard along N, but we probably need general LLs
in TMEM for that).

We implement a reinterpret trick to be able to share the channel between
K and V as the cga_layout now is transposed.

The only real change in the 1CTA case is that we now do the inits and
arrives separated rather than interleaved, as this is a requirement in
the 2CTA case.
We also drop the release part as they were not necessary in the 1CTA case
and in the 2CTA case they were incorrect (luckily caught by consan) as you
could have CTA0 run ahead and invalidate the barrier before CTA1 had finished
committing to it.

The pattern is:

  • CTA0 correction does its own final arrive(o_bar)
  • CTA0 has no further producer acquire that would wait for CTA1’s matching arrive
  • CTA0 can leave the WS region while CTA1 still has not performed its final consumer arrive

With these two changes, the generated SASS is about 30 lines shorter and
the perf before/after is within noise (perhaps minimally faster)

I mostly did this to test the new consan implementation, but
a nice corollary of this multicta implementation is that we got some
real wins without any tuning for the not causal and D == 128 and bitwidth == 16.

I just was able to bench this in a GB300 tho. If this heuristic change is
controversial I'm happy to drop it and hand off to the kernel experts.

There are some chances we can exploit a bit more perf changing the
heuristics, but I didn't do that

The benchmarks

Case Previous 1CTA Current 1CTA Current 2CTA 2CTA vs Previous
FP16, 1024 1038.35 1058.31 1067.81 +2.84%
FP16, 2048 1457.05 1462.41 1514.95 +3.97%
FP16, 4096 1604.76 1608.22 1673.11 +4.26%
FP16, 8192 1668.91 1676.79 1728.97 +3.60%
BF16, 1024 1077.77 1090.04 1094.33 +1.54%
BF16, 2048 1509.09 1514.40 1561.27 +3.46%
BF16, 4096 1643.15 1642.42 1704.08 +3.71%
BF16, 8192 1722.90 1722.20 1786.45 +3.69%

@lezcano lezcano requested a review from ptillet as a code owner May 3, 2026 17:13
@lezcano lezcano requested a review from Mogball May 3, 2026 17:14
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c62f17be04

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread python/examples/gluon/01-attention-forward.py
@Mogball
Copy link
Copy Markdown
Collaborator

Mogball commented May 4, 2026

will TAL tomororw. really interested to see the fp8 numbers though! at a glance this is really clean

lezcano added 2 commits May 4, 2026 19:14
Now the attention example supports arbitrary num_ctas by sharding
along M (didn't try to shard along N, but we probably need general LLs
in TMEM for that).

We implement a reinterpret trick to be able to share the channel between
K and V as the cga_layout now is transposed.

The only real change in the 1CTA case is that we now do the inits and
arrives separated rather than interleaved, as this is a requirement in
the 2CTA case.
We also drop the release part as they were not necessary in the 1CTA case
and in the 2CTA case they were incorrect (luckily caught by consan) as you
could have CTA0 run ahead and invalidate the barrier before CTA1 had finished
committing to it.

The pattern is:
- CTA0 correction does its own final arrive(o_bar)
- CTA0 has no further producer acquire that would wait for CTA1’s matching arrive
- CTA0 can leave the WS region while CTA1 still has not performed its final consumer arrive

With these two changes, the generated SASS is about 30 lines shorter and
the perf before/after is within noise (perhaps minimally faster)

I mostly did this to test the new consan implementation, but
a nice corollary of this multicta implementation is that we got some
real wins without any tuning for the `not causal and D == 128 and bitwidth == 16`.

I just was able to bench this in a GB300 tho. If this change is
controversial I'm happy to drop it and hand off to the kernel experts.

There are some chances we can exploit a bit more perf changing the
heuristics, but I didn't do that

The benchmarks

| Case | Previous 1CTA | Current 1CTA | Current 2CTA | Speedup vs prev |
|---|---:|---:|---:|---:|---:|
| FP16, 1024 | 1038.35 | 1058.31 | 1067.81 | +2.84% |
| FP16, 2048 | 1457.05 | 1462.41 | 1514.95 | +3.97% |
| FP16, 4096 | 1604.76 | 1608.22 | 1673.11 | +4.26% |
| FP16, 8192 | 1668.91 | 1676.79 | 1728.97 | +3.60% |
| BF16, 1024 | 1077.77 | 1090.04 | 1094.33 | +1.54% |
| BF16, 2048 | 1509.09 | 1514.40 | 1561.27 | +3.46% |
| BF16, 4096 | 1643.15 | 1642.42 | 1704.08 | +3.71% |
| BF16, 8192 | 1722.90 | 1722.20 | 1786.45 | +3.69% |
@lezcano lezcano merged commit cd8e4ac into main May 5, 2026
16 of 18 checks passed
@lezcano lezcano deleted the multicta_attn branch May 5, 2026 22:36
Copy link
Copy Markdown
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, left my review in draft.

USE_TMEM_RED=use_selected_tmem_red,
NUM_KV_BUFFERS=num_kv_buffers,
USE_EXP2_TURNSTILE=use_exp2_turnstile,
CTA_LAYOUT=cta_layout,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Can we use CGA consistently?

s1_chnl.release()
c0_chnl.release()
c1_chnl.release()
exp_turnstile.release()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit sceptical about not releasing barriers properly. In this case it's fine because the warp_specialize op will increase the live range of the shared memory, but it's not always safe so I'm not sure we should encourage it by putting it in the official examples.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just extend the lifetime instead then

o_instr_shape = get_mma_instr_shape(o_cta_shape, gl.float32)
self.qk_tmem_layout = gl.constexpr(
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1, cga_layout=self.CGA_LAYOUT,
two_ctas=bool(self.CGA_LAYOUT)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Add a config.TWO_CTAS?


@gluon.jit
def acquire_from(self, mem):
mem, ready_bar = self.channel.acquire_producer(self.counter, mem)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that mem comes from kv_chnl.mem._reinterpret(...). Would it be possible to reinterpret the acquire result instead of allowing acquire to be called with potentially unrelated memory? IMO this really breaks the abstraction.


@gluon.jit
def acquire_producer(self, counter):
def init(self, num_producers: gl.constexpr = 1, num_consumers: gl.constexpr = 1):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that prime is needed to allow for a single cluster mbarrier init fence, but why split init from alloc?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants