Skip to content

Deterministic backward for blocksparse impl#2253

Merged
drisspg merged 1 commit into
mainfrom
drisspg/stack/16
May 4, 2026
Merged

Deterministic backward for blocksparse impl#2253
drisspg merged 1 commit into
mainfrom
drisspg/stack/16

Conversation

@drisspg
Copy link
Copy Markdown
Collaborator

@drisspg drisspg commented Feb 12, 2026

Stacked PRs:


Deterministic backward for blocksparse impl

PyTorch PR: pytorch/pytorch#174813

Summary

This make it possible to run bwd with blocksparse tensors w/ determinism.

At a very high level, the current implementation needs two things

  1. For every M x N block, a way to figure out for N_i is it this columns turn to accumulate DQ into the global memory. For dense it can infer this by predefined sparsity, I.e. dense, causal, sliding window.
  2. Some coordination with the tile schedule to ensure that we throws work tiles in the queue in a way that will not block any given row from making progress.

This is a fun little problem and I made some animations: https://x.com/drisspg/status/2022438130893426731?s=20

But that being said I feel as though there are two approaches for block sparsity, we could either compute the semaphore lock value for every N tile at runtime based off of the sparse data, Or we can precompute the lock values ahead of time. I did not go with approach 1 and went with approach 2 because in terms of total work, since sparsity is typically consistent across layers we get to amortize the construction and I was concerned that doing the scan across kv_blocks and indices would be too much overhead.

So what does this look like

We introduced 3 new attributes to BlockSparseTensors ( I wish this was 2 but we will come back to this):
New dq_write_order(partial|full) mirroring the q_block_idx(num_blocks) tensors.

They are used as such:

dq_write_order: int32[B, H, num_n_blocks, max_q_per_n]

Parallel to q_block_idx[B, H, num_n_blocks, max_q_per_n].
Each KV column worker iterates from 0 to max_q_per_n, so for the i-th M-block in the CTA's iteration list:

m_block       = q_block_idx[b, h, n_block, i]
my_lock_value = dq_write_order[b, h, n_block, i]

If we we can figure out via the total counts whether I corresponds to partial or full block and thats basically it. See animation :)

Scheduling

I implemented this and perf was terrible since by default we would use the non-spt path. SPT is fancy word for right to left. spt=shortest processing time for most causal masks they look lower triangular which means the amount of work for a column is highest on the left, since q_idx = 0 which is >= to all kv_indices and thus all m_blocks will contribute - and smallest on the right since the only block that has any actual dQ to compute is the last block for Nth columns iteration.

         n_block (KV)
         0    1    2    3    4
       ┌────┬────┬────┬────┬────┐
   0   │ ██ │    │    │    │    │
       ├────┼────┼────┼────┼────┤
   1   │ ██ │ ██ │    │    │    │
m      ├────┼────┼────┼────┼────┤
(Q) 2  │ ██ │ ██ │ ██ │    │    │
       ├────┼────┼────┼────┼────┤
   3   │ ██ │ ██ │ ██ │ ██ │    │
       ├────┼────┼────┼────┼────┤
   4   │ ██ │ ██ │ ██ │ ██ │ ██ │
       └────┴────┴────┴────┴────┘
work:    5    4    3    2    1

Bad scheduling is terrible for perf, you turn beautiful parallel GPU into terrible serialized turtle.

So I also add 1 more flag which is SPT. And that says in stead of scheduling using the default left to right flip the lock values so that go from lowest to highest from right left. This does require coordination with the scheduler. Why? Again lets look at causal for simplicity -
Lets say your num_sms was 2 and your total KV blocks was 5 if you did the default scheduling we would throw work into the queue starting with kv_block 0 and kv_block 1 for m_tile 0 we would be good for m_tile 1 we would also be good since there is their lock values are 0 and then 1,0 but for m_row 2 we would have lock values (2, 1) and semaphore value 0 is at kv_block 2! So our workers will never be able to continue

         n_block
         0    1    2    3    4
       ┌────┬────┬────┬────┬────┐
   0   │ 0  │    │    │    │    │  ✓ only CTA 0, lock=0
       ├────┼────┼────┼────┼────┤
   1   │ 1  │ 0  │    │    │    │  ✓ CTA 1 lock=0 goes first, then CTA 0
m      ├────┼────┼────┼────┼────┤
(Q) 2  │ 2  │ 1  │ 0  │    │    │  ✗ lock=0 is at n_block=2 — not running!
       ├────┼────┼────┼────┼────┤
   3   │ 3  │ 2  │ 1  │ 0  │    │
       ├────┼────┼────┼────┼────┤
   4   │ 4  │ 3  │ 2  │ 1  │ 0  │
       └────┴────┴────┴────┴────┘

So long story short with SPT for most causally blocks we are seeing around the 20-30 perf drop from deterministic bwd -> see attached data.

Better Scheduling

I was toiling with ideas for more optimal scheduling. Since the distribution of tiles can be arbitrary for BlockSparse, I think if we could ahead of time compute a least blocking schedule and then we use this schedule to submit kv_tiles, again this requires coordination with the scheduler and would probably add some slow down since get_next_tile_info will need a global read to figure out where to go. I think we would need to have 1 lock value per column in order to avoid any deadlocks and I thought of some good heuristics but afaik this is NP complete. Hard to say if all this AOT work would be worth the potentially faster bwd. Sounds like a future problem.

Although lets say we did this I do kinda want spt attribute to really be schedule attribute which can either be non spt, spt or a permutation. Couldnt think of anything really elegant here ** ideas from reviewer welcome**.

Fuzzing

Side note I did a bunch of fuzzing for this PR, one really nice aspect (or not lol) is that the ratio of compile time to gpu work is very high so both testing and fuzzing is very amenable to parallel workers.

image image

This fuzzing lead to this PR: #2258

Performance:

Sliding window size = 2048 and eventually starts to get majorly blocked
This is with SPT = True
image

And here is SPT = False
image

drisspg added a commit that referenced this pull request Feb 12, 2026
stack-info: PR: #2253, branch: drisspg/stack/16
block_size=(m_block_size, n_block_size),
subtile_factor=subtile_factor,
)
if deterministic:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is getting very complicated and hard to read from all the error checking.. I want to refactor jus so the main entrypoint flow is easier to grok

Comment thread flash_attn/cute/flash_bwd_sm100.py Outdated
Comment thread flash_attn/cute/flash_bwd_sm100.py Outdated
return dense[:, :, :, :num_cols]


def compute_dq_write_order(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can move this to PT call after we land there after this 😂

spt: bool | None = None


def _ordered_to_dense_simple(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here

drisspg added a commit to drisspg/flash-attention that referenced this pull request Feb 12, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
@drisspg drisspg marked this pull request as draft February 12, 2026 18:20
@drisspg drisspg marked this pull request as ready for review February 12, 2026 18:20
drisspg added a commit to drisspg/flash-attention that referenced this pull request Feb 12, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
@drisspg drisspg marked this pull request as draft February 12, 2026 18:43
@drisspg drisspg marked this pull request as ready for review February 12, 2026 18:43
@drisspg drisspg marked this pull request as draft February 13, 2026 04:03
drisspg added a commit to drisspg/flash-attention that referenced this pull request Feb 14, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
drisspg added a commit to drisspg/flash-attention that referenced this pull request Feb 14, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
@drisspg drisspg marked this pull request as ready for review February 15, 2026 06:00
drisspg added a commit to drisspg/flash-attention that referenced this pull request Feb 16, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
@drisspg drisspg marked this pull request as draft February 16, 2026 00:44
@drisspg drisspg marked this pull request as ready for review February 16, 2026 00:44
drisspg added a commit to drisspg/flash-attention that referenced this pull request Feb 16, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
@drisspg drisspg marked this pull request as draft February 16, 2026 02:23
@drisspg drisspg marked this pull request as ready for review February 16, 2026 02:23
drisspg added a commit to drisspg/flash-attention that referenced this pull request Mar 2, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3c1a5c2df0

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread flash_attn/cute/block_sparsity.py Outdated
Comment on lines +437 to +440
dq_write_order = _check_and_expand_metadata_tensor(
"dq_write_order",
tensors.dq_write_order,
expected_index_shape,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow compact dq metadata tensors during normalization

normalize_block_sparse_tensors now accepts compact *_block_idx tensors via _check_and_expand_block (last-dim <= expected), but dq_write_order/dq_write_order_full are still expanded against the full expected_index_shape. In deterministic block-sparse backward, write-order tensors are naturally parallel to compact backward indices, so this path raises a ValueError before kernel launch even though the sparse indices themselves are accepted. Fresh evidence in this commit is the new deterministic metadata path plus compact-index support/tests, which makes this mismatch user-visible for deterministic backward.

Useful? React with 👍 / 👎.

@drisspg drisspg marked this pull request as draft April 27, 2026 21:37
@drisspg drisspg changed the base branch from drisspg/stack/36 to main April 27, 2026 21:37
drisspg added a commit that referenced this pull request Apr 27, 2026
stack-info: PR: #2253, branch: drisspg/stack/16
@drisspg drisspg changed the base branch from main to drisspg/stack/36 April 27, 2026 21:38
@drisspg drisspg marked this pull request as ready for review April 27, 2026 21:38
@drisspg drisspg marked this pull request as draft April 28, 2026 02:35
@drisspg drisspg changed the base branch from drisspg/stack/36 to main April 28, 2026 02:35
@drisspg drisspg changed the base branch from main to drisspg/stack/36 April 28, 2026 02:35
@drisspg drisspg marked this pull request as ready for review April 28, 2026 02:35
@drisspg drisspg marked this pull request as draft April 28, 2026 16:56
@drisspg drisspg changed the base branch from drisspg/stack/36 to main April 28, 2026 16:56
@drisspg drisspg marked this pull request as ready for review April 28, 2026 16:56
@drisspg drisspg marked this pull request as draft April 28, 2026 17:03
@drisspg drisspg marked this pull request as ready for review April 28, 2026 17:04
drisspg added a commit to drisspg/flash-attention that referenced this pull request Apr 28, 2026
stack-info: PR: Dao-AILab#2253, branch: drisspg/stack/16
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 062e5647be

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread flash_attn/cute/block_sparsity.py Outdated
Comment on lines +437 to +440
dq_write_order = _check_and_expand_metadata_tensor(
"dq_write_order",
tensors.dq_write_order,
expected_index_shape,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow compact dq_write_order metadata during normalization

normalize_block_sparse_tensors accepts compact *_block_idx tensors (last dim <= expected) via _check_and_expand_block, but dq_write_order/dq_write_order_full are still validated against the full expected_index_shape. In deterministic block-sparse backward, these metadata tensors are parallel to compact q_block_idx/full_q_idx, so valid compact inputs now fail with a shape-expansion ValueError before kernel launch. This breaks deterministic backward for long-sequence compact masks, even though compact sparse indices are explicitly supported.

Useful? React with 👍 / 👎.

stack-info: PR: #2253, branch: drisspg/stack/16
@drisspg drisspg marked this pull request as draft April 28, 2026 17:43
@drisspg drisspg marked this pull request as ready for review April 28, 2026 17:43
@drisspg drisspg marked this pull request as draft April 28, 2026 17:47
@drisspg drisspg marked this pull request as ready for review April 28, 2026 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants