[CONSAN] Add read before any write check#10167
Conversation
We add a check that a buffer had been written before its first read. This fixes a multicast repro that had been flaky for sometime. Note that this is analogous to the "see the previous write" race conditions we might find if mbarriers are not correct only that this one catches when races when there was no previous write.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2e9fbdb188
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| (uint64_t)memType}, | ||
| buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/false)); | ||
| } | ||
| createCallToCachedFunction( |
There was a problem hiding this comment.
why do we have to run both "verify_write_initialized" and "verify_write_visibility" when !allowNoWrite? Won't verify_write_initialized already check the visibility?
There was a problem hiding this comment.
this is a minor thing, and it's because these functions just accept one assert, so if we want to have its own nice assert, we have to run both. Should I make it so that we can have several asserts per function you reckon?
There was a problem hiding this comment.
One of the problems is that I think "verify_write_initialized" checks also visibility, so in case of buffer not being visible it will emit the "uninitialized" assert, which will be confusing. I think we either need separate function to check for no-write, or multiple possible asserts, but then we also need the inner function to be able to return richer information than just bool
There was a problem hiding this comment.
I think it's fine, unless I misunderstood you.
We have:
rowInitialized = !noOneIsWriting
writeVisible = noOneIsWriting || bufferHasVisibility
This is already tested in
python/test/gluon/test_consan.py::test_async_tma_kernel[1ctas-True]
which triggers the visiblity point but not the initialised one.
There was a problem hiding this comment.
Makes sense, re-reading the code I see I was wrong. Thanks!
### Reviewers This PR includes #10167 and #10196 which I'll kill after they are merged. It is also separated in logical commits so that reviewing is simpler, as I had to add a few optimisations to generate less code as it was taking too long. We left out an optimisation where we sliced the indices to avoid loading the whole tensor from HBM when we know statically which rows we need to load (e.g. if we just slice current_cta along a dimension). This should help with compilation of num_ctas > 2 kernels as those take quite a bit of time with the new model. This is not pressing as those programs are not very performant without #9957 ### Idea A CTA is modelled as a set of independent logical threads, as if we had multiple warp-specialised threads running in parallel. Everything pretty much follows from that rule. A multicast-layout barrier has one live barrier row, owned by the lead CTA. Every CTA in the barrier group may arrive / expect on that row, but only the lead CTA initializes, waits, and invalidates it. This is the same model as several independent logical threads arriving on one barrier while only one logical thread waits on it. ### Shadow tables The buffer and barrier tables are CTA-agnostic We add a CTA dimension to go with every buffer/barrier/thread/mask dimension: ```text buffers | tensor | <B x i64> barriers | tensor | <K x i64> barrierStates | scratch | <Cbar x K x i64> waiting | scratch | <Cbar x K x Cthr x i32> writeVisibility | scratch | <Cbuf x B x Cmask x i64> readVisibility | scratch | <Cbuf x B x Cthr x T x Cmask x i64> writeTracking | scratch | <Cbuf x B x Cbar x K x i8> readTracking | scratch | <Cbuf x B x Cbar x K x Cmask x i64> outstandingCommits | scratch | <C x B x P x i8> ``` `Cbar`, `Cbuf`, `Cthr`, and `Cmask` are CTA dimensions qualifying barriers, buffers, threads, and thread masks respectively. C in outsatndingCommits folds the CTA dim for B and P, as there is no cross-CTA work in wgmma. Each CTA dimension is placed immediately before the dimension it qualifies. This keeps the multiCTA lift regular: a pre-existing dimension at position `pos` moves to `2 * pos`. Aliasing happens per-CTA ### CTA Issuers and Receivers and their Representation We need two pieces of CTA information: 1. Issuer predication: a 4-bit mask `m` such that the canonical issuer of a group satisfies `(cta_id & m) == 0`. 2. Receiver sets: a runtime `cta_id`-dependent `uint16_t` CTA bitmask (i.e., a `Value`) computed with the same lowering helpers used by the real implementation. Example: `mask == 0x1` in a 4-CTA kernel it means that: CTA0 accessed CTA0 and CTA1 CTA2 accessed CTA2 and CTA3 ### Cross-CTA memory effects TMA and CLC multicast Their mask is the multicas group (all the CTAs in the case of CLC) MMA / TMEMCopy, 2CTA The lead CTA must observe all the inputs and write to all outputs of both CTAs in the pair. In other words, the mask == 0x1 The idea here is that even though CTA0 and CTA1 collaborate, since it is launched from CTA0 and its synchronisation is emitted from CTA0, we can model it as if CTA0 did all the work. ### Barrier semantics TMA barriers and MMA completion barriers are dual. ```text Operation Issuer Receiver TMA, 1CTA, no multicast cta_id cta_id TMA, 2CTA, no multicast cta_id cta_id & ~1 TMA, 1CTA, multicast multicast-group leader multicast-group TMA, 2CTA, multicast multicast-group leader even cta_ids in the multicast-group MMA, 1CTA, no multicast cta_id 0x0 MMA, 2CTA, no multicast even CTA 0x1 MMA, multicast see lowering broadcastBits_d ``` For TMA, barrier receivers are obtained by applying the barrier-leader map to the TMA data-receiver rows. The relevant lowering is `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp`. For MMA completion, the multicast receiver mask is ```text broadcastBits_d = getBlockBroadcastMask(d) | (twoCTAs ? 0x1 : 0x0) ``` using the same logic as `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp`. Intuitively, an MMA completion waits for all CTAs that may write data consumed by the CTA performing the commit. CLC is the 1-CTA TMA multicast special case whose multicast group contains all CTAs. The CLC layout makes this explicit through an all-zero `cga_layout`. For ordinary mbarrier ops, use the semantics in `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp`: - `init`, `wait`, and `invalidate` are predicated on the barrier leader CTA. - `expect`, `arrive` are executed by each CTA. - All operations target the leader barrier address. - Therefore only leader barriers are live; non-leader barriers is as if they didn't exist and should never be accessed. In particular, non leader CTAs do not block on waits. A non-relaxed cluster barrier publishes all generic-proxy inflight events to all CTAs. ### Testing I run this on `test_consan.py` and all the multicta tests / kernels that we have, to make sure they all pass.
We add a check that a buffer had been written before its first read.
This fixes a multicast repro that had been flaky for sometime.
Note that this is analogous to the "see the previous write" race
conditions we might find if mbarriers are not correct only that this one
catches when races when there was no previous write.
We update consan tests that were not initialising their shmem before.