Deterministic backward for blocksparse impl#2253
Conversation
stack-info: PR: #2253, branch: drisspg/stack/16
08b2922 to
129aa17
Compare
| block_size=(m_block_size, n_block_size), | ||
| subtile_factor=subtile_factor, | ||
| ) | ||
| if deterministic: |
There was a problem hiding this comment.
this file is getting very complicated and hard to read from all the error checking.. I want to refactor jus so the main entrypoint flow is easier to grok
| return dense[:, :, :, :num_cols] | ||
|
|
||
|
|
||
| def compute_dq_write_order( |
There was a problem hiding this comment.
can move this to PT call after we land there after this 😂
| spt: bool | None = None | ||
|
|
||
|
|
||
| def _ordered_to_dense_simple( |
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
129aa17 to
53969cb
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
53969cb to
d55e707
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
d55e707 to
0e74cdd
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
0e74cdd to
d6f5b49
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
d6f5b49 to
25ad72e
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
25ad72e to
fb45f7a
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3c1a5c2df0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| dq_write_order = _check_and_expand_metadata_tensor( | ||
| "dq_write_order", | ||
| tensors.dq_write_order, | ||
| expected_index_shape, |
There was a problem hiding this comment.
Allow compact dq metadata tensors during normalization
normalize_block_sparse_tensors now accepts compact *_block_idx tensors via _check_and_expand_block (last-dim <= expected), but dq_write_order/dq_write_order_full are still expanded against the full expected_index_shape. In deterministic block-sparse backward, write-order tensors are naturally parallel to compact backward indices, so this path raises a ValueError before kernel launch even though the sparse indices themselves are accepted. Fresh evidence in this commit is the new deterministic metadata path plus compact-index support/tests, which makes this mismatch user-visible for deterministic backward.
Useful? React with 👍 / 👎.
stack-info: PR: #2253, branch: drisspg/stack/16
3c1a5c2 to
f0e5370
Compare
f0e5370 to
062e564
Compare
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 062e5647be
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| dq_write_order = _check_and_expand_metadata_tensor( | ||
| "dq_write_order", | ||
| tensors.dq_write_order, | ||
| expected_index_shape, |
There was a problem hiding this comment.
Allow compact dq_write_order metadata during normalization
normalize_block_sparse_tensors accepts compact *_block_idx tensors (last dim <= expected) via _check_and_expand_block, but dq_write_order/dq_write_order_full are still validated against the full expected_index_shape. In deterministic block-sparse backward, these metadata tensors are parallel to compact q_block_idx/full_q_idx, so valid compact inputs now fail with a shape-expansion ValueError before kernel launch. This breaks deterministic backward for long-sequence compact masks, even though compact sparse indices are explicitly supported.
Useful? React with 👍 / 👎.
stack-info: PR: #2253, branch: drisspg/stack/16
062e564 to
5cd3501
Compare
Stacked PRs:
Deterministic backward for blocksparse impl
PyTorch PR: pytorch/pytorch#174813
Summary
This make it possible to run bwd with blocksparse tensors w/ determinism.
At a very high level, the current implementation needs two things
This is a fun little problem and I made some animations: https://x.com/drisspg/status/2022438130893426731?s=20
But that being said I feel as though there are two approaches for block sparsity, we could either compute the semaphore lock value for every N tile at runtime based off of the sparse data, Or we can precompute the lock values ahead of time. I did not go with approach 1 and went with approach 2 because in terms of total work, since sparsity is typically consistent across layers we get to amortize the construction and I was concerned that doing the scan across kv_blocks and indices would be too much overhead.
So what does this look like
We introduced 3 new attributes to
BlockSparseTensors( I wish this was 2 but we will come back to this):New dq_write_order(partial|full) mirroring the q_block_idx(num_blocks) tensors.
They are used as such:
Parallel to
q_block_idx[B, H, num_n_blocks, max_q_per_n].Each KV column worker iterates from 0 to max_q_per_n, so for the i-th M-block in the CTA's iteration list:
If we we can figure out via the total counts whether I corresponds to partial or full block and thats basically it. See animation :)
Scheduling
I implemented this and perf was terrible since by default we would use the
non-sptpath. SPT is fancy word for right to left.spt=shortest processing timefor most causal masks they look lower triangular which means the amount of work for a column is highest on the left, since q_idx = 0 which is >= to all kv_indices and thus all m_blocks will contribute - and smallest on the right since the only block that has any actual dQ to compute is the last block for Nth columns iteration.Bad scheduling is terrible for perf, you turn beautiful parallel GPU into terrible serialized turtle.
So I also add 1 more flag which is SPT. And that says in stead of scheduling using the default left to right flip the lock values so that go from lowest to highest from right left. This does require coordination with the scheduler. Why? Again lets look at causal for simplicity -
Lets say your num_sms was 2 and your total KV blocks was 5 if you did the default scheduling we would throw work into the queue starting with kv_block 0 and kv_block 1 for m_tile 0 we would be good for m_tile 1 we would also be good since there is their lock values are 0 and then 1,0 but for m_row 2 we would have lock values (2, 1) and semaphore value 0 is at kv_block 2! So our workers will never be able to continue
So long story short with SPT for most causally blocks we are seeing around the 20-30 perf drop from deterministic bwd -> see attached data.
Better Scheduling
I was toiling with ideas for more optimal scheduling. Since the distribution of tiles can be arbitrary for BlockSparse, I think if we could ahead of time compute a
least blocking scheduleand then we use this schedule to submit kv_tiles, again this requires coordination with the scheduler and would probably add some slow down since get_next_tile_info will need a global read to figure out where to go. I think we would need to have 1 lock value per column in order to avoid any deadlocks and I thought of some good heuristics but afaik this is NP complete. Hard to say if all this AOT work would be worth the potentially faster bwd. Sounds like a future problem.Although lets say we did this I do kinda want
sptattribute to really bescheduleattribute which can either be non spt, spt or a permutation. Couldnt think of anything really elegant here ** ideas from reviewer welcome**.Fuzzing
Side note I did a bunch of fuzzing for this PR, one really nice aspect (or not lol) is that the ratio of compile time to gpu work is very high so both testing and fuzzing is very amenable to parallel workers.
This fuzzing lead to this PR: #2258
Performance:
Sliding window size = 2048 and eventually starts to get majorly blocked

This is with SPT = True
And here is SPT = False
