Skip to content

Disable 2CTA fwd non-causal on CUDA 12 to work around codegen regression#2461

Merged
Johnsonms merged 2 commits into
mainfrom
Johnsonms/disable-2cta-cuda12.9
Apr 15, 2026
Merged

Disable 2CTA fwd non-causal on CUDA 12 to work around codegen regression#2461
Johnsonms merged 2 commits into
mainfrom
Johnsonms/disable-2cta-cuda12.9

Conversation

@Johnsonms
Copy link
Copy Markdown
Collaborator

@Johnsonms Johnsonms commented Apr 15, 2026

Summary

  • CUDA 12.9 has a codegen regression that causes ~18% slowdown for 2CTA forward non-causal
    (hdim=128: ~1280 vs ~1542 TFLOPS)
  • Auto-disable 2CTA when CUDA 12.9 is detected via torch.version.cuda
  • Users on CUDA 13.x are unaffected — 2CTA stays enabled and performs as expected
  • The manual FA_DISABLE_2CTA=1 env var continues to work regardless of CUDA version

Confirmed by @Johnsonms and @jshahOSS on B200 with CUDA 12.9.

Test plan

  • Verify on CUDA 12.9: non-causal hdim=128 should now use 1CTA (~1542 TFLOPS)
  • Verify on CUDA 13.x: non-causal hdim=128 should still use 2CTA (~1597 TFLOPS)
  • Verify FA_DISABLE_2CTA=1 still works as manual override
python benchmarks/benchmark_attn.py \                                            
  --fwd --backend fa4 \                                                                            
  --headdim 128 --batch-size 4 --seqlen 8192 \    
  --causal both    

…egression

CUDA 12.9 has a codegen issue that causes ~18% slowdown for 2CTA
forward non-causal (hdim=128: 1280 vs 1542 TFLOPS). This is fixed
in CUDA 13.x. Auto-disable 2CTA when CUDA 12.9 is detected.

Users on CUDA 13.x are unaffected. The manual `FA_DISABLE_2CTA=1`
override continues to work regardless of CUDA version.
Comment thread flash_attn/cute/utils.py Outdated
cuda_version = torch.version.cuda
if cuda_version is not None:
major, minor = cuda_version.split(".")[:2]
return int(major) == 12 and int(minor) == 9
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just check int(major) == 12

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will change. Thanks Tri

@Johnsonms Johnsonms changed the title Disable 2CTA fwd non-causal on CUDA 12.9 to work around codegen regression Disable 2CTA fwd non-causal on CUDA 12 to work around codegen regression Apr 15, 2026
@Johnsonms Johnsonms merged commit 83b8b8f into main Apr 15, 2026
ussoewwin pushed a commit to ussoewwin/flash-attention that referenced this pull request May 13, 2026
…ion (Dao-AILab#2461)

* Disable 2CTA forward non-causal on CUDA 12.9 to work around codegen regression

CUDA 12.9 has a codegen issue that causes ~18% slowdown for 2CTA
forward non-causal (hdim=128: 1280 vs 1542 TFLOPS). This is fixed
in CUDA 13.x. Auto-disable 2CTA when CUDA 12.9 is detected.

Users on CUDA 13.x are unaffected. The manual `FA_DISABLE_2CTA=1`
override continues to work regardless of CUDA version.

* Disable 2CTA forward non-causal on all CUDA 12.x (not just 12.9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants