fix: Fix NaN output in mxfp8_quantize for very small input values#2441
fix: Fix NaN output in mxfp8_quantize for very small input values#2441yzh119 merged 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical numerical stability issue within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughA fix to the MXFP8 quantization kernel prevents division-by-zero errors by checking if the computed SFValue is zero before computing its reciprocal. Corresponding test cases validate the fix with denormal inputs, all-zeros inputs, mixed magnitudes, and isolated denormal values to ensure no NaN or Inf propagation. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 🧪 Unit Test Generation v2 is now available!We have significantly improved our unit test generation capabilities. To enable: Add this to your reviews:
finishing_touches:
unit_tests:
enabled: trueTry it out by using the Have feedback? Share your thoughts on our Discord thread! Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request provides a solid fix for an issue where mxfp8_quantize() could produce NaN values for very small denormalized inputs. The root cause was correctly identified as an underflow problem during scale factor computation, and the fix in quantization_utils.cuh is precise and well-commented. The accompanying unit tests are comprehensive, covering various edge cases like denormal inputs, all-zeros, and mixed magnitudes, which is excellent for ensuring the bug is resolved and preventing regressions. I have one suggestion to improve the new tests by refactoring duplicated code.
| major, _ = get_compute_capability(torch.device("cuda:0")) | ||
| if major < 10: | ||
| pytest.skip("mxfp8 quantization is not supported on compute capability < 10") |
There was a problem hiding this comment.
This compute capability check is duplicated across all four new tests and some existing ones. To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function or a pytest fixture.
For example, you could define a helper function:
def _require_sm100():
major, _ = get_compute_capability(torch.device("cuda:0"))
if major < 10:
pytest.skip("mxfp8 quantization is not supported on compute capability < 10")And then call _require_sm100() at the start of each relevant test. This would make the tests cleaner and avoid repeating the same logic.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tests/utils/test_fp8_quantize.py`:
- Line 179: Rename the unused unpacked scalar-factor bindings from a_sf to a
deliberately ignored name to silence Ruff RUF059: change occurrences where you
unpack mxfp8_quantize into "a_fp8, a_sf" (calls to mxfp8_quantize) to use
"_a_sf" or "_" (e.g., "a_fp8, _a_sf = mxfp8_quantize(...)") in all four places
so the variable is intentionally marked unused while preserving the a_fp8
binding and test behavior.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)
173-177: Make denormal fixtures dtype‑aware so float16 cases don’t underflow to zero.The constants
1e-38/1e-40become0after casting totorch.float16, so those parameterizations won’t exercise the denormal path. Consider deriving values fromtorch.finfo(dtype).tinyso both float16 and bfloat16 cases contain non‑zero subnormals.♻️ Suggested pattern
- a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous() + denorm = torch.finfo(dtype).tiny / 2 # ensures subnormal after cast + a = (torch.randn([m, k], dtype=torch.float32) * denorm).to(dtype).cuda().contiguous()Apply the same idea to the mixed‑magnitude and single‑denormal fixtures.
Also applies to: 225-235, 264-273
| # These values caused NaN in the original buggy implementation | ||
| a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous() | ||
|
|
||
| a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) |
There was a problem hiding this comment.
Silence Ruff RUF059 by marking a_sf as intentionally unused.
Ruff flags these unpacked a_sf bindings as unused. Rename them to _a_sf (or _) to avoid lint failures.
🧹 Minimal fix (apply to all four occurrences)
- a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
+ a_fp8, _a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)Also applies to: 201-201, 237-237, 276-276
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 179-179: Unpacked variable a_sf is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@tests/utils/test_fp8_quantize.py` at line 179, Rename the unused unpacked
scalar-factor bindings from a_sf to a deliberately ignored name to silence Ruff
RUF059: change occurrences where you unpack mxfp8_quantize into "a_fp8, a_sf"
(calls to mxfp8_quantize) to use "_a_sf" or "_" (e.g., "a_fp8, _a_sf =
mxfp8_quantize(...)") in all four places so the variable is intentionally marked
unused while preserving the a_fp8 binding and test behavior.
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
LG, should be ready after gitlab CI finished.
|
[FAILED] Pipeline #42860088: 11/20 passed |
📌 Description
Problem
mxfp8_quantize()produces NaN values when input tensors contain very small denormalized values (~1e-35 to 1e-40).Root Cause
In
cvt_warp_fp16_to_mxfp8(), the output scale computation checksvecMax != 0.finstead ofSFValue != 0.f. When inputs are tiny,vecMaxis non-zero but SFValue underflows to zero after E8M0 conversion (minimum ~5.88e-39). This causesreciprocal(0) = infinity, and when float32 denormals are flushed to zero by FTZ mode, we get0 × ∞ = NaN.Fix
Change the condition from
vecMax != 0.ftoSFValue != 0.f, so that underflowed scale factors produce 0 outputs instead of NaN.Testing
Added 4 unit tests in
tests/utils/test_fp8_quantize.pycovering denormal inputs, all-zeros, mixed magnitudes, and the exact bug reproduction scenario.On B200, they pass as
136 passed in 1.07s🔍 Related Issues
#2440
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.