[EXAMPLES] Implement multicta attention#10211
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c62f17be04
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
will TAL tomororw. really interested to see the fp8 numbers though! at a glance this is really clean |
Now the attention example supports arbitrary num_ctas by sharding along M (didn't try to shard along N, but we probably need general LLs in TMEM for that). We implement a reinterpret trick to be able to share the channel between K and V as the cga_layout now is transposed. The only real change in the 1CTA case is that we now do the inits and arrives separated rather than interleaved, as this is a requirement in the 2CTA case. We also drop the release part as they were not necessary in the 1CTA case and in the 2CTA case they were incorrect (luckily caught by consan) as you could have CTA0 run ahead and invalidate the barrier before CTA1 had finished committing to it. The pattern is: - CTA0 correction does its own final arrive(o_bar) - CTA0 has no further producer acquire that would wait for CTA1’s matching arrive - CTA0 can leave the WS region while CTA1 still has not performed its final consumer arrive With these two changes, the generated SASS is about 30 lines shorter and the perf before/after is within noise (perhaps minimally faster) I mostly did this to test the new consan implementation, but a nice corollary of this multicta implementation is that we got some real wins without any tuning for the `not causal and D == 128 and bitwidth == 16`. I just was able to bench this in a GB300 tho. If this change is controversial I'm happy to drop it and hand off to the kernel experts. There are some chances we can exploit a bit more perf changing the heuristics, but I didn't do that The benchmarks | Case | Previous 1CTA | Current 1CTA | Current 2CTA | Speedup vs prev | |---|---:|---:|---:|---:|---:| | FP16, 1024 | 1038.35 | 1058.31 | 1067.81 | +2.84% | | FP16, 2048 | 1457.05 | 1462.41 | 1514.95 | +3.97% | | FP16, 4096 | 1604.76 | 1608.22 | 1673.11 | +4.26% | | FP16, 8192 | 1668.91 | 1676.79 | 1728.97 | +3.60% | | BF16, 1024 | 1077.77 | 1090.04 | 1094.33 | +1.54% | | BF16, 2048 | 1509.09 | 1514.40 | 1561.27 | +3.46% | | BF16, 4096 | 1643.15 | 1642.42 | 1704.08 | +3.71% | | BF16, 8192 | 1722.90 | 1722.20 | 1786.45 | +3.69% |
peterbell10
left a comment
There was a problem hiding this comment.
Oops, left my review in draft.
| USE_TMEM_RED=use_selected_tmem_red, | ||
| NUM_KV_BUFFERS=num_kv_buffers, | ||
| USE_EXP2_TURNSTILE=use_exp2_turnstile, | ||
| CTA_LAYOUT=cta_layout, |
There was a problem hiding this comment.
NIT: Can we use CGA consistently?
| s1_chnl.release() | ||
| c0_chnl.release() | ||
| c1_chnl.release() | ||
| exp_turnstile.release() |
There was a problem hiding this comment.
I'm a bit sceptical about not releasing barriers properly. In this case it's fine because the warp_specialize op will increase the live range of the shared memory, but it's not always safe so I'm not sure we should encourage it by putting it in the official examples.
There was a problem hiding this comment.
let's just extend the lifetime instead then
| o_instr_shape = get_mma_instr_shape(o_cta_shape, gl.float32) | ||
| self.qk_tmem_layout = gl.constexpr( | ||
| TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1, cga_layout=self.CGA_LAYOUT, | ||
| two_ctas=bool(self.CGA_LAYOUT))) |
There was a problem hiding this comment.
NIT: Add a config.TWO_CTAS?
|
|
||
| @gluon.jit | ||
| def acquire_from(self, mem): | ||
| mem, ready_bar = self.channel.acquire_producer(self.counter, mem) |
There was a problem hiding this comment.
I see that mem comes from kv_chnl.mem._reinterpret(...). Would it be possible to reinterpret the acquire result instead of allowing acquire to be called with potentially unrelated memory? IMO this really breaks the abstraction.
|
|
||
| @gluon.jit | ||
| def acquire_producer(self, counter): | ||
| def init(self, num_producers: gl.constexpr = 1, num_consumers: gl.constexpr = 1): |
There was a problem hiding this comment.
I understand that prime is needed to allow for a single cluster mbarrier init fence, but why split init from alloc?
Now the attention example supports arbitrary num_ctas by sharding
along M (didn't try to shard along N, but we probably need general LLs
in TMEM for that).
We implement a reinterpret trick to be able to share the channel between
K and V as the cga_layout now is transposed.
The only real change in the 1CTA case is that we now do the inits and
arrives separated rather than interleaved, as this is a requirement in
the 2CTA case.
We also drop the release part as they were not necessary in the 1CTA case
and in the 2CTA case they were incorrect (luckily caught by consan) as you
could have CTA0 run ahead and invalidate the barrier before CTA1 had finished
committing to it.
The pattern is:
With these two changes, the generated SASS is about 30 lines shorter and
the perf before/after is within noise (perhaps minimally faster)
I mostly did this to test the new consan implementation, but
a nice corollary of this multicta implementation is that we got some
real wins without any tuning for the
not causal and D == 128 and bitwidth == 16.I just was able to bench this in a GB300 tho. If this heuristic change is
controversial I'm happy to drop it and hand off to the kernel experts.
There are some chances we can exploit a bit more perf changing the
heuristics, but I didn't do that
The benchmarks