Skip to content

feat(cute): implement softcap backward pass, correct math formula, and resolve JIT cache bug#2402

Merged
drisspg merged 4 commits intoDao-AILab:mainfrom
CaesarG:fix-fa4-softcap-scoremod
Apr 11, 2026
Merged

feat(cute): implement softcap backward pass, correct math formula, and resolve JIT cache bug#2402
drisspg merged 4 commits intoDao-AILab:mainfrom
CaesarG:fix-fa4-softcap-scoremod

Conversation

@CaesarG
Copy link
Copy Markdown
Contributor

@CaesarG CaesarG commented Mar 28, 2026

Description

This PR initially addressed the numerical deviation in FA4's softcap score modulation and aligned its output with the native FA2 and FA3 implementations. During the review and testing phase, it was expanded to fully implement the softcap backward pass and resolve underlying JIT compilation issues.

1. Forward Pass Corrections (Original Scope)

  1. Mathematical Bug: The create_softcap_scoremod function previously computed scores * tanh(scores). This PR corrects the formula to softcap_val * tanh(scores) to properly bound the logits to the [-c, c] range.
  2. Signature Mismatch: As noted in Fix missing seqlen_info param in softcap scoremod Fix missing seqlen_info param in softcap scoremod #2366, the seqlen_info argument was missing from the scoremod_premask_fn signature, causing a runtime error during @cute.jit execution. This parameter has now been added.

Relates to / Fixes #2396

2. Backward Pass & Infrastructure Expansions (New)

  1. Full Backward Pass Implementation:
    • Implemented softcap backward pass support across flash_attn/cute/interface.py and the respective backend bwd files for various SM architectures (SM80/SM90/SM100).
  2. Robust Unit Test Coverage:
    • Updated softcap parametrization in tests/cute/test_flash_attn.py from a hardcoded 0.0 to [0.0, 15.0].
    • Replaced the previous logic that implicitly skipped backward tests when softcap != 0.0. Forward and backward numerical validations are now actively checking the softcapping logic.
  3. TVM JIT Compile Cache Bug Fix:
    • Resolved a TVM JIT tracing artifact that caused a Mismatched mdK_semaphore.strides[2] expected to be 1 error during the full test suite run.
    • Fix: Appended (seqlen_q_rounded // m_block_size == 1) and (seqlen_k_rounded // n_block_size == 1) boolean flags to the compile_key in _flash_attn_bwd to prevent cache poisoning between single-block and multi-block sequences.

Verification

I wrote a test script to directly compare the outputs of FA2, FA3, and FA4 (dtype=torch.bfloat16, SOFTCAP_VAL=50.0).

Click to expand the reproducible test script

(Paste your original python script here)

Results

Before this PR (Original FA4):

  • ✅ FA2 vs FA3 Passed: Max Absolute Error 0.003906
  • ❌ FA2 vs FA4 Failed: Max Absolute Error 1.367188
  • ❌ FA3 vs FA4 Failed: Max Absolute Error 1.367188

After this PR:

  • ✅ FA2 vs FA3 Passed: Max Absolute Error 0.003906
  • ✅ FA2 vs FA4 Passed: Max Absolute Error 0.003906
  • ✅ FA3 vs FA4 Passed: Max Absolute Error 0.003906

Pytest Validation: All updated Pytest suites (forward and backward with softcap=15.0) now pass successfully on local SM90 (Hopper) testing.
(Note: As I only have local access to an SM90 GPU, the SM80 and SM100 backward kernels might require additional verification via CI).

@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 28, 2026

cc @drisspg: It seems the softcap score mod wasn't right before?

@drisspg drisspg self-requested a review March 28, 2026 21:05
Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is forsure a bug and the new equation is correct

@drisspg
Copy link
Copy Markdown
Collaborator

drisspg commented Mar 28, 2026

@CaesarG did you look at the unit tests - do we have tests for softcapping and can we fix them if so that they would have caught this

@CaesarG CaesarG force-pushed the fix-fa4-softcap-scoremod branch from d029a8f to c1baeda Compare March 29, 2026 04:37
@CaesarG
Copy link
Copy Markdown
Contributor Author

CaesarG commented Mar 29, 2026

Hi @drisspg, thanks for pointing that out! I just took a close look at the unit tests and you are absolutely right.

Here is what I found and what I am currently working on:

  1. Forward Tests: The softcap parameter in the test files was previously set to 0.0, which effectively bypassed the softcapping logic. I have changed it to parametrize over [0.0, 15.0] so it will catch these numerical deviations moving forward.
  2. Backward Implementation & Tests: I noticed that the backward pass tests explicitly skip execution if softcap != 0.0 because the backward support was missing/incomplete. Instead of just modifying the tests, I went ahead and updated flash_attn/cute/interface.py and the respective backward implementation files for different SM architectures to fully support softcap in the backward pass.
  3. Current Status & Hardware Limitation: I am currently running the updated full test suite locally to ensure both forward and backward passes are correct. Since it covers a massive amount of parameter combinations, it will take some time. Please note that I only have access to an SM90 GPU locally, so I can only fully verify the SM90 kernels. The SM80 and SM100 backward implementations might need to be validated by your CI or on different hardware.

By the way, when running the full test suite, I noticed a sporadic TVM FFI stride mismatch error (Mismatched mdK_semaphore.strides[2] ... expected to be 1) on dK_semaphore. It seems to be a JIT tracing artifact when seqlen_k_rounded // n_block_size == 1 causes TVM to hardcode a strict stride checker, which then fails on subsequent tests with larger seqlen_k under the same compile_key. Running the tests individually bypasses it.

I will push the updated implementation and test cases to the PR as soon as my local SM90 run finishes

@CaesarG CaesarG force-pushed the fix-fa4-softcap-scoremod branch from 7a7de9e to d34f0a8 Compare March 29, 2026 14:51
@CaesarG
Copy link
Copy Markdown
Contributor Author

CaesarG commented Mar 29, 2026

Hi @drisspg, a quick update on the progress:

I've pushed the new commits containing the full backward pass implementation for softcap, the updated test coverage, and the TVM FFI cache fix we discussed.

Regarding the tests: While the isolated tests and the large seqlen subsets I ran yesterday passed perfectly, I'm actually seeing a relatively high failure rate when running the full test suite sequentially today.

I am currently investigating these failures to figure out the exact root cause. I'll keep you posted as soon as I have it fully resolved and a completely green sequential run.

@CaesarG CaesarG changed the title fix(cute): correct softcap math formula and add missing seqlen_info feat(cute): implement softcap backward pass, correct math formula, and resolve JIT cache bug Mar 29, 2026
CaesarG added 3 commits April 1, 2026 20:39
- Fixed the mathematical formula in `create_softcap_scoremod` to compute `softcap_val * tanh(scores)` instead of `scores * tanh(scores)`, properly bounding the logits.
- Added the missing `seqlen_info` parameter to `scoremod_premask_fn` signature to prevent runtime errors during @cute.jit execution.
- Verified that these changes perfectly align FA4 softcap outputs with native FA2 and FA3 implementations.
- **Backward Softcap Support**: Implemented full backward pass support for `softcap` across different SM architectures in `interface.py` and respective `bwd` backend files.
- **Unit Tests Coverage**: Updated `softcap` parametrization in `tests/cute/test_flash_attn.py` from a hardcoded `0.0` to `[0.0, 15.0]`. This ensures proper forward and backward numerical validation, replacing the previous logic that skipped backward tests when `softcap != 0.0`.
- **Compile Cache Fix**: Fixed a TVM JIT tracing artifact that caused `Mismatched mdK_semaphore.strides[2] expected to be 1` during full test suite execution. Appended `(seqlen_q_rounded // m_block_size == 1)` and `(seqlen_k_rounded // n_block_size == 1)` to `compile_key` in `_flash_attn_bwd` to prevent cache poisoning between single-block and multi-block sequence lengths.
- hash generated softcap score modifiers after softcap conversion
- simplify backward softcap/score_mod handling
- retry selected CUTE tests after clearing compile caches on CUDA OOM
@CaesarG CaesarG force-pushed the fix-fa4-softcap-scoremod branch from d34f0a8 to 473dabb Compare April 1, 2026 13:00
@CaesarG
Copy link
Copy Markdown
Contributor Author

CaesarG commented Apr 1, 2026

Hi @drisspg,

The full CUTE test suitetests/cute/test_flash_attn.py is now 100% green! 🟢

Here is a quick summary of the latest fixes and the overall PR:

Latest Fixes:

  • Fwd JIT Cache Collision: Moved the score_mod_hash computation after the softcap conversion block. This elegantly fixes the cache poisoning between softcap=0.0 and softcap=15.0 kernels without changing the compile_key.
  • Test Suite OOMs: Added a clean @retry_on_oom decorator for the massive sequential tests (like varlen). It strictly catches torch.OutOfMemoryError to safely clear compile caches and PyTorch memory fragmentation.

Full PR Recap:

  1. Corrected softcap math formula and added missing seqlen_info parameter.
  2. Implemented full backward pass support for softcap.
  3. Fixed bwd JIT cache stride bug (mdK_semaphore.strides[2]) for multi-block sequences.
  4. Fixed fwd JIT cache collision.
  5. Fortified test stability against sequential OOMs.

Everything is numerically validated and completely stable. Ready for your review! 🚀

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 3, 2026

Great, let's merge when it's ready

Comment thread flash_attn/cute/flash_bwd.py
- Fwd Kernel Fix: Added `apply_score_mod` wrapper in the forward kernel.
- Interface Guards: Added explicit `NotImplementedError` early-exits in `interface.py` for custom `score_mod` on non-SM90+ architectures.
@CaesarG
Copy link
Copy Markdown
Contributor Author

CaesarG commented Apr 11, 2026

Hi @tridao @drisspg,

Great, let's merge when it's ready

Yes, it's fully ready to go from my end! Feel free to merge whenever you're ready. Thanks

@drisspg drisspg merged commit 14f3627 into Dao-AILab:main Apr 11, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Maybe incorrect mathematical formula in create_softcap_scoremod (FA4 Score Modulation)

3 participants