[BACKEND] Add preferred cluster fallback#9957
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 91e45fd1da
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if (isa<ttng::AsyncTMAReduceOp, ttng::AsyncTMAGatherOp, | ||
| ttng::AsyncTMAScatterOp>(op)) | ||
| return unsupported(); |
There was a problem hiding this comment.
Reject cluster sync ops before enabling fallback
The safety walk only rejects a narrow set of ops and allows ttng.cluster_arrive/wait/barrier to pass through. With preferred-fallback enabled, a full-cluster barrier can be weakened into per-fallback-cluster barriers (e.g., 8→2 CTAs), changing synchronization semantics and potentially causing wrong results or hangs. These cluster sync ops should be treated as unsupported for fallback.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
I don't think this is the case, as we prove via the layout of the ops of the program that we don't do any cross-CTA work, so in these programs you should never need these ops...
There was a problem hiding this comment.
Reviewers, can you have a look at this one?
There was a problem hiding this comment.
@lezcano does a cluster sync imply a cluster membar (similar to what happens at the CTA level)? If so then a kernel could rely on cluster synchronisation if it uses global memory as a scratchpad for example, even if there are no explicit cross-CTA layout conversions.
There was a problem hiding this comment.
Ah, right. I was also wondering about global atomics, whether we should ban those as well. cc @peterbell10
There was a problem hiding this comment.
Not just atomics, but any access to global memory could be used to cross cta boundaries. e.g. each CTA stores to their own scratch, then does a cluster barrier and reads from its neighbours.
peterbell10
left a comment
There was a problem hiding this comment.
I haven't looked at the code, but I still can't see how this could possibly be correct. Can you give an example of a program where running it on fewer CTAs still gives you the same result?
|
03-matmul-multicta.py is an example of such a program. In general, the prototypical program that you would want to support would be a matmul TMA multicast. Omitting all the necessary synchronisation, something of the form: |
|
|
peterbell10
left a comment
There was a problem hiding this comment.
Discussed with Mario offline. I think the semantics do make sense and are useful. Still have a few concerns about the details though.
| // Convert ctaid to clusterid, which is the real program id | ||
| // Note that all cluster CTAs are distributed in the X dim | ||
| if (op.getDim() == ProgramIDDim::X) { | ||
| auto numCTAs = ttg::lookupNumCTAs(op); | ||
| if (numCTAs > 1) { | ||
| TritonLLVMOpBuilder b(loc, rewriter); | ||
| result = b.sdiv(result, b.i32_val(numCTAs)); | ||
| } | ||
| } |
There was a problem hiding this comment.
I'm not quite convinced this works with CLC. Since the cta in cluster id is implemented as ctaid % numCTAs, I think this implies we will process the work for cancelled_ctaid // numCTAs + ctaid % numCTAs, where ctaid doesn't come from the clc result, so may be wrong.
There was a problem hiding this comment.
Yeah, you are right. And actually solving the issue is a bit tricky with the current design, so I am just disabling it for now. It's not clear to me how to best fix it without having full control of the scheduler tbh.
|
review addressed. |
|
I'm still running some benchmarks see how they look. |
|
While running the benchmarks, I realised that there is a big issue with the approach. The user in gluon can inspect the layouts and perform constexpr computations using them. There is no way we can know whether these computations are invariant under layout changes. On the other hand, hacking around that and lowering the This is not autotuned, I'm using an optimal config for 2CTAs. So this shows that this method makes If we wanted, we could hide this method behind a kernel flag that could be turned on by the user. It would be direct to generalise consan to support this mode I believe. |
very interesting analysis and results. I'm also a bit concerned about changing the semantic without user knowing it. But exposing a flag that let the user opt in if we can clearly describe the rules that need to be followed + good sanitizer sounds like a reasonable solution. That being said I would wait until we find use cases where this bring enough performance boost to justify the extra language feature to merge this. |
ed6a55e to
693e8db
Compare
|
Actually, I think that the issue is not as problematic as I initially thought. The issue I found was not part of "we are using different layouts", there is not much that you can fuck up there as the gluon API is rather narrow. For multicast ops in this mode we naturally need to support it in the LLVM lowering, while this part we were handling in the frontend before. As such, the natural fix is to push this to a With this, I think this patch is safe to land really, as we cover all the necessary infra to support the only cross-CTA op we support in this mode, TMA multicast. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 693e8db006
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
peterbell10
left a comment
There was a problem hiding this comment.
I agree with Thomas that it would be good to show that the multicast setup gives a significant perf increase first, before introducing new complexity. I'm concerned that this will leak into user code and give surprising bugs.
| mma_bar = mbarrier.allocate_mbarrier() | ||
| mma_bar_count: ttgl.constexpr = blackwell.tcgen05_mma_barrier_count([smemA, smemB], True) | ||
| mbarrier.init(mma_bar, count=mma_bar_count) | ||
| mbarrier.init_tcgen05_mma(mma_bar, [smemA, smemB]) |
There was a problem hiding this comment.
I feel like this represents a big abstraction leak. If the user calculates and sets the arrival count manually, they shouldn't get unexpected hangs.
Also, one could imagine a use case where you need to arrive on the same mbarrier from multiple different sources (mma, tma, manual arrive, etc..) whereas this limits you to only a single mma.
There was a problem hiding this comment.
Not sure if you can mis arrives and tcgen05_commit s on the same barrier, but if you can there's no reason why this abstraction shouldn't allow it.
And sure, we need to use this op and not the other for this specific pattern, but exactly the same happens with things like having to pass the descriptors to tcgen05_commit or the multicast flag to mma. These patterns are tricky and there is so much you can represent natively at a language level...
| return_compiled=True, | ||
| ) | ||
|
|
||
| assert compiled.metadata.preferred_cluster_fallback_ctas == 0 |
There was a problem hiding this comment.
I guess we're now missing a test that exercises this PR positively?
There was a problem hiding this comment.
yes, sorry, didn't add it after realising this one was off
|
I'll write the mma kernel generic on the cta layout and benchmark it, see how the multicast configs look lime vs the 2cta ones |
fa9405b to
3267a5c
Compare
|
I have tried to find some shapes for which the best multicast config is noticeably better than simply a 2CTA config but I haven't found anything. The configs are mostly competitive with each other within ~10-20TFLOPS. With this in mind, I'd say we put this in the back burner and revisit it when rubin comes, or when we have more multiCTA kernels to try it on. cc @Mogball this might still be worth trying in your kernels just in case. |
The helper we had was missing the necessary `two_ctas` flag to compute the number of arrivals correctly for an mma that goes into a multicast TMA. I changed the API all across I think it would be much better if we just had a TTNG_InitMmaBarrierOp as proposed in #9957, as this would enable to get actual perf with multicast and would make the API much cleaner, but for now we just get this helper. What's nice is that this bug was found using the multicta consan :D
When running multiCTA kernels for `numCTAs > 2`, we are leaving some perf on the table as GPUs may be able not use every single SM. This is because SMs are grouped in TPCs (pairs of SMs) which then are grouped on GPCs (sets of TPCs). Every SM is part of a TPC, but not every TPC is part of a GPC of size 8, so if we launch a kernel with `numCTAs == 16` where each CTA takes a full SM, this may only be run on GPCs with size at least 8, leaving every other SM unused. To account for this, NVIDIA exposes an API ``` CU_LAUNCH_ATTRIBUTE_PREFERRED_CLUSTER_DIMENSION ``` starting on SM100. When invoking this API, we tell the GPU that we want to execute the kernel with `numCTAs` if possible, but if not, use a fallback number of CTAs and run it on less SMs. In this PR, we add a pass that checks for cross-CTA data movement that would make this invariant not hold. If all the ops in the kernel are alright, then we apply the optimisation. Then we change the lowerings with the following invariants: In other words, this will always return the same number regardless of whether the kernel was split or not. For example, if we have a grid of 2 CGAs with 4 CTAs each and the second launch is split into 2 launches, then we'll get Launch 00 4CTAs: ClusterCTAIdOp: 0-3 Launch 10 2CTAs: ClusterCTAIdOp: 0-1 Launch 11 2CTAs: ClusterCTAIdOp: 2-3 In other words, this should be used to compute addresses in global memory. This will depend on the runtime launch size of the program In the example above Launch 00 4CTAs: NVVM::ClusterId: 0-3 Launch 10 2CTAs: NVVM::ClusterId: 0-1 Launch 11 2CTAs: NVVM::ClusterId: 0-1 This should be used to generate masks for multicast, for example, or to compute predicates. Under this model the pid is invariant under the splits, so the program semantics under this transformation don't vary. We also support CLC. All this is E2E tested in 03-matmul-multicta.py
### Reviewers This PR includes #10167 and #10196 which I'll kill after they are merged. It is also separated in logical commits so that reviewing is simpler, as I had to add a few optimisations to generate less code as it was taking too long. We left out an optimisation where we sliced the indices to avoid loading the whole tensor from HBM when we know statically which rows we need to load (e.g. if we just slice current_cta along a dimension). This should help with compilation of num_ctas > 2 kernels as those take quite a bit of time with the new model. This is not pressing as those programs are not very performant without #9957 ### Idea A CTA is modelled as a set of independent logical threads, as if we had multiple warp-specialised threads running in parallel. Everything pretty much follows from that rule. A multicast-layout barrier has one live barrier row, owned by the lead CTA. Every CTA in the barrier group may arrive / expect on that row, but only the lead CTA initializes, waits, and invalidates it. This is the same model as several independent logical threads arriving on one barrier while only one logical thread waits on it. ### Shadow tables The buffer and barrier tables are CTA-agnostic We add a CTA dimension to go with every buffer/barrier/thread/mask dimension: ```text buffers | tensor | <B x i64> barriers | tensor | <K x i64> barrierStates | scratch | <Cbar x K x i64> waiting | scratch | <Cbar x K x Cthr x i32> writeVisibility | scratch | <Cbuf x B x Cmask x i64> readVisibility | scratch | <Cbuf x B x Cthr x T x Cmask x i64> writeTracking | scratch | <Cbuf x B x Cbar x K x i8> readTracking | scratch | <Cbuf x B x Cbar x K x Cmask x i64> outstandingCommits | scratch | <C x B x P x i8> ``` `Cbar`, `Cbuf`, `Cthr`, and `Cmask` are CTA dimensions qualifying barriers, buffers, threads, and thread masks respectively. C in outsatndingCommits folds the CTA dim for B and P, as there is no cross-CTA work in wgmma. Each CTA dimension is placed immediately before the dimension it qualifies. This keeps the multiCTA lift regular: a pre-existing dimension at position `pos` moves to `2 * pos`. Aliasing happens per-CTA ### CTA Issuers and Receivers and their Representation We need two pieces of CTA information: 1. Issuer predication: a 4-bit mask `m` such that the canonical issuer of a group satisfies `(cta_id & m) == 0`. 2. Receiver sets: a runtime `cta_id`-dependent `uint16_t` CTA bitmask (i.e., a `Value`) computed with the same lowering helpers used by the real implementation. Example: `mask == 0x1` in a 4-CTA kernel it means that: CTA0 accessed CTA0 and CTA1 CTA2 accessed CTA2 and CTA3 ### Cross-CTA memory effects TMA and CLC multicast Their mask is the multicas group (all the CTAs in the case of CLC) MMA / TMEMCopy, 2CTA The lead CTA must observe all the inputs and write to all outputs of both CTAs in the pair. In other words, the mask == 0x1 The idea here is that even though CTA0 and CTA1 collaborate, since it is launched from CTA0 and its synchronisation is emitted from CTA0, we can model it as if CTA0 did all the work. ### Barrier semantics TMA barriers and MMA completion barriers are dual. ```text Operation Issuer Receiver TMA, 1CTA, no multicast cta_id cta_id TMA, 2CTA, no multicast cta_id cta_id & ~1 TMA, 1CTA, multicast multicast-group leader multicast-group TMA, 2CTA, multicast multicast-group leader even cta_ids in the multicast-group MMA, 1CTA, no multicast cta_id 0x0 MMA, 2CTA, no multicast even CTA 0x1 MMA, multicast see lowering broadcastBits_d ``` For TMA, barrier receivers are obtained by applying the barrier-leader map to the TMA data-receiver rows. The relevant lowering is `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp`. For MMA completion, the multicast receiver mask is ```text broadcastBits_d = getBlockBroadcastMask(d) | (twoCTAs ? 0x1 : 0x0) ``` using the same logic as `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp`. Intuitively, an MMA completion waits for all CTAs that may write data consumed by the CTA performing the commit. CLC is the 1-CTA TMA multicast special case whose multicast group contains all CTAs. The CLC layout makes this explicit through an all-zero `cga_layout`. For ordinary mbarrier ops, use the semantics in `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp`: - `init`, `wait`, and `invalidate` are predicated on the barrier leader CTA. - `expect`, `arrive` are executed by each CTA. - All operations target the leader barrier address. - Therefore only leader barriers are live; non-leader barriers is as if they didn't exist and should never be accessed. In particular, non leader CTAs do not block on waits. A non-relaxed cluster barrier publishes all generic-proxy inflight events to all CTAs. ### Testing I run this on `test_consan.py` and all the multicta tests / kernels that we have, to make sure they all pass.
When running multiCTA kernels for
numCTAs > 2, we are leaving someperf on the table as GPUs may be able not use every single SM. This is
because SMs are grouped in TPCs (pairs of SMs) which then are grouped on
GPCs (sets of TPCs). Every SM is part of a TPC, but not every TPC is
part of a GPC of size 8, so if we launch a kernel with
numCTAs == 16where each CTA takes a full SM, this may only be run on GPCs with size
at least 8, leaving every other SM unused.
To account for this, NVIDIA exposes an API
starting on SM100.
When invoking this API, we tell the GPU that we want to execute the
kernel with
numCTAsif possible, but if not, use a fallback number ofCTAs and run it on less SMs.
Simplest use case to keep in mind for this feature:
In this PR, we add a pass that checks for cross-CTA data movement that
would make this invariant not hold. If all the ops in the kernel are
alright, then we apply the optimisation.
Then we change the lowerings with the following invariants:
nvgpu::ClusterCTAIdOpreturns the global ctaIdIn other words, this will always return the same number regardless of
whether the kernel was split or not.
For example, if we have a grid of 2 CGAs with 4 CTAs each and the second
launch is split into 2 launches, then we'll get
In other words, this should be used to compute addresses in global
memory.
NVVM::ClusterIdgives the relative cta_id wrt. the launched CGA sizeThis will depend on the runtime launch size of the program In the
example above
This should be used to generate masks for multicast, for example, or to
compute predicates.
Under this model the pid is invariant under the splits, so the program
semantics under this transformation don't vary.