fix: use type-specific FP8 max value for clamping in RMSNorm quantization kernels#2612
fix: use type-specific FP8 max value for clamping in RMSNorm quantization kernels#2612Bias92 wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @Bias92, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in FP8 quantization kernels where a hardcoded clamp value for E4M3 was being applied universally, leading to significant data truncation for E5M2 output types. By introducing a type-aware trait, the system now dynamically applies the correct maximum clamp value based on the specific FP8 format, thereby improving numerical precision and correctness in RMSNorm and FusedAddRMSNorm quantization. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughIntroduces a templated FP8 clamp bound ( Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
The pull request effectively addresses the hardcoded FP8 max value issue by introducing a type-aware fp8_clamp_max<O> trait. This change correctly handles both E4M3 and E5M2 FP8 types, preventing incorrect truncation of the representable range. The implementation is clean and directly resolves the identified problem.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
include/flashinfer/norm.cuh (1)
148-159: Dispatch coverage verified; code is correct.The
fp8_clamp_maxtrait values (448.0f for E4M3, 57344.0f for E5M2) are correct. The incomplete primary template compiles safely because theDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8macro exhaustively restricts the output typeOto only__nv_fp8_e4m3and__nv_fp8_e5m2before kernel instantiation; any other type fails at runtime with an explicitTVM_FFI_ICHECKerror, not a cryptic compile-time incomplete-type message.Optional improvements (not required):
- Diagnostic clarity — Adding a dependent-false
static_assertin the primary template would make unsupported types more explicit, though the practical risk is low given the dispatch guards:♻️ Optional: improve diagnostic
template <typename T> -struct fp8_clamp_max; +struct fp8_clamp_max { + static_assert(sizeof(T) == 0, + "fp8_clamp_max: unsupported FP8 type; add a specialization for this type."); +};
cuda::std::numeric_limitsalternative — CCCL issue#3349tracks extendingcuda::std::numeric_limitsfor FP8 types. If the CUDA toolkit in use supports it, the hardcoded constants can be replaced withcuda::std::numeric_limits<O>::max(). Verify toolkit compatibility before switching.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/norm.cuh` around lines 148 - 159, Add an explicit compile-time diagnostic for unsupported FP8 types by updating the primary template fp8_clamp_max to contain a dependent-false static_assert that triggers when instantiated with any type other than the specialized __nv_fp8_e4m3 and __nv_fp8_e5m2; keep the existing specializations as-is and do not change dispatch logic (DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8) — this will make errors clearer if someone tries to instantiate fp8_clamp_max with an unsupported type.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/norm.cuh`:
- Around line 148-159: Add an explicit compile-time diagnostic for unsupported
FP8 types by updating the primary template fp8_clamp_max to contain a
dependent-false static_assert that triggers when instantiated with any type
other than the specialized __nv_fp8_e4m3 and __nv_fp8_e5m2; keep the existing
specializations as-is and do not change dispatch logic
(DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8) — this will make errors clearer if someone
tries to instantiate fp8_clamp_max with an unsupported type.
febb4f4 to
11088ee
Compare
yzh119
left a comment
There was a problem hiding this comment.
LGTM, thanks for the fix.
|
/bot run |
|
[FAILED] Pipeline #44589115: 13/20 passed |
|
The 3 failing checks appear unrelated to this change — remove-label is a permissions issue for external contributors, and the JIT Unittest failures on T4/A10G were cancelled due to infrastructure timeouts before any tests ran |
|
Hi @jiahanc, @kahyunnam, @lwakuraRein, @nv-yunzheq — I hope you're all doing well! I wanted to send a gentle ping on this PR, as @yzh119 has kindly approved it. Regarding the failing CI checks — these appear to be unrelated to the actual change:
|
Summary
Replace hardcoded FP8 E4M3 clamp value (448.0) with a type-aware
fp8_clamp_max<O>trait inRMSNormQuantKernelandFusedAddRMSNormQuantKernel.Problem
Both kernels hardcode the E4M3 max value for output clamping:
However, the output type
Ois dispatched viaDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8incsrc/norm.cu, which handles both E4M3 (max=448) and E5M2 (max=57344). When the output dtype is E5M2, this incorrectly truncates ~99% of the representable range.Fix
Added
fp8_clamp_max<T>trait with correct max values:__nv_fp8_e4m3: 448.0f__nv_fp8_e5m2: 57344.0fApplied to both
RMSNormQuantKernelandFusedAddRMSNormQuantKernel.Summary by CodeRabbit