Skip to content

[Cute,Bwd,Sm100] fix incorrect calculation of n_block global max for bwd deterministic#2549

Merged
jayhshah merged 1 commit into
mainfrom
jshah/bwd-det-fix
May 8, 2026
Merged

[Cute,Bwd,Sm100] fix incorrect calculation of n_block global max for bwd deterministic#2549
jayhshah merged 1 commit into
mainfrom
jshah/bwd-det-fix

Conversation

@jayhshah
Copy link
Copy Markdown
Collaborator

@jayhshah jayhshah commented May 8, 2026

PR #2253 refactored the bwd dQ deterministic code by introducing a new get_n_block_max_for_m_block method, however this took in an incorrect calculation of n_block_global_max by using the cta rather than the cluster n tile (thus doubling the value), which sometimes breaks deterministic mode for 2cta backward. Note that BlockInfo's tile_n is the indeed the cluster tile_n by definition in the kernel.

This PR fixes the bug by determining n_block_global_max inside get_n_block_max_for_m_block, matching the existing logic in get_n_block_min_max.

@drisspg

@jayhshah jayhshah requested a review from drisspg May 8, 2026 06:59
Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh good find, another reason why the tiling+cluster meaning can be hard to grok. I wonder if non colliding names would be helpful here

@jayhshah jayhshah merged commit ab66326 into main May 8, 2026
2 of 3 checks passed
@jayhshah jayhshah deleted the jshah/bwd-det-fix branch May 8, 2026 15:36
reubenconducts pushed a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants