fix: use type-specific FP8 max value for clamping in RMSNorm and RoPE quantization kernels#2639
fix: use type-specific FP8 max value for clamping in RMSNorm and RoPE quantization kernels#2639Bias92 wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
… quantization kernels The hardcoded 448.0f clamp range is only correct for FP8 e4m3fn (max=448). When the output type is e5m2 (max=57344), this silently clips ~99.2% of the representable range, destroying information. This commit: - Introduces fp8_clamp_max<T> trait in a shared header (fp8_types.cuh) - Fixes norm.cuh: RMSNormQuantKernel and FusedAddRMSNormQuantKernel - Fixes pos_enc.cuh: RopeQuantizeKernel (q_rope, k_rope, k_nope paths) and RopeQuantizeAppendPagedKVCacheKernel (all 5 paths including V cache) - Adds clamping before cast_store in all FP8 quantization paths that previously had no overflow protection Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary of ChangesHello @Bias92, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical issues in FP8 quantization by ensuring that values are correctly clamped according to their specific FP8 type. It introduces a mechanism to retrieve type-aware maximum values for clamping, thereby preventing data loss from incorrect clipping in RMSNorm kernels and mitigating the risk of NaN/Inf propagation and KV cache corruption in RoPE quantization kernels. The changes enhance the robustness and accuracy of FP8 operations across the codebase. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds a new Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses issues with FP8 value clamping in RMSNorm and RoPE quantization kernels. It introduces a new fp8_types.cuh header for type-specific maximum values, which is a good approach. My review focuses on improving code readability and maintainability by reducing code duplication. I've suggested creating a helper function for the clamping logic, which is repeated in many places.
| template <> | ||
| struct fp8_clamp_max<__nv_fp8_e5m2> { | ||
| static constexpr float value = 57344.0f; | ||
| }; |
There was a problem hiding this comment.
To improve code readability and reduce repetition across norm.cuh and pos_enc.cuh, consider adding a helper function for clamping values before FP8 conversion. This pattern is used in multiple places.
You could add the following function to this file, which would make the call sites much cleaner (e.g., output_vec[j] = clamp<O>(output_vec[j]);).
};
template <typename T>
__device__ __forceinline__ float clamp(float val) {
constexpr float max_val = fp8_clamp_max<T>::value;
return fmaxf(-max_val, fminf(val, max_val));
}
include/flashinfer/norm.cuh
Outdated
| output_vec[j] = | ||
| float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; | ||
| output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); | ||
| output_vec[j] = fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value)); |
include/flashinfer/norm.cuh
Outdated
| for (uint32_t j = 0; j < VEC_SIZE; j++) { | ||
| output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; | ||
| output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); | ||
| output_vec[j] = fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value)); |
include/flashinfer/pos_enc.cuh
Outdated
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| vec[i] = vec[i] * scale; | ||
| vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value)); |
include/flashinfer/pos_enc.cuh
Outdated
| q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value)); |
include/flashinfer/pos_enc.cuh
Outdated
| q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value)); |
include/flashinfer/pos_enc.cuh
Outdated
| k_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(k_rope_vec[i], fp8_clamp_max<QuantType>::value)); |
include/flashinfer/pos_enc.cuh
Outdated
| k_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(k_nope_vec[i], fp8_clamp_max<QuantType>::value)); |
| v_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(v_vec[i], fp8_clamp_max<QuantType>::value)); |
include/flashinfer/pos_enc.cuh
Outdated
| q_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(q_nope_vec[i], fp8_clamp_max<QuantType>::value)); |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/norm.cuh`:
- Around line 218-219: Run clang-format on the header to fix
whitespace/formatting issues in the edited regions: reformat the block
containing the output_vec assignment and usages of fp8_clamp_max (the line with
output_vec[j] = fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j],
fp8_clamp_max<O>::value)); and the nearby lines around the other flagged region
(lines referencing output_vec and fp8_clamp_max at the second location). Ensure
pre-commit passes by applying the project's clang-format style (or running the
repo's format script) and commit the normalized file.
In `@include/flashinfer/pos_enc.cuh`:
- Line 25: The file fails clang-format checks; run the repository's clang-format
(or pre-commit) on pos_enc.cuh to fix whitespace, include ordering and style
issues—specifically reformat the include line referencing "fp8_types.cuh" and
the surrounding header contents so the file passes pre-commit clang-format
checks before merging.
- Line 277: The fp8 clamping call using fp8_clamp_max<QuantType> must be guarded
so it only instantiates for FP8 quant types; modify the places that call
fp8_clamp_max (e.g., inside RopeQuantize/RopeQuantizeKernel invoked by
BatchQKApplyRotaryPosIdsCosSinCache and the vec[i] clamp expression) to perform
the clamp only when QuantType is an FP8 type (use an existing is_fp8 trait or
add a constexpr helper like is_fp8_v<QuantType> and wrap the clamp in if
constexpr (is_fp8_v<QuantType>) { vec[i] = fmaxf(...
fp8_clamp_max<QuantType>::value ...); } else { /* no clamp */ } so the base
template assertion is never instantiated for non-FP8 types).
| #include <string> | ||
| #include <type_traits> | ||
|
|
||
| #include "fp8_types.cuh" |
There was a problem hiding this comment.
clang-format cleanup is still required for this file.
Pre-commit reports formatting diffs; please format this header before merging.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/pos_enc.cuh` at line 25, The file fails clang-format
checks; run the repository's clang-format (or pre-commit) on pos_enc.cuh to fix
whitespace, include ordering and style issues—specifically reformat the include
line referencing "fp8_types.cuh" and the surrounding header contents so the file
passes pre-commit clang-format checks before merging.
include/flashinfer/pos_enc.cuh
Outdated
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| vec[i] = vec[i] * scale; | ||
| vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value)); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# First, locate the key files
echo "=== Locating key files ==="
fd -t f "pos_enc.cuh" include/
fd -t f "fp8_types.cuh" include/
echo
echo "=== Checking fp8_clamp_max definition in fp8_types.cuh ==="
if fd -t f "fp8_types.cuh" include/ | head -1 | xargs cat -n; then
:
else
echo "fp8_types.cuh not found, searching for fp8_clamp_max definition"
rg -n "fp8_clamp_max" include/ --type cuh --type h -A 3 | head -50
fiRepository: flashinfer-ai/flashinfer
Length of output: 1923
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check the clamping line at 277 ==="
sed -n '270,285p' include/flashinfer/pos_enc.cuh
echo
echo "=== Check line 1391 with RopeQuantize call ==="
sed -n '1385,1395p' include/flashinfer/pos_enc.cuh
echo
echo "=== Check other cited clamping lines: 506-507, 533-534, 911-912, 942-943, 974-975, 1006-1007, 1037-1038 ==="
for lines in "506,507" "533,534" "911,912" "942,943" "974,975" "1006,1007" "1037,1038"; do
echo "--- Lines $lines ---"
sed -n "${lines}p" include/flashinfer/pos_enc.cuh
doneRepository: flashinfer-ai/flashinfer
Length of output: 2570
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Search for RopeQuantize template definition and usages ==="
rg -n "template.*RopeQuantize" include/ --type cuh --type h -A 5 | head -40
echo
echo "=== Find all RopeQuantize instantiations with template args ==="
rg -n "RopeQuantize\s*<" include/flashinfer/pos_enc.cuh | head -20Repository: flashinfer-ai/flashinfer
Length of output: 175
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Examine fp8_clamp_max specializations ==="
rg -n "fp8_clamp_max" include/ --type cuh --type h -B 2 -A 5 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 158
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Search for RopeQuantize template definition ==="
rg -n "template.*RopeQuantize" include/ -A 10 | head -60
echo
echo "=== Find all RopeQuantize<...> instantiations in pos_enc.cuh ==="
rg -n "RopeQuantize\s*<" include/flashinfer/pos_enc.cuh -A 2Repository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check the context around line 1391 to see what DType values are used ==="
sed -n '1360,1410p' include/flashinfer/pos_enc.cuh
echo
echo "=== Search for the RopeQuantize kernel function signature ==="
grep -n "kernel.*RopeQuantize\|RopeQuantize.*kernel\|__global__.*RopeQuantize" include/flashinfer/pos_enc.cuh | head -20Repository: flashinfer-ai/flashinfer
Length of output: 3171
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Find RopeQuantize template function definition ==="
rg -n "RopeQuantize" include/flashinfer/pos_enc.cuh | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1142
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Look for template instantiations with different DType/QuantType combinations ==="
rg -n "RopeQuantize.*DType.*IdType.*DType" include/flashinfer/pos_enc.cuh
rg -n "RopeQuantize.*float.*int.*float\|RopeQuantize.*half.*int.*half\|RopeQuantize.*bfloat16.*int.*bfloat16" include/flashinfer/pos_enc.cuh
echo
echo "=== Check function calls to rope_quantize or RopeQuantize ==="
rg -n "rope_quantize\|RopeQuantize" include/flashinfer/pos_enc.cuh --color=never | head -30Repository: flashinfer-ai/flashinfer
Length of output: 202
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Get the complete RopeQuantize template definition location and signature ==="
sed -n '1300,1350p' include/flashinfer/pos_enc.cuhRepository: flashinfer-ai/flashinfer
Length of output: 2517
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Get the RopeQuantize function definition (line 1049 onwards) ==="
sed -n '1049,1140p' include/flashinfer/pos_enc.cuh
echo
echo "=== Check if there are other RopeQuantize instantiations with different QuantType ==="
rg -n "RopeQuantize<.*," include/flashinfer/pos_enc.cuh | head -20Repository: flashinfer-ai/flashinfer
Length of output: 4964
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check what DType values are used when calling BatchQKApplyRotaryPosIdsCosSinCache ==="
rg -n "BatchQKApplyRotaryPosIdsCosSinCache" . --no-heading | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1012
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Look for explicit template instantiations or dispatch macros for DType ==="
rg -n "DISPATCH.*DTYPE\|instantiate.*DType\|float.*half.*bfloat" include/flashinfer/pos_enc.cuh | head -20
echo
echo "=== Check the RopeQuantizeKernel definition to understand QuantType usage ==="
sed -n '431,550p' include/flashinfer/pos_enc.cuhRepository: flashinfer-ai/flashinfer
Length of output: 146
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check csrc/rope.cu to see what DTypes are passed to BatchQKApplyRotaryPosIdsCosSinCache ==="
sed -n '140,170p' csrc/rope.cu
echo
echo "=== Look for template specializations or instantiation guards for rope functions ==="
rg -n "ROPE_INSTANTIATE\|DT_DISPATCH\|FOR_EACH_DTYPE" csrc/rope.cu | head -20
echo
echo "=== Check the full context of rope.cu around the call to BatchQKApplyRotaryPosIdsCosSinCache ==="
sed -n '100,200p' csrc/rope.cuRepository: flashinfer-ai/flashinfer
Length of output: 1864
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Look at the complete RopeQuantizeKernel definition to see where clamping is used ==="
sed -n '431,650p' include/flashinfer/pos_enc.cuhRepository: flashinfer-ai/flashinfer
Length of output: 10846
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Find scale_store_partial_chunk function to see if it also uses fp8_clamp_max ==="
rg -n "scale_store_partial_chunk" include/flashinfer/pos_enc.cuh -B 5 -A 20 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 3876
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Verify if BatchQKApplyRotaryPosIdsCosSinCache is instantiated with non-FP8 types ==="
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16" csrc/rope.cu -B 3 -A 3
echo
echo "=== Check what DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 macro does ==="
rg -n "define DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16" . -A 10Repository: flashinfer-ai/flashinfer
Length of output: 4380
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Determine all DType instantiations by checking the dispatcher ==="
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE" csrc/rope.cu | head -10
echo
echo "=== Check the macro definition to see what types it dispatches to ==="
rg -n "define DISPATCH_DLPACK_DTYPE_TO_CTYPE" . --type h --type cpp -A 20 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 6270
🏁 Script executed:
#!/bin/bash
set -euo zipefail
echo "=== Check the complete scale_store_partial_chunk function to see line 277 in context ==="
sed -n '251,290p' include/flashinfer/pos_enc.cuhRepository: flashinfer-ai/flashinfer
Length of output: 121
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Verify scale_store_partial_chunk is called from RopeQuantizeKernel with non-FP8 QuantType ==="
sed -n '540,580p' include/flashinfer/pos_enc.cuh
echo
echo "=== Look at how scale_store_partial_chunk uses QuantType for clamping ==="
sed -n '251,280p' include/flashinfer/pos_enc.cuhRepository: flashinfer-ai/flashinfer
Length of output: 3634
Guard FP8 clamping against non-FP8 QuantType instantiations.
The generic template fp8_clamp_max<QuantType> defined in fp8_types.cuh only specializes __nv_fp8_e4m3 and __nv_fp8_e5m2; the base template asserts on any other type. At line 1391, BatchQKApplyRotaryPosIdsCosSinCache<DType, IdType> calls RopeQuantize<DType, IdType, DType> where DType is dispatched via DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 (only __half and __nv_bfloat16). This causes RopeQuantizeKernel to instantiate fp8_clamp_max<__half> or fp8_clamp_max<__nv_bfloat16> when applying clamping at lines 277, 506–507, 533–534, 911–912, 942–943, 974–975, 1006–1007, and 1037–1038, triggering a compile-time assertion failure.
Proposed fix pattern
+template <typename QuantType>
+__device__ __forceinline__ float clamp_if_fp8(float x) {
+ if constexpr (std::is_same<QuantType, __nv_fp8_e4m3>::value ||
+ std::is_same<QuantType, __nv_fp8_e5m2>::value) {
+ return fmaxf(-fp8_clamp_max<QuantType>::value, fminf(x, fp8_clamp_max<QuantType>::value));
+ } else {
+ return x;
+ }
+}
...
- vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value));
+ vec[i] = clamp_if_fp8<QuantType>(vec[i]);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/pos_enc.cuh` at line 277, The fp8 clamping call using
fp8_clamp_max<QuantType> must be guarded so it only instantiates for FP8 quant
types; modify the places that call fp8_clamp_max (e.g., inside
RopeQuantize/RopeQuantizeKernel invoked by BatchQKApplyRotaryPosIdsCosSinCache
and the vec[i] clamp expression) to perform the clamp only when QuantType is an
FP8 type (use an existing is_fp8 trait or add a constexpr helper like
is_fp8_v<QuantType> and wrap the clamp in if constexpr (is_fp8_v<QuantType>) {
vec[i] = fmaxf(... fp8_clamp_max<QuantType>::value ...); } else { /* no clamp */
} so the base template assertion is never instantiated for non-FP8 types).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
♻️ Duplicate comments (1)
include/flashinfer/pos_enc.cuh (1)
277-279:⚠️ Potential issue | 🔴 CriticalGuard FP8 clamping so non-FP8
QuantTypepaths do not instantiate FP8-only traits.At Line 1392 this file instantiates
RopeQuantize<DType, IdType, DType>, while the new clamp sites unconditionally usefp8_clamp_max<QuantType>. This can break non-FP8 builds (or introduce unintended clipping) unless clamping is explicitly FP8-gated.Proposed fix pattern
+template <typename QuantType> +__device__ __forceinline__ float clamp_if_fp8(float x) { + if constexpr (std::is_same_v<QuantType, __nv_fp8_e4m3> || + std::is_same_v<QuantType, __nv_fp8_e5m2>) { + return fmaxf(-fp8_clamp_max<QuantType>::value, fminf(x, fp8_clamp_max<QuantType>::value)); + } else { + return x; + } +} ... - vec[i] = - fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value)); + vec[i] = clamp_if_fp8<QuantType>(vec[i]);Use this read-only check to confirm whether non-FP8 instantiations exist and whether
fp8_clamp_maxis FP8-only:#!/bin/bash set -euo pipefail echo "=== fp8_clamp_max definition / specializations ===" fd -t f "fp8_types.cuh" include --exec sed -n '1,220p' {} echo echo "=== RopeQuantize instantiations in pos_enc.cuh ===" rg -n "RopeQuantize<" include/flashinfer/pos_enc.cuh -C2 sed -n '1388,1398p' include/flashinfer/pos_enc.cuh echo echo "=== Unconditional fp8_clamp_max uses in pos_enc.cuh ===" rg -n "fp8_clamp_max<QuantType>" include/flashinfer/pos_enc.cuh -C1Also applies to: 507-508, 534-535, 912-913, 943-944, 975-976, 1007-1008, 1038-1039, 1392-1398
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/pos_enc.cuh` around lines 277 - 279, The FP8 clamp call uses fp8_clamp_max<QuantType> unconditionally and will instantiate FP8-only traits for non-FP8 QuantType; fix by gating the clamp with a compile-time FP8 trait check (e.g., use if constexpr (is_fp8<QuantType>::value) or equivalent project trait) so only FP8 instantiations call fp8_clamp_max<QuantType>::value; update the sites that set vec[i] (the fmaxf/fminf clamp) to perform the clamp only inside that if constexpr and leave vec[i] unchanged for non-FP8 types (apply the same pattern to all listed occurrences and the RopeQuantize<DType, IdType, DType> instantiation paths).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@include/flashinfer/pos_enc.cuh`:
- Around line 277-279: The FP8 clamp call uses fp8_clamp_max<QuantType>
unconditionally and will instantiate FP8-only traits for non-FP8 QuantType; fix
by gating the clamp with a compile-time FP8 trait check (e.g., use if constexpr
(is_fp8<QuantType>::value) or equivalent project trait) so only FP8
instantiations call fp8_clamp_max<QuantType>::value; update the sites that set
vec[i] (the fmaxf/fminf clamp) to perform the clamp only inside that if
constexpr and leave vec[i] unchanged for non-FP8 types (apply the same pattern
to all listed occurrences and the RopeQuantize<DType, IdType, DType>
instantiation paths).
Summary
RMSNormQuantKernelandFusedAddRMSNormQuantKernelhardcodedfmaxf(-448.0f, fminf(x, 448.0f))— correct for e4m3fn (max=448) but clips~99.2% of representable range for e5m2 (max=57344)
RopeQuantizeKernelandRopeQuantizeAppendPagedKVCacheKernelhad no clamping at all before
cast_store, risking NaN/Inf in quantized outputand paged KV cache corruption
fp8_clamp_max<T>type trait in a shared lightweight header(
fp8_types.cuh) to provide type-correct max values for both e4m3 and e5m2Changes by file
include/flashinfer/fp8_types.cuhfp8_clamp_max<T>traitinclude/flashinfer/norm.cuhinclude/flashinfer/pos_enc.cuhTest plan
Summary by CodeRabbit