[CONSAN] Multi CTA model v2#10212
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ca16eb7640
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
5cc79d2 to
53c47ed
Compare
jeffniu-openai
left a comment
There was a problem hiding this comment.
lgtm but @pawelszczerbuk should take a look
Mogball
left a comment
There was a problem hiding this comment.
I love having to manage 2 accounts
| auto ptrTy = triton::getPointerType(memTy.getElementType()); | ||
| auto allocOp = createThirdPartyScratchAlloc(rewriter, loc, ptrTy, | ||
| sizeInBytes, alignment); | ||
| allocOp->setDiscardableAttr("tt.divisibility", |
There was a problem hiding this comment.
where did these changes come from?
There was a problem hiding this comment.
b5604f2
these were lifted to createThirdPartyScratchAlloc so that everyone benefits from the vectorisation part.
53c47ed to
aa9e664
Compare
Reviewers
This PR includes #10167 and #10196 which I'll kill after they are merged.
It is also separated in logical commits so that reviewing is simpler, as I had to add a few optimisations to generate less code as it was taking too long.
We left out an optimisation where we sliced the indices to avoid loading the whole tensor from HBM when we know statically which rows we need to load (e.g. if we just slice current_cta along a dimension). This should help with compilation of num_ctas > 2 kernels as those take quite a bit of time with the new model. This is not pressing as those programs are not very performant without #9957
Idea
A CTA is modelled as a set of independent logical threads, as if we had multiple
warp-specialised threads running in parallel.
Everything pretty much follows from that rule.
A multicast-layout barrier has one live barrier row, owned by the lead CTA.
Every CTA in the barrier group may arrive / expect on that row, but only the
lead CTA initializes, waits, and invalidates it.
This is the same model as several independent logical threads arriving on one
barrier while only one logical thread waits on it.
Shadow tables
The buffer and barrier tables are CTA-agnostic
We add a CTA dimension to go with every buffer/barrier/thread/mask dimension:
Cbar,Cbuf,Cthr, andCmaskare CTA dimensions qualifying barriers,buffers, threads, and thread masks respectively. C in outsatndingCommits folds
the CTA dim for B and P, as there is no cross-CTA work in wgmma.
Each CTA dimension is placed
immediately before the dimension it qualifies. This keeps the multiCTA lift
regular: a pre-existing dimension at position
posmoves to2 * pos.Aliasing happens per-CTA
CTA Issuers and Receivers and their Representation
We need two pieces of CTA information:
msuch that the canonical issuer of agroup satisfies
(cta_id & m) == 0.cta_id-dependentuint16_tCTA bitmask(i.e., a
Value) computed with the same lowering helpers used by thereal implementation.
Example:
mask == 0x1in a 4-CTA kernel it means that:CTA0 accessed CTA0 and CTA1
CTA2 accessed CTA2 and CTA3
Cross-CTA memory effects
TMA and CLC multicast
Their mask is the multicas group (all the CTAs in the case of CLC)
MMA / TMEMCopy, 2CTA
The lead CTA must observe all the inputs and write to all outputs of both CTAs in the pair.
In other words, the mask == 0x1
The idea here is that even though CTA0 and CTA1 collaborate, since it is launched from CTA0
and its synchronisation is emitted from CTA0, we can model it as if CTA0 did all the work.
### Barrier semantics
TMA barriers and MMA completion barriers are dual.
For TMA, barrier receivers are obtained by applying the barrier-leader map to
the TMA data-receiver rows. The relevant lowering is
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp.For MMA completion, the multicast receiver mask is
using the same logic as
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp.Intuitively, an MMA completion waits for all CTAs that may write data consumed
by the CTA performing the commit.
CLC is the 1-CTA TMA multicast special case whose multicast group contains all
CTAs. The CLC layout makes this explicit through an all-zero
cga_layout.For ordinary mbarrier ops, use the semantics in
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp:init,wait, andinvalidateare predicated on the barrier leader CTA.expect,arriveare executed by each CTA.exist and should never be accessed. In particular, non leader CTAs do not block on
waits.
A non-relaxed cluster barrier publishes all generic-proxy inflight events to all
CTAs.
Testing
I run this on
test_consan.pyand all the multicta tests / kernels that we have, to make sure they all pass.