Skip to content

Ameyn/gdn bf16 tolerance parallel reduction#2610

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn_bf16_tolerance_parallel_reduction
Feb 23, 2026
Merged

Ameyn/gdn bf16 tolerance parallel reduction#2610
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn_bf16_tolerance_parallel_reduction

Conversation

@ameynaik-hub
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub commented Feb 21, 2026

📌 Description

  1. fma2 not supported for hopper, fix for that for bf16 h state version of gdn decode.
  2. Increase atol_kv from 0.005 to 0.016 to accommodate 1 ULP differences in BF16
    that arise from parallel warp-level reductions vs sequential reference implementation.
    This fixes seed-specific test failures (e.g., seed=0 on Blackwell) without affecting
    kernel correctness. Validated across 160 test runs (5 seeds × 32 configs) with 100% pass rate.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Improved compatibility with SM90+ GPUs for BF16 (bfloat16) operations by adopting architecture-agnostic computation methods.
    • Enhanced numeric stability and accuracy in BF16 decoding operations through adjusted tolerance thresholds.

ameynaik-hub and others added 2 commits February 20, 2026 09:40
…duction precision

Increase atol_kv from 0.005 to 0.016 to accommodate 1 ULP differences in BF16
that arise from parallel warp-level reductions vs sequential reference implementation.
This fixes seed-specific test failures (e.g., seed=0 on Blackwell) without affecting
kernel correctness. Validated across 160 test runs (5 seeds × 32 configs) with 100% pass rate.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Replace cute.arch.fma_packed_f32x2() with scalar FP32 FMA operations.
The packed F32x2 intrinsics generate PTX instructions that are not
supported on SM90 (Hopper) architecture, causing compilation failures
with error: "F32x2 intrinsics are not supported on this architecture".

Changes:
- Add FMA wrapper functions (fma_pair, fma_pair_mul) using scalar ops
- Replace all 28 occurrences of cute.arch.fma_packed_f32x2()

Testing:
- All 44 unit tests pass (T=1,2,3,4 × BS=1-128)
- Correctness validated against BF16 state reference

Signed-off-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical compatibility issue with NVIDIA Hopper GPUs by refactoring FMA operations within the BF16 GDN decode kernels. It introduces scalar FMA wrappers to ensure proper execution on SM90+ architectures, which do not support packed F32x2 intrinsics. Additionally, testing tolerances have been refined to account for the numerical characteristics of BF16 parallel reductions.

Highlights

  • Hopper (SM90+) Compatibility: Implemented custom FMA (Fused Multiply-Add) wrapper functions (fma_pair and fma_pair_mul) to replace cute.arch.fma_packed_f32x2, ensuring compatibility with NVIDIA Hopper (SM90+) architectures where packed F32x2 intrinsics are not supported.
  • BF16 GDN Decode Kernels: Updated the BF16 H state version of GDN decode kernels to utilize the new scalar FMA operations, resolving a known incompatibility issue with Hopper GPUs.
  • Testing Tolerance Adjustment: Increased the absolute tolerance (atol_kv) for BF16 H state comparison in tests to accommodate for potential numerical differences arising from parallel reductions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
    • Added fma_pair_mul and fma_pair functions to provide SM90+ compatible scalar FMA operations.
    • Replaced all instances of cute.arch.fma_packed_f32x2 with the newly introduced fma_pair and fma_pair_mul functions.
    • Updated comments to clarify SM90+ compatibility and potential future optimizations for SM100+ architectures.
  • tests/gdn/test_decode_delta_rule.py
    • Adjusted atol_kv tolerance from 0.005 to 0.016 for BF16 h state comparison, accounting for 1 ULP (Unit in the Last Place) differences from parallel reductions.
Activity
  • The author has confirmed that pre-commit checks were installed and run successfully.
  • Tests have been added or updated as needed, and all tests are reported as passing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 21, 2026

📝 Walkthrough

Walkthrough

This pull request replaces architecture-specific FMA intrinsics with portable wrappers in the BF16 GDN decode kernel to improve SM90+ compatibility. A test tolerance threshold is adjusted to accommodate BF16 precision from parallel reductions. No public API changes.

Changes

Cohort / File(s) Summary
GDN Decode BF16 Kernel FMA Refactoring
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Replaced fma_packed_f32x2 intrinsic calls with architecture-agnostic fma_pair and fma_pair_mul wrappers throughout multiple kernel paths (normalize_and_store_qk_to_smem, decay_h_from_smem_and_compute_pred, update_h_with_delta, compute_output, decay_h_in_place, and variants) to improve compatibility while maintaining functional behavior.
Test Tolerance Adjustment
tests/gdn/test_decode_delta_rule.py
Increased atol_kv tolerance from 0.005 to 0.016 in BF16 gdn_decode_klast test to account for 1 ULP (unit in last place) BF16 precision loss from parallel reductions.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐰✨ A kernel reborn, from packed intrinsics free,
Scalar pairs now dance where architectures agree!
SM90 smiles, broader skies await,
Portable magic—no hardware gate.
BF16 sings true with tolerance's gentle sway,
Hop forward, dear code, in a portable way! 🎉

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title mentions BF16 tolerance and parallel reduction, which directly relates to the main changes: adjusting test tolerances for BF16 operations and replacing packed FMA intrinsics with scalar operations for SM90 compatibility.
Description check ✅ Passed The description provides clear context for both main changes (fma2 support fix and tolerance adjustment), includes detailed validation metrics, and follows the provided template with pre-commit and test checklists completed.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request correctly addresses the lack of support for packed FP32 FMA instructions on the Hopper (SM90) architecture by introducing scalar FMA wrapper functions. These wrappers (fma_pair and fma_pair_mul) replace cute.arch.fma_packed_f32x2 calls throughout the gdn_decode_bf16_state.py kernel, ensuring compatibility while maintaining numerical stability. Additionally, the test tolerance atol_kv has been increased to 0.016 to account for the precision limits of BF16 (approximately 1 ULP at magnitude 2.0) during parallel reductions. The changes are well-documented and improve the robustness of the kernel across different GPU architectures.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)

138-145: fma_pair_mul name is misleading — it performs plain multiplication, not FMA.

The function computes a*b with no addend, making the fma prefix misleading. Consider renaming to mul_pair to better reflect the operation. The docstring note about equivalence to fma_packed_f32x2 with c=(0,0) is mathematically accurate (since fma(a,b,0)==a*b in IEEE 754), but the name still confuses intent.

♻️ Rename proposal
-def fma_pair_mul(a1, a2, b1, b2):
-    """Multiply two pairs: (a1, a2) * (b1, b2).
-
-    Equivalent to fma_packed_f32x2 with c=(0,0), but compatible with SM90+.
-    """
+def mul_pair(a1, a2, b1, b2):
+    """Multiply two pairs element-wise: returns (a1*b1, a2*b2).
+
+    Scalar replacement for fma_packed_f32x2 with c=(0,0), compatible with SM90+.
+    """
     result1 = a1 * b1
     result2 = a2 * b2
     return result1, result2

And update all 9 call sites from fma_pair_mul(...) to mul_pair(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 138 - 145,
Rename the misleading function fma_pair_mul to mul_pair and update its docstring
to reflect that it performs element-wise multiplication (a1*b1, a2*b2) rather
than an FMA; modify the function definition name from fma_pair_mul to mul_pair
and update all 9 call sites that invoke fma_pair_mul(...) to mul_pair(...),
ensuring references (imports/exports, tests, and any uses in
gdn_decode_bf16_state.py and related modules) are updated to the new symbol.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 138-145: Rename the misleading function fma_pair_mul to mul_pair
and update its docstring to reflect that it performs element-wise multiplication
(a1*b1, a2*b2) rather than an FMA; modify the function definition name from
fma_pair_mul to mul_pair and update all 9 call sites that invoke
fma_pair_mul(...) to mul_pair(...), ensuring references (imports/exports, tests,
and any uses in gdn_decode_bf16_state.py and related modules) are updated to the
new symbol.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 22, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !337 has been created, and the CI pipeline #44542374 is currently running. I'll report back once the pipeline job completes.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

how can I merge?

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44542374: 14/20 passed

@yzh119 yzh119 merged commit 26ef055 into flashinfer-ai:main Feb 23, 2026
20 checks passed
ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

1. fma2 not supported for hopper, fix for that for bf16 h state version
of gdn decode.
2. Increase atol_kv from 0.005 to 0.016 to accommodate 1 ULP differences
in BF16
that arise from parallel warp-level reductions vs sequential reference
implementation.
This fixes seed-specific test failures (e.g., seed=0 on Blackwell)
without affecting
kernel correctness. Validated across 160 test runs (5 seeds × 32
configs) with 100% pass rate.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved compatibility with SM90+ GPUs for BF16 (bfloat16) operations
by adopting architecture-agnostic computation methods.
* Enhanced numeric stability and accuracy in BF16 decoding operations
through adjusted tolerance thresholds.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants