Skip to content

fix: use type-specific FP8 max value for clamping in RMSNorm and RoPE quantization kernels#2639

Open
Bias92 wants to merge 2 commits intoflashinfer-ai:mainfrom
Bias92:fix/fp8-e5m2-clamp-range
Open

fix: use type-specific FP8 max value for clamping in RMSNorm and RoPE quantization kernels#2639
Bias92 wants to merge 2 commits intoflashinfer-ai:mainfrom
Bias92:fix/fp8-e5m2-clamp-range

Conversation

@Bias92
Copy link
Contributor

@Bias92 Bias92 commented Feb 26, 2026

Summary

  • norm.cuh: RMSNormQuantKernel and FusedAddRMSNormQuantKernel hardcoded
    fmaxf(-448.0f, fminf(x, 448.0f)) — correct for e4m3fn (max=448) but clips
    ~99.2% of representable range for e5m2 (max=57344)
  • pos_enc.cuh: RopeQuantizeKernel and RopeQuantizeAppendPagedKVCacheKernel
    had no clamping at all before cast_store, risking NaN/Inf in quantized output
    and paged KV cache corruption
  • Introduces fp8_clamp_max<T> type trait in a shared lightweight header
    (fp8_types.cuh) to provide type-correct max values for both e4m3 and e5m2

Changes by file

File Change
include/flashinfer/fp8_types.cuh New — shared fp8_clamp_max<T> trait
include/flashinfer/norm.cuh Replace hardcoded 448.0f (2 sites)
include/flashinfer/pos_enc.cuh Add FP8 clamping in all 8 quantize paths

Test plan

  • Existing tests pass (norm, rope, attention with FP8 KV cache)
  • e5m2: values in (448, 57344] no longer clipped
  • e4m3: behavior unchanged (448.0f clamp preserved)
  • fmaxf/fminf compile to single PTX instructions (no perf regression)

Summary by CodeRabbit

  • Bug Fixes
    • Improved numerical stability of FP8 quantization and rotary/normalization kernels by enforcing proper clamping to each FP8 format's representable range before conversion, preventing potential NaN/Inf results.
  • Chores
    • Centralized clamp limits for FP8 formats so quantization kernels use the correct bounds consistently.

… quantization kernels

The hardcoded 448.0f clamp range is only correct for FP8 e4m3fn (max=448).
When the output type is e5m2 (max=57344), this silently clips ~99.2% of the
representable range, destroying information.

This commit:
- Introduces fp8_clamp_max<T> trait in a shared header (fp8_types.cuh)
- Fixes norm.cuh: RMSNormQuantKernel and FusedAddRMSNormQuantKernel
- Fixes pos_enc.cuh: RopeQuantizeKernel (q_rope, k_rope, k_nope paths)
  and RopeQuantizeAppendPagedKVCacheKernel (all 5 paths including V cache)
- Adds clamping before cast_store in all FP8 quantization paths that
  previously had no overflow protection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Bias92, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical issues in FP8 quantization by ensuring that values are correctly clamped according to their specific FP8 type. It introduces a mechanism to retrieve type-aware maximum values for clamping, thereby preventing data loss from incorrect clipping in RMSNorm kernels and mitigating the risk of NaN/Inf propagation and KV cache corruption in RoPE quantization kernels. The changes enhance the robustness and accuracy of FP8 operations across the codebase.

