[Cute,Bwd,Sm100] fix incorrect calculation of n_block global max for bwd deterministic#2549
Merged
Conversation
drisspg
approved these changes
May 8, 2026
Collaborator
drisspg
left a comment
There was a problem hiding this comment.
Ahh good find, another reason why the tiling+cluster meaning can be hard to grok. I wonder if non colliding names would be helpful here
reubenconducts
pushed a commit
to reubenconducts/flash-attention
that referenced
this pull request
Jun 2, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR #2253 refactored the bwd dQ deterministic code by introducing a new
get_n_block_max_for_m_blockmethod, however this took in an incorrect calculation ofn_block_global_maxby using the cta rather than the cluster n tile (thus doubling the value), which sometimes breaks deterministic mode for 2cta backward. Note that BlockInfo's tile_n is the indeed the cluster tile_n by definition in the kernel.This PR fixes the bug by determining
n_block_global_maxinsideget_n_block_max_for_m_block, matching the existing logic inget_n_block_min_max.@drisspg