Skip to content

Fix ZeroDivisionError in num_splits_heuristic for empty Q workloads#2515

Merged
Johnsonms merged 1 commit into
Dao-AILab:mainfrom
shivam2199:fix/2503-zero-seqlen-q-batch
May 13, 2026
Merged

Fix ZeroDivisionError in num_splits_heuristic for empty Q workloads#2515
Johnsonms merged 1 commit into
Dao-AILab:mainfrom
shivam2199:fix/2503-zero-seqlen-q-batch

Conversation

@shivam2199
Copy link
Copy Markdown
Contributor

@shivam2199 shivam2199 commented Apr 28, 2026

Summary

num_splits_heuristic in flash_attn/cute/interface.py divides num_SMs by total_mblocks, which collapses to 0 when seqlen_q == 0 or batch_size == 0. The existing seqlen_k == 0 early-exit in _flash_attn_fwd does not cover these cases, so empty-Q shapes (e.g. CUDA graph padding, empty microbatches) crash with ZeroDivisionError.

Fixes #2503.

Changes

  • _flash_attn_fwd early-exit (interface.py:462) — extend the existing seqlen_k == 0 guard to also cover total_q == 0, using the same zero-output / -inf-LSE contract. total_q = batch_size * seqlen_q (dense) or q.shape[0] (varlen), so a single predicate handles both paths.
  • num_splits_heuristic defensive guard (interface.py:254) — add a total_mblocks == 0 short-circuit inside the heuristic itself so it is safe in isolation. Mirrors the existing num_n_blocks <= 4 style.
  • Regression tests in tests/cute/test_flash_attn.py:
    • test_flash_attn_empty_q_dense — parametrized over (batch=2, seqlen_q=0) and (batch=0, seqlen_q=64), both causal and non-causal. Uses seqlen_k=4096 so num_n_blocks > 4 and the existing early-return in the heuristic does not mask the bug.
    • test_flash_attn_empty_q_varlencu_seqlens_q all zeros with total_q == 0.

Scope confirmed FA4-only: FA3 (hopper/heuristics.h) and FA2 (csrc/flash_attn/flash_api.cpp) use different heuristics that do not divide by total_mblocks.

Test plan

  • pytest tests/cute/test_flash_attn.py -k "empty_q" -x passes on the fixed branch
  • Same test reverts to ZeroDivisionError on main without the interface.py changes (confirms the test actually exercises the bug path)
  • pytest tests/cute/test_flash_attn.py -k "seqlen_k_zero" -x still passes (confirms the existing early-exit path is untouched)

I was not able to run these locally (no CUDA host); relying on CI for validation.

num_splits_heuristic divides num_SMs by total_mblocks, which collapses
to 0 when seqlen_q == 0 or batch_size == 0 (e.g. CUDA graph padding or
empty microbatches). The existing seqlen_k == 0 early-exit in
_flash_attn_fwd does not cover these cases.

- Extend the early-exit to also cover total_q == 0, using the same
  zero-output / -inf-LSE contract. total_q is batch_size * seqlen_q
  (dense) or q.shape[0] (varlen), so a single predicate handles both
  code paths.
- Add a defensive total_mblocks == 0 guard inside num_splits_heuristic
  itself so the function is safe in isolation.
- Add regression tests covering dense (batch=0, seqlen_q=0) and varlen
  (total_q=0) paths under both causal and non-causal masks.

Fixes Dao-AILab#2503.
@shivam2199
Copy link
Copy Markdown
Contributor Author

@tridao @jayhshah Can you take a look?

@shivam2199
Copy link
Copy Markdown
Contributor Author

@Johnsonms @drisspg @jayhshah Can you please review the PR?

@Johnsonms
Copy link
Copy Markdown
Collaborator

Sure, I will repro and confirm this issue by end of this week,Sorry in traveling

Copy link
Copy Markdown

@yunzhongOvO yunzhongOvO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~

@shivam2199
Copy link
Copy Markdown
Contributor Author

Thanks @yunzhongOvO - @Johnsonms Can this merged now?

@Johnsonms Johnsonms self-requested a review May 13, 2026 21:11
@Johnsonms Johnsonms merged commit 484b981 into Dao-AILab:main May 13, 2026
reubenconducts pushed a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
…ao-AILab#2515)

num_splits_heuristic divides num_SMs by total_mblocks, which collapses
to 0 when seqlen_q == 0 or batch_size == 0 (e.g. CUDA graph padding or
empty microbatches). The existing seqlen_k == 0 early-exit in
_flash_attn_fwd does not cover these cases.

- Extend the early-exit to also cover total_q == 0, using the same
  zero-output / -inf-LSE contract. total_q is batch_size * seqlen_q
  (dense) or q.shape[0] (varlen), so a single predicate handles both
  code paths.
- Add a defensive total_mblocks == 0 guard inside num_splits_heuristic
  itself so the function is safe in isolation.
- Add regression tests covering dense (batch=0, seqlen_q=0) and varlen
  (total_q=0) paths under both causal and non-causal masks.

Fixes Dao-AILab#2503.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: ZeroDivisionError in num_splits_heuristic when seqlen_q=0 or batch_size=0

3 participants