Highlights

  • Type-Specific FP8 Clamping: Introduced a new fp8_clamp_max<T> type trait to dynamically provide the correct maximum representable value for different FP8 types (e.g., e4m3 and e5m2). This ensures that quantization clamping is accurate for the specific FP8 format being used.
  • RMSNorm Quantization Fix: Corrected the RMSNormQuantKernel and FusedAddRMSNormQuantKernel by replacing a hardcoded 448.0f clamping value with the new type-specific fp8_clamp_max<O>::value. This resolves an issue where e5m2 values were incorrectly clipped, losing approximately 99.2% of their representable range.
  • RoPE Quantization Safety: Added essential FP8 clamping to RopeQuantizeKernel and RopeQuantizeAppendPagedKVCacheKernel functions. Previously, these kernels lacked clamping before cast_store, which could lead to NaN/Inf values in quantized output and potential corruption of the paged KV cache.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • include/flashinfer/fp8_types.cuh
    • Introduced fp8_clamp_max<T> type trait to provide type-correct maximum values for FP8 types (__nv_fp8_e4m3 and __nv_fp8_e5m2).
  • include/flashinfer/norm.cuh
    • Included fp8_types.cuh header.
    • Replaced hardcoded 448.0f clamping with fp8_clamp_max<O>::value in RMSNormQuantKernel.
    • Replaced hardcoded 448.0f clamping with fp8_clamp_max<O>::value in FusedAddRMSNormQuantKernel.
  • include/flashinfer/pos_enc.cuh
    • Included fp8_types.cuh header.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value in scale_store_partial_chunk.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for q_rope_vec in RopeQuantizeKernel.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for k_rope_vec in RopeQuantizeKernel.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for q_rope_vec in RopeQuantizeAppendPagedKVCacheKernel.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for k_rope_vec in RopeQuantizeAppendPagedKVCacheKernel.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for k_nope_vec in RopeQuantizeAppendPagedKVCacheKernel.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for v_vec in RopeQuantizeAppendPagedKVCacheKernel.
    • Added FP8 clamping using fp8_clamp_max<QuantType>::value for q_nope_vec in RopeQuantizeAppendPagedKVCacheKernel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 26, 2026

📝 Walkthrough

Walkthrough

Adds a new fp8_clamp_max type trait with specializations for __nv_fp8_e4m3 and __nv_fp8_e5m2, and replaces hard-coded FP8 clamping with dynamic clamps using this trait across RMSNorm and RoPE quantization kernels to avoid NaN/Inf before FP8 casting.

Changes

Cohort / File(s) Summary
FP8 Type Trait
include/flashinfer/fp8_types.cuh
New CUDA header defining template <typename T> struct fp8_clamp_max and specializations: fp8_clamp_max<__nv_fp8_e4m3>::value = 448.0f and fp8_clamp_max<__nv_fp8_e5m2>::value = 57344.0f.
Quantization Kernels
include/flashinfer/norm.cuh, include/flashinfer/pos_enc.cuh
Replaced hard-coded clamp literals with fp8_clamp_max<...>::value and added runtime clamping (symmetric min/max) after scaling in multiple RMSNorm and RoPE quantization code paths; added include for the new header.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • bkryu
  • yzh119
  • wenscarl
  • aleozlx
  • djmmoss
  • nvmbreughe

Poem

"I nibble bytes and guard each hop, 🐇
FP8 bounds set—no NaN will drop,
e4 or e5, values kept in line,
scaled and clamped, the casts are fine,
A rabbit's cheer for safe numeric time!"

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: introducing type-specific FP8 max values for clamping across RMSNorm and RoPE quantization kernels.
Description check ✅ Passed The description includes a comprehensive summary, changes by file, and test plan, though the PR description template sections are not explicitly formatted per the template structure.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses issues with FP8 value clamping in RMSNorm and RoPE quantization kernels. It introduces a new fp8_types.cuh header for type-specific maximum values, which is a good approach. My review focuses on improving code readability and maintainability by reducing code duplication. I've suggested creating a helper function for the clamping logic, which is repeated in many places.

template <>
struct fp8_clamp_max<__nv_fp8_e5m2> {
static constexpr float value = 57344.0f;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code readability and reduce repetition across norm.cuh and pos_enc.cuh, consider adding a helper function for clamping values before FP8 conversion. This pattern is used in multiple places.

You could add the following function to this file, which would make the call sites much cleaner (e.g., output_vec[j] = clamp<O>(output_vec[j]);).

};

template <typename T>
__device__ __forceinline__ float clamp(float val) {
  constexpr float max_val = fp8_clamp_max<T>::value;
  return fmaxf(-max_val, fminf(val, max_val));
}

output_vec[j] =
float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
output_vec[j] = fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the proposed clamp helper function in fp8_types.cuh, this line can be simplified for better readability.

      output_vec[j] = clamp<O>(output_vec[j]);

for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
output_vec[j] = fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This can also be simplified using the proposed clamp helper function for better readability.

      output_vec[j] = clamp<O>(output_vec[j]);

#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
vec[i] = vec[i] * scale;
vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the proposed clamp helper function in fp8_types.cuh, this line can be simplified for better readability.

    vec[i] = clamp<QuantType>(vec[i]);

Comment on lines +506 to +507
q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

        q_rope_vec[i] = clamp<QuantType>(q_rope_vec[i]);

Comment on lines +911 to +912
q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

        q_rope_vec[i] = clamp<QuantType>(q_rope_vec[i]);

Comment on lines +942 to +943
k_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(k_rope_vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

        k_rope_vec[i] = clamp<QuantType>(k_rope_vec[i]);

Comment on lines +974 to +975
k_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(k_nope_vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

        k_nope_vec[i] = clamp<QuantType>(k_nope_vec[i]);

Comment on lines +1006 to +1007
v_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(v_vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

              v_vec[i] = clamp<QuantType>(v_vec[i]);

Comment on lines +1037 to +1038
q_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(q_nope_vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

        q_nope_vec[i] = clamp<QuantType>(q_nope_vec[i]);

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/norm.cuh`:
- Around line 218-219: Run clang-format on the header to fix
whitespace/formatting issues in the edited regions: reformat the block
containing the output_vec assignment and usages of fp8_clamp_max (the line with
output_vec[j] = fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j],
fp8_clamp_max<O>::value)); and the nearby lines around the other flagged region
(lines referencing output_vec and fp8_clamp_max at the second location). Ensure
pre-commit passes by applying the project's clang-format style (or running the
repo's format script) and commit the normalized file.

In `@include/flashinfer/pos_enc.cuh`:
- Line 25: The file fails clang-format checks; run the repository's clang-format
(or pre-commit) on pos_enc.cuh to fix whitespace, include ordering and style
issues—specifically reformat the include line referencing "fp8_types.cuh" and
the surrounding header contents so the file passes pre-commit clang-format
checks before merging.
- Line 277: The fp8 clamping call using fp8_clamp_max<QuantType> must be guarded
so it only instantiates for FP8 quant types; modify the places that call
fp8_clamp_max (e.g., inside RopeQuantize/RopeQuantizeKernel invoked by
BatchQKApplyRotaryPosIdsCosSinCache and the vec[i] clamp expression) to perform
the clamp only when QuantType is an FP8 type (use an existing is_fp8 trait or
add a constexpr helper like is_fp8_v<QuantType> and wrap the clamp in if
constexpr (is_fp8_v<QuantType>) { vec[i] = fmaxf(...
fp8_clamp_max<QuantType>::value ...); } else { /* no clamp */ } so the base
template assertion is never instantiated for non-FP8 types).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f852eb6 and 178350c.

📒 Files selected for processing (3)
  • include/flashinfer/fp8_types.cuh
  • include/flashinfer/norm.cuh
  • include/flashinfer/pos_enc.cuh

#include <string>
#include <type_traits>

#include "fp8_types.cuh"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

clang-format cleanup is still required for this file.

Pre-commit reports formatting diffs; please format this header before merging.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/pos_enc.cuh` at line 25, The file fails clang-format
checks; run the repository's clang-format (or pre-commit) on pos_enc.cuh to fix
whitespace, include ordering and style issues—specifically reformat the include
line referencing "fp8_types.cuh" and the surrounding header contents so the file
passes pre-commit clang-format checks before merging.

#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
vec[i] = vec[i] * scale;
vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# First, locate the key files
echo "=== Locating key files ==="
fd -t f "pos_enc.cuh" include/
fd -t f "fp8_types.cuh" include/

echo
echo "=== Checking fp8_clamp_max definition in fp8_types.cuh ==="
if fd -t f "fp8_types.cuh" include/ | head -1 | xargs cat -n; then
  :
else
  echo "fp8_types.cuh not found, searching for fp8_clamp_max definition"
  rg -n "fp8_clamp_max" include/ --type cuh --type h -A 3 | head -50
fi

Repository: flashinfer-ai/flashinfer

Length of output: 1923


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check the clamping line at 277 ==="
sed -n '270,285p' include/flashinfer/pos_enc.cuh

echo
echo "=== Check line 1391 with RopeQuantize call ==="
sed -n '1385,1395p' include/flashinfer/pos_enc.cuh

echo
echo "=== Check other cited clamping lines: 506-507, 533-534, 911-912, 942-943, 974-975, 1006-1007, 1037-1038 ==="
for lines in "506,507" "533,534" "911,912" "942,943" "974,975" "1006,1007" "1037,1038"; do
  echo "--- Lines $lines ---"
  sed -n "${lines}p" include/flashinfer/pos_enc.cuh
done

Repository: flashinfer-ai/flashinfer

Length of output: 2570


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Search for RopeQuantize template definition and usages ==="
rg -n "template.*RopeQuantize" include/ --type cuh --type h -A 5 | head -40

echo
echo "=== Find all RopeQuantize instantiations with template args ==="
rg -n "RopeQuantize\s*<" include/flashinfer/pos_enc.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 175


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Examine fp8_clamp_max specializations ==="
rg -n "fp8_clamp_max" include/ --type cuh --type h -B 2 -A 5 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 158


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Search for RopeQuantize template definition ==="
rg -n "template.*RopeQuantize" include/ -A 10 | head -60

echo
echo "=== Find all RopeQuantize<...> instantiations in pos_enc.cuh ==="
rg -n "RopeQuantize\s*<" include/flashinfer/pos_enc.cuh -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check the context around line 1391 to see what DType values are used ==="
sed -n '1360,1410p' include/flashinfer/pos_enc.cuh

echo
echo "=== Search for the RopeQuantize kernel function signature ==="
grep -n "kernel.*RopeQuantize\|RopeQuantize.*kernel\|__global__.*RopeQuantize" include/flashinfer/pos_enc.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 3171


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Find RopeQuantize template function definition ==="
rg -n "RopeQuantize" include/flashinfer/pos_enc.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1142


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Look for template instantiations with different DType/QuantType combinations ==="
rg -n "RopeQuantize.*DType.*IdType.*DType" include/flashinfer/pos_enc.cuh
rg -n "RopeQuantize.*float.*int.*float\|RopeQuantize.*half.*int.*half\|RopeQuantize.*bfloat16.*int.*bfloat16" include/flashinfer/pos_enc.cuh

echo
echo "=== Check function calls to rope_quantize or RopeQuantize ==="
rg -n "rope_quantize\|RopeQuantize" include/flashinfer/pos_enc.cuh --color=never | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 202


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Get the complete RopeQuantize template definition location and signature ==="
sed -n '1300,1350p' include/flashinfer/pos_enc.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 2517


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Get the RopeQuantize function definition (line 1049 onwards) ==="
sed -n '1049,1140p' include/flashinfer/pos_enc.cuh

echo
echo "=== Check if there are other RopeQuantize instantiations with different QuantType ==="
rg -n "RopeQuantize<.*," include/flashinfer/pos_enc.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 4964


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check what DType values are used when calling BatchQKApplyRotaryPosIdsCosSinCache ==="
rg -n "BatchQKApplyRotaryPosIdsCosSinCache" . --no-heading | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1012


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Look for explicit template instantiations or dispatch macros for DType ==="
rg -n "DISPATCH.*DTYPE\|instantiate.*DType\|float.*half.*bfloat" include/flashinfer/pos_enc.cuh | head -20

echo
echo "=== Check the RopeQuantizeKernel definition to understand QuantType usage ==="
sed -n '431,550p' include/flashinfer/pos_enc.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 146


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check csrc/rope.cu to see what DTypes are passed to BatchQKApplyRotaryPosIdsCosSinCache ==="
sed -n '140,170p' csrc/rope.cu

echo
echo "=== Look for template specializations or instantiation guards for rope functions ==="
rg -n "ROPE_INSTANTIATE\|DT_DISPATCH\|FOR_EACH_DTYPE" csrc/rope.cu | head -20

echo
echo "=== Check the full context of rope.cu around the call to BatchQKApplyRotaryPosIdsCosSinCache ==="
sed -n '100,200p' csrc/rope.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1864


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Look at the complete RopeQuantizeKernel definition to see where clamping is used ==="
sed -n '431,650p' include/flashinfer/pos_enc.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 10846


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Find scale_store_partial_chunk function to see if it also uses fp8_clamp_max ==="
rg -n "scale_store_partial_chunk" include/flashinfer/pos_enc.cuh -B 5 -A 20 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 3876


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Verify if BatchQKApplyRotaryPosIdsCosSinCache is instantiated with non-FP8 types ==="
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16" csrc/rope.cu -B 3 -A 3

echo
echo "=== Check what DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 macro does ==="
rg -n "define DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16" . -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 4380


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Determine all DType instantiations by checking the dispatcher ==="
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE" csrc/rope.cu | head -10

echo
echo "=== Check the macro definition to see what types it dispatches to ==="
rg -n "define DISPATCH_DLPACK_DTYPE_TO_CTYPE" . --type h --type cpp -A 20 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 6270


🏁 Script executed:

#!/bin/bash
set -euo zipefail

echo "=== Check the complete scale_store_partial_chunk function to see line 277 in context ==="
sed -n '251,290p' include/flashinfer/pos_enc.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 121


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Verify scale_store_partial_chunk is called from RopeQuantizeKernel with non-FP8 QuantType ==="
sed -n '540,580p' include/flashinfer/pos_enc.cuh

echo
echo "=== Look at how scale_store_partial_chunk uses QuantType for clamping ==="
sed -n '251,280p' include/flashinfer/pos_enc.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 3634


Guard FP8 clamping against non-FP8 QuantType instantiations.

The generic template fp8_clamp_max<QuantType> defined in fp8_types.cuh only specializes __nv_fp8_e4m3 and __nv_fp8_e5m2; the base template asserts on any other type. At line 1391, BatchQKApplyRotaryPosIdsCosSinCache<DType, IdType> calls RopeQuantize<DType, IdType, DType> where DType is dispatched via DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 (only __half and __nv_bfloat16). This causes RopeQuantizeKernel to instantiate fp8_clamp_max<__half> or fp8_clamp_max<__nv_bfloat16> when applying clamping at lines 277, 506–507, 533–534, 911–912, 942–943, 974–975, 1006–1007, and 1037–1038, triggering a compile-time assertion failure.

Proposed fix pattern
+template <typename QuantType>
+__device__ __forceinline__ float clamp_if_fp8(float x) {
+  if constexpr (std::is_same<QuantType, __nv_fp8_e4m3>::value ||
+                std::is_same<QuantType, __nv_fp8_e5m2>::value) {
+    return fmaxf(-fp8_clamp_max<QuantType>::value, fminf(x, fp8_clamp_max<QuantType>::value));
+  } else {
+    return x;
+  }
+}
...
-        vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value));
+        vec[i] = clamp_if_fp8<QuantType>(vec[i]);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/pos_enc.cuh` at line 277, The fp8 clamping call using
fp8_clamp_max<QuantType> must be guarded so it only instantiates for FP8 quant
types; modify the places that call fp8_clamp_max (e.g., inside
RopeQuantize/RopeQuantizeKernel invoked by BatchQKApplyRotaryPosIdsCosSinCache
and the vec[i] clamp expression) to perform the clamp only when QuantType is an
FP8 type (use an existing is_fp8 trait or add a constexpr helper like
is_fp8_v<QuantType> and wrap the clamp in if constexpr (is_fp8_v<QuantType>) {
vec[i] = fmaxf(... fp8_clamp_max<QuantType>::value ...); } else { /* no clamp */
} so the base template assertion is never instantiated for non-FP8 types).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
include/flashinfer/pos_enc.cuh (1)

277-279: ⚠️ Potential issue | 🔴 Critical

Guard FP8 clamping so non-FP8 QuantType paths do not instantiate FP8-only traits.

At Line 1392 this file instantiates RopeQuantize<DType, IdType, DType>, while the new clamp sites unconditionally use fp8_clamp_max<QuantType>. This can break non-FP8 builds (or introduce unintended clipping) unless clamping is explicitly FP8-gated.

Proposed fix pattern
+template <typename QuantType>
+__device__ __forceinline__ float clamp_if_fp8(float x) {
+  if constexpr (std::is_same_v<QuantType, __nv_fp8_e4m3> ||
+                std::is_same_v<QuantType, __nv_fp8_e5m2>) {
+    return fmaxf(-fp8_clamp_max<QuantType>::value, fminf(x, fp8_clamp_max<QuantType>::value));
+  } else {
+    return x;
+  }
+}
...
-    vec[i] =
-        fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value));
+    vec[i] = clamp_if_fp8<QuantType>(vec[i]);

Use this read-only check to confirm whether non-FP8 instantiations exist and whether fp8_clamp_max is FP8-only:

#!/bin/bash
set -euo pipefail

echo "=== fp8_clamp_max definition / specializations ==="
fd -t f "fp8_types.cuh" include --exec sed -n '1,220p' {}

echo
echo "=== RopeQuantize instantiations in pos_enc.cuh ==="
rg -n "RopeQuantize<" include/flashinfer/pos_enc.cuh -C2
sed -n '1388,1398p' include/flashinfer/pos_enc.cuh

echo
echo "=== Unconditional fp8_clamp_max uses in pos_enc.cuh ==="
rg -n "fp8_clamp_max<QuantType>" include/flashinfer/pos_enc.cuh -C1

Also applies to: 507-508, 534-535, 912-913, 943-944, 975-976, 1007-1008, 1038-1039, 1392-1398

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/pos_enc.cuh` around lines 277 - 279, The FP8 clamp call
uses fp8_clamp_max<QuantType> unconditionally and will instantiate FP8-only
traits for non-FP8 QuantType; fix by gating the clamp with a compile-time FP8
trait check (e.g., use if constexpr (is_fp8<QuantType>::value) or equivalent
project trait) so only FP8 instantiations call fp8_clamp_max<QuantType>::value;
update the sites that set vec[i] (the fmaxf/fminf clamp) to perform the clamp
only inside that if constexpr and leave vec[i] unchanged for non-FP8 types
(apply the same pattern to all listed occurrences and the RopeQuantize<DType,
IdType, DType> instantiation paths).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@include/flashinfer/pos_enc.cuh`:
- Around line 277-279: The FP8 clamp call uses fp8_clamp_max<QuantType>
unconditionally and will instantiate FP8-only traits for non-FP8 QuantType; fix
by gating the clamp with a compile-time FP8 trait check (e.g., use if constexpr
(is_fp8<QuantType>::value) or equivalent project trait) so only FP8
instantiations call fp8_clamp_max<QuantType>::value; update the sites that set
vec[i] (the fmaxf/fminf clamp) to perform the clamp only inside that if
constexpr and leave vec[i] unchanged for non-FP8 types (apply the same pattern
to all listed occurrences and the RopeQuantize<DType, IdType, DType>
instantiation paths).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 178350c and f119a27.

📒 Files selected for processing (2)
  • include/flashinfer/norm.cuh
  • include/flashinfer/pos_enc.cuh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant