Fix ZeroDivisionError in num_splits_heuristic for empty Q workloads#2515
Merged
Johnsonms merged 1 commit intoMay 13, 2026
Merged
Conversation
num_splits_heuristic divides num_SMs by total_mblocks, which collapses to 0 when seqlen_q == 0 or batch_size == 0 (e.g. CUDA graph padding or empty microbatches). The existing seqlen_k == 0 early-exit in _flash_attn_fwd does not cover these cases. - Extend the early-exit to also cover total_q == 0, using the same zero-output / -inf-LSE contract. total_q is batch_size * seqlen_q (dense) or q.shape[0] (varlen), so a single predicate handles both code paths. - Add a defensive total_mblocks == 0 guard inside num_splits_heuristic itself so the function is safe in isolation. - Add regression tests covering dense (batch=0, seqlen_q=0) and varlen (total_q=0) paths under both causal and non-causal masks. Fixes Dao-AILab#2503.
Contributor
Author
Contributor
Author
|
@Johnsonms @drisspg @jayhshah Can you please review the PR? |
Collaborator
|
Sure, I will repro and confirm this issue by end of this week,Sorry in traveling |
Contributor
Author
|
Thanks @yunzhongOvO - @Johnsonms Can this merged now? |
Johnsonms
approved these changes
May 13, 2026
reubenconducts
pushed a commit
to reubenconducts/flash-attention
that referenced
this pull request
Jun 2, 2026
…ao-AILab#2515) num_splits_heuristic divides num_SMs by total_mblocks, which collapses to 0 when seqlen_q == 0 or batch_size == 0 (e.g. CUDA graph padding or empty microbatches). The existing seqlen_k == 0 early-exit in _flash_attn_fwd does not cover these cases. - Extend the early-exit to also cover total_q == 0, using the same zero-output / -inf-LSE contract. total_q is batch_size * seqlen_q (dense) or q.shape[0] (varlen), so a single predicate handles both code paths. - Add a defensive total_mblocks == 0 guard inside num_splits_heuristic itself so the function is safe in isolation. - Add regression tests covering dense (batch=0, seqlen_q=0) and varlen (total_q=0) paths under both causal and non-causal masks. Fixes Dao-AILab#2503.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
num_splits_heuristicinflash_attn/cute/interface.pydividesnum_SMsbytotal_mblocks, which collapses to 0 whenseqlen_q == 0orbatch_size == 0. The existingseqlen_k == 0early-exit in_flash_attn_fwddoes not cover these cases, so empty-Q shapes (e.g. CUDA graph padding, empty microbatches) crash withZeroDivisionError.Fixes #2503.
Changes
_flash_attn_fwdearly-exit (interface.py:462) — extend the existingseqlen_k == 0guard to also covertotal_q == 0, using the same zero-output /-inf-LSE contract.total_q = batch_size * seqlen_q(dense) orq.shape[0](varlen), so a single predicate handles both paths.num_splits_heuristicdefensive guard (interface.py:254) — add atotal_mblocks == 0short-circuit inside the heuristic itself so it is safe in isolation. Mirrors the existingnum_n_blocks <= 4style.tests/cute/test_flash_attn.py:test_flash_attn_empty_q_dense— parametrized over(batch=2, seqlen_q=0)and(batch=0, seqlen_q=64), both causal and non-causal. Usesseqlen_k=4096sonum_n_blocks > 4and the existing early-return in the heuristic does not mask the bug.test_flash_attn_empty_q_varlen—cu_seqlens_qall zeros withtotal_q == 0.Scope confirmed FA4-only: FA3 (
hopper/heuristics.h) and FA2 (csrc/flash_attn/flash_api.cpp) use different heuristics that do not divide bytotal_mblocks.Test plan
pytest tests/cute/test_flash_attn.py -k "empty_q" -xpasses on the fixed branchZeroDivisionErroronmainwithout the interface.py changes (confirms the test actually exercises the bug path)pytest tests/cute/test_flash_attn.py -k "seqlen_k_zero" -xstill passes (confirms the existing early-exit path is untouched)I was not able to run these locally (no CUDA host); relying on CI for validation.