Skip to content

fix: Fix NaN output in mxfp8_quantize for very small input values#2441

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
bkryu:mxfp8_quant_fix
Jan 30, 2026
Merged

fix: Fix NaN output in mxfp8_quantize for very small input values#2441
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
bkryu:mxfp8_quant_fix

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Jan 29, 2026

📌 Description

Problem

mxfp8_quantize() produces NaN values when input tensors contain very small denormalized values (~1e-35 to 1e-40).

Root Cause
In cvt_warp_fp16_to_mxfp8(), the output scale computation checks vecMax != 0.f instead of SFValue != 0.f. When inputs are tiny, vecMax is non-zero but SFValue underflows to zero after E8M0 conversion (minimum ~5.88e-39). This causes reciprocal(0) = infinity, and when float32 denormals are flushed to zero by FTZ mode, we get 0 × ∞ = NaN.

Fix
Change the condition from vecMax != 0.f to SFValue != 0.f, so that underflowed scale factors produce 0 outputs instead of NaN.

Testing
Added 4 unit tests in tests/utils/test_fp8_quantize.py covering denormal inputs, all-zeros, mixed magnitudes, and the exact bug reproduction scenario.

On B200, they pass as 136 passed in 1.07s

🔍 Related Issues

#2440

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Fixed invalid reciprocal calculation in MXFP8 quantization that could produce NaN results with certain input patterns.
  • Tests

    • Added comprehensive test coverage for extreme values, zero inputs, denormal values, and mixed magnitude data to ensure robust quantization behavior across edge cases and layouts.

✏️ Tip: You can customize this high-level summary in your review settings.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical numerical stability issue within the mxfp8_quantize function, which previously generated NaN outputs for extremely small, denormalized input values. The core problem stemmed from an incorrect check during scale factor calculation, leading to undefined results. The solution involves a precise adjustment to this calculation, ensuring that underflow conditions are handled gracefully and produce expected zero outputs. This change is thoroughly validated by a new suite of unit tests designed to cover edge cases and prevent regressions.

Highlights

  • Bug Fix: NaN Output in mxfp8_quantize: Resolved an issue where the mxfp8_quantize function produced NaN values when processing very small denormalized input tensors (e.g., ~1e-35 to 1e-40).
  • Root Cause Addressed: The bug was traced to an incorrect condition (vecMax != 0.f instead of SFValue != 0.f) in cvt_warp_fp16_to_mxfp8. This led to scale factor underflow to zero, causing a 0 * infinity = NaN scenario when float32 denormals were flushed to zero by FTZ mode.
  • Corrected Scale Factor Calculation: The fix involves changing the output scale computation condition to SFValue != 0.f, ensuring that underflowed scale factors correctly result in zero outputs instead of NaNs.
  • Comprehensive Unit Testing: Four new unit tests have been added to tests/utils/test_fp8_quantize.py to cover various scenarios, including denormal inputs, all-zero inputs, mixed magnitude inputs, and the exact bug reproduction scenario, all passing on B200.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

A fix to the MXFP8 quantization kernel prevents division-by-zero errors by checking if the computed SFValue is zero before computing its reciprocal. Corresponding test cases validate the fix with denormal inputs, all-zeros inputs, mixed magnitudes, and isolated denormal values to ensure no NaN or Inf propagation.

Changes

Cohort / File(s) Summary
MXFP8 Quantization Kernel Fix
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
Modified reciprocal computation condition from vecMax != 0.f to SFValue != 0.f to directly guard against division by zero and prevent NaN propagation when computed SFValue is zero.
MXFP8 Quantization Test Coverage
tests/utils/test_fp8_quantize.py
Added four parameterized pytest functions (test_mxfp8_quantize_denormal_inputs, test_mxfp8_quantize_all_zeros, test_mxfp8_quantize_mixed_magnitude, test_mxfp8_quantize_single_denormal_in_block) covering edge cases with denormal values, zero inputs, and mixed magnitudes across dtype and layout variants.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • djmmoss
  • wenscarl
  • jimmyzho
  • nvmbreughe
  • yongwww

Poem

🐰 A tale of zero-guard with care,
Where MXFP8 quantized with flair,
SFValue checks before divide,
Denormals tested far and wide,
No NaNs now through the warren's flight! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main fix: addressing NaN output in mxfp8_quantize for very small input values.
Description check ✅ Passed The description comprehensively covers the problem statement, root cause analysis, the applied fix, and testing verification, following the template structure with all required sections completed.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

🧪 Unit Test Generation v2 is now available!

We have significantly improved our unit test generation capabilities.

To enable: Add this to your .coderabbit.yaml configuration:

reviews:
  finishing_touches:
    unit_tests:
      enabled: true

Try it out by using the @coderabbitai generate unit tests command on your code files or under ✨ Finishing Touches on the walkthrough!

Have feedback? Share your thoughts on our Discord thread!


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a solid fix for an issue where mxfp8_quantize() could produce NaN values for very small denormalized inputs. The root cause was correctly identified as an underflow problem during scale factor computation, and the fix in quantization_utils.cuh is precise and well-commented. The accompanying unit tests are comprehensive, covering various edge cases like denormal inputs, all-zeros, and mixed magnitudes, which is excellent for ensuring the bug is resolved and preventing regressions. I have one suggestion to improve the new tests by refactoring duplicated code.

Comment on lines +169 to +171
major, _ = get_compute_capability(torch.device("cuda:0"))
if major < 10:
pytest.skip("mxfp8 quantization is not supported on compute capability < 10")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This compute capability check is duplicated across all four new tests and some existing ones. To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function or a pytest fixture.

For example, you could define a helper function:

def _require_sm100():
    major, _ = get_compute_capability(torch.device("cuda:0"))
    if major < 10:
        pytest.skip("mxfp8 quantization is not supported on compute capability < 10")

And then call _require_sm100() at the start of each relevant test. This would make the tests cleaner and avoid repeating the same logic.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/utils/test_fp8_quantize.py`:
- Line 179: Rename the unused unpacked scalar-factor bindings from a_sf to a
deliberately ignored name to silence Ruff RUF059: change occurrences where you
unpack mxfp8_quantize into "a_fp8, a_sf" (calls to mxfp8_quantize) to use
"_a_sf" or "_" (e.g., "a_fp8, _a_sf = mxfp8_quantize(...)") in all four places
so the variable is intentionally marked unused while preserving the a_fp8
binding and test behavior.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)

173-177: Make denormal fixtures dtype‑aware so float16 cases don’t underflow to zero.

The constants 1e-38/1e-40 become 0 after casting to torch.float16, so those parameterizations won’t exercise the denormal path. Consider deriving values from torch.finfo(dtype).tiny so both float16 and bfloat16 cases contain non‑zero subnormals.

♻️ Suggested pattern
-    a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous()
+    denorm = torch.finfo(dtype).tiny / 2  # ensures subnormal after cast
+    a = (torch.randn([m, k], dtype=torch.float32) * denorm).to(dtype).cuda().contiguous()

Apply the same idea to the mixed‑magnitude and single‑denormal fixtures.

Also applies to: 225-235, 264-273

# These values caused NaN in the original buggy implementation
a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous()

a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence Ruff RUF059 by marking a_sf as intentionally unused.

Ruff flags these unpacked a_sf bindings as unused. Rename them to _a_sf (or _) to avoid lint failures.

🧹 Minimal fix (apply to all four occurrences)
-    a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
+    a_fp8, _a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)

Also applies to: 201-201, 237-237, 276-276

🧰 Tools
🪛 Ruff (0.14.14)

[warning] 179-179: Unpacked variable a_sf is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@tests/utils/test_fp8_quantize.py` at line 179, Rename the unused unpacked
scalar-factor bindings from a_sf to a deliberately ignored name to silence Ruff
RUF059: change occurrences where you unpack mxfp8_quantize into "a_fp8, a_sf"
(calls to mxfp8_quantize) to use "_a_sf" or "_" (e.g., "a_fp8, _a_sf =
mxfp8_quantize(...)") in all four places so the variable is intentionally marked
unused while preserving the a_fp8 binding and test behavior.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 29, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !275 has been created, and the CI pipeline #42860088 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu self-assigned this Jan 29, 2026
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, should be ready after gitlab CI finished.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #42860088: 11/20 passed

@yzh119 yzh119 merged commit 560a9de into flashinfer-ai:main Jan 30, 2026
20 of 26 checks passed
@bkryu bkryu deleted the mxfp8_quant_fix branch February 2, 2026 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants