Skip to content

feat: RMSNorm/Fused RMSNorm + FP8 Quantization kernels#2243

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
BLaZeKiLL:dev/dlal/norm-fp8-fusion
Dec 19, 2025
Merged

feat: RMSNorm/Fused RMSNorm + FP8 Quantization kernels#2243
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
BLaZeKiLL:dev/dlal/norm-fp8-fusion

Conversation

@BLaZeKiLL
Copy link
Contributor

@BLaZeKiLL BLaZeKiLL commented Dec 19, 2025

📌 Description

FP8 model inference requires multiple intermediate quantization kernels, which can be avoided by fusing norm and quantization kernels. Consumers like sglang and vllm can lower to these norm + quant fusion kernels using custom torch compile passes

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Reference

I have been working on adding custom fusion passes to sglang as part of the following RFC and would like to use flashinfer's norm kernels for the norm quant fusions instead of migrating vllm kernels to sglang as part of the following MR

Implementation

I realise that existing kernels (at least for rmsnorm) can be modified to add the scale parameter as an optional parameter, thereby avoiding most code duplication. However, as an initial implementation, I have opted for a separate implementation route. This can be refactored if required.

For fused_add_rmsnorm_quant, I don't think an in-place update would be possible since dtypes for input and output differ

Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am not aware of getting this value at compile time without including c10 headers from torch, and not sure if that is acceptable post tvm ffi migration

Following is a snippet from VLLM, and I have seen similar code for getting the FP8 numeric limits

#include <c10/util/Float8_e4m3fn.h>

template <typename T,
          typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
                                      std::is_same_v<T, c10::Float8_e4m3fnuz> ||
                                      std::is_same_v<T, int8_t>>>
struct quant_type_max {
  static constexpr T val() { return std::numeric_limits<T>::max(); }
};

The best option in my mind is to introduce include/flashinfer/fp8.h containing something similar to the above snippet, and also support e5m2

Tests

atol and rtol for the fp8 assertions had to be high due to the low precision nature of the data, but with tolerances of 1e-2, just a few tests fail with a single element mismatch

Summary by CodeRabbit

  • New Features

    • Added quantized RMSNorm and fused quantized RMSNorm (residual-add) with configurable scale, eps, and PDL toggle.
    • Supports FP16/FP8 paths and optional per-token or per-tensor scaling; outputs are clamped for quantized formats.
  • Tests

    • Added tests validating quantized normalization and fused-residual flows across dtypes, batch sizes, scaling modes, and PDL configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 19, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds two quantized RMSNorm APIs—rmsnorm_quant and fused_add_rmsnorm_quant—across CUDA kernels, C++ runtime/FFI bindings, Python wrappers (with fake/test stubs), and unit tests. New params: scale, eps, and enable_pdl; existing non-quantized paths remain unchanged.

Changes

Cohort / File(s) Change Summary
CUDA kernels & host wrappers
include/flashinfer/norm.cuh
Add RMSNormQuantKernel and FusedAddRMSNormQuantKernel templates and host wrappers RMSNormQuant / FusedAddRMSNormQuant to compute RMSNorm with quantized outputs (per-token/per-tensor scaling, clamping, shared-memory reductions, launch config, enable_pdl).
C++ runtime implementations
csrc/norm.cu
Add rmsnorm_quant(...) and fused_add_rmsnorm_quant(...) entry points that validate inputs, dispatch FP16/FP8 paths, invoke quantized kernels on a CUDA stream, and report CUDA errors.
FFI / public exports
csrc/flashinfer_norm_binding.cu
Declare and export rmsnorm_quant and fused_add_rmsnorm_quant (signatures include double scale, double eps, bool enable_pdl).
Python API & fakes
flashinfer/norm.py
Add rmsnorm_quant and fused_add_rmsnorm_quant wrappers delegating to native module with enable_pdl defaults; add _rmsnorm_quant_fake and _fused_add_rmsnorm_quant_fake no-op test stubs.
Tests / reference implementations
tests/utils/test_norm.py
Add reference functions llama_rms_norm_quant and fused_add_rms_norm_quant and parameterized tests test_norm_quant and test_fused_add_rmsnorm_quant (batch/hidden sizes, float16/bfloat16, quant_scale, PDL handling, clamping to float8_e4m3fn, residual checks).

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Python
participant FFI as C++_FFI
participant Runtime as C++_Runtime
participant GPU as CUDA_Kernel
Python->>FFI: call rmsnorm_quant(out,input,weight,scale,eps,enable_pdl)
FFI->>Runtime: forward TensorView + params
Runtime->>GPU: launch RMSNormQuantKernel (dtype dispatch, stream)
GPU-->>Runtime: kernel completes (quantized output written)
Runtime-->>FFI: return status
FFI-->>Python: return / None

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas needing extra attention:
    • include/flashinfer/norm.cuh: vectorization, shared-memory layout, warp/cross-warp reductions, quantization/clamping math.
    • csrc/norm.cu: dtype dispatch (FP16/FP8), stream and error handling.
    • csrc/flashinfer_norm_binding.cu: ABI/FFI signatures and exports.
    • tests/utils/test_norm.py: coverage for edge quant_scale values, float8 clamping, and PDL-unavailable branches.

Suggested reviewers

  • yzh119

Poem

🐇 I hopped through kernels, light and keen,

scaled each token, clipped the scene,
weights in paw, residuals spun,
float8 dreams when work was done,
a rabbit's patch of code—fast, small, and clean.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.18% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main changes: adding RMSNorm/Fused RMSNorm kernels with FP8 quantization support.
Description check ✅ Passed The description provides a clear rationale (FP8 model inference optimization), implementation notes, known issues (hard-coded FP8 limits), and test coverage details. All major sections of the template are addressed with meaningful content.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @BLaZeKiLL, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces specialized CUDA kernels for RMSNorm and Fused Add RMSNorm that incorporate FP8 quantization directly into the normalization process. The primary goal is to enhance the efficiency of FP8 model inference by eliminating the need for separate, intermediate quantization steps. This optimization is particularly beneficial for consumers like sglang and vllm, enabling them to leverage custom PyTorch compile passes for improved performance in large language models.

Highlights

  • FP8 Quantization for RMSNorm: Introduced a new rmsnorm_quant kernel that performs Root Mean Square Normalization and quantizes the output to FP8, designed to streamline FP8 model inference by fusing norm and quantization steps.
  • FP8 Quantization for Fused Add RMSNorm: Added a fused_add_rmsnorm_quant kernel which combines residual addition, RMS normalization, and FP8 quantization into a single fused operation, further optimizing performance for FP8 inference.
  • Python Bindings and Tests: Implemented Python bindings for the new quantized kernels and added comprehensive pytest tests to verify their correctness and performance across various batch sizes, hidden sizes, and data types, including support for programmatic dependent launch (PDL).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new CUDA kernels for RMSNorm and Fused RMSNorm with FP8 quantization. The changes are well-structured and follow the existing patterns in the codebase. My review focuses on improving the Python API correctness, making the CUDA kernels more generic, and increasing test coverage. I've identified a few issues in the Python bindings related to return types and documentation, suggested removing hardcoded values in the CUDA kernels, and recommended parameterizing the tests to cover all supported FP8 types.

for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] =
float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The clamping values -448.0f and 448.0f are hardcoded for FP8 E4M3. As you noted in the PR description, this prevents the kernel from working correctly with other FP8 types like E5M2. Please make this generic. A good approach would be to use if constexpr on the output type O to select the appropriate numeric limits, and define these limits in a central header to avoid magic numbers.

Example:

if constexpr (std::is_same_v<O, __nv_fp8_e4m3>) {
    // E4M3 limits
    output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
} else if constexpr (std::is_same_v<O, __nv_fp8_e5m2>) {
    // E5M2 limits
    output_vec[j] = fmaxf(-57344.0f, fminf(output_vec[j], 57344.0f));
}

#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to RMSNormQuantKernel, the clamping values here are hardcoded for FP8 E4M3. This should be generalized to support other FP8 formats. Please use a generic approach, for instance with if constexpr, to handle different FP8 types like E4M3 and E5M2.

Comment on lines +35 to +46
def llama_rms_norm_quant(x, w, scale, eps=1e-6):
inv_scale = torch.reciprocal(torch.tensor(scale)).float()
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x * inv_scale
x = torch.clamp(
x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max
)
x = x.to(torch.float8_e4m3fn)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This reference implementation is hardcoded for torch.float8_e4m3fn. To enable testing with other FP8 types like e5m2, please parameterize this function to accept an fp8_dtype and use its finfo for clamping. This will also make the tests more robust.

Suggested change
def llama_rms_norm_quant(x, w, scale, eps=1e-6):
inv_scale = torch.reciprocal(torch.tensor(scale)).float()
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x * inv_scale
x = torch.clamp(
x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max
)
x = x.to(torch.float8_e4m3fn)
return x
def llama_rms_norm_quant(x, w, scale, fp8_dtype, eps=1e-6):
inv_scale = torch.reciprocal(torch.tensor(scale)).float()
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x * inv_scale
x = torch.clamp(
x, torch.finfo(fp8_dtype).min, torch.finfo(fp8_dtype).max
)
x = x.to(fp8_dtype)
return x

Comment on lines +83 to +97
def fused_add_rms_norm_quant(x, residual, weight, scale, eps):
inv_scale = torch.reciprocal(torch.tensor(scale)).float()
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * weight.float()
x = x * inv_scale
x = torch.clamp(
x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max
)
x = x.to(torch.float8_e4m3fn)
return x, residual
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to llama_rms_norm_quant, this reference implementation is hardcoded for torch.float8_e4m3fn. Please parameterize it to accept an fp8_dtype to allow testing against different FP8 formats.

Suggested change
def fused_add_rms_norm_quant(x, residual, weight, scale, eps):
inv_scale = torch.reciprocal(torch.tensor(scale)).float()
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * weight.float()
x = x * inv_scale
x = torch.clamp(
x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max
)
x = x.to(torch.float8_e4m3fn)
return x, residual
def fused_add_rms_norm_quant(x, residual, weight, scale, fp8_dtype, eps):
inv_scale = torch.reciprocal(torch.tensor(scale)).float()
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * weight.float()
x = x * inv_scale
x = torch.clamp(
x, torch.finfo(fp8_dtype).min, torch.finfo(fp8_dtype).max
)
x = x.to(fp8_dtype)
return x, residual

Comment on lines +128 to +152
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_norm_quant(
batch_size, hidden_size, dtype, quant_scale, enable_pdl, contiguous
):
if contiguous:
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]

if enable_pdl and not device_support_pdl(x.device):
pytest.skip("PDL is only available for Hopper and later GPUs")

w = torch.randn(hidden_size).to(0).to(dtype)

y_ref = llama_rms_norm_quant(x, w, quant_scale)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda")
flashinfer.norm.rmsnorm_quant(y, x, w, quant_scale, enable_pdl=enable_pdl)

torch.testing.assert_close(y_ref.float(), y.float(), rtol=1, atol=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test only covers torch.float8_e4m3fn. The underlying kernel supports other FP8 types like torch.float8_e5m2. After parameterizing the reference implementation llama_rms_norm_quant, please also parameterize this test to run against different FP8 dtypes (e.g., torch.float8_e4m3fn, torch.float8_e5m2) to ensure full coverage.

Comment on lines +220 to +255
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_fused_add_rmsnorm_quant(
batch_size, hidden_size, dtype, quant_scale, enable_pdl, contiguous
):
eps = 1e-6

if contiguous:
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]

if enable_pdl and not device_support_pdl(x.device):
pytest.skip("PDL is only available for Hopper and later GPUs")

residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")

x_native, residual_native = fused_add_rms_norm_quant(
x.clone(), residual.clone(), weight, quant_scale, eps
)

x_fused = x.clone()
residual_fused = residual.clone()
y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda")
flashinfer.norm.fused_add_rmsnorm_quant(
y, x_fused, residual_fused, weight, quant_scale, eps, enable_pdl=enable_pdl
)

torch.testing.assert_close(y.float(), x_native.float(), rtol=1, atol=1)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test only covers torch.float8_e4m3fn. Please parameterize it to cover other FP8 types like torch.float8_e5m2 to ensure the fused kernel is fully tested. This would follow the parameterization of the fused_add_rms_norm_quant reference function.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (2)
tests/utils/test_norm.py (1)

148-152: Consider whether tolerances are appropriate for FP8.

The tolerances rtol=1, atol=1 are very loose—they allow deviations up to 100% relative or 1.0 absolute. While FP8 has limited precision, this may mask subtle bugs. For float8_e4m3fn (max 448, min granularity varies by magnitude), tighter tolerances like atol=0.1 or atol=0.5 might still pass while catching more regressions.

flashinfer/norm.py (1)

207-228: Docstring is missing documentation for out and scale parameters.

The docstring doesn't document the out (output tensor) and scale (quantization scale factor) parameters, which are important for users to understand the API.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 454e7b2 and 0aaf5e8.

📒 Files selected for processing (5)
  • csrc/flashinfer_norm_binding.cu (2 hunks)
  • csrc/norm.cu (2 hunks)
  • flashinfer/norm.py (2 hunks)
  • include/flashinfer/norm.cuh (3 hunks)
  • tests/utils/test_norm.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/norm.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (294-296)
csrc/flashinfer_norm_binding.cu (2)
csrc/norm.cu (6)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm (117-145)
  • fused_add_rmsnorm (117-118)
  • fused_add_rmsnorm_quant (147-178)
  • fused_add_rmsnorm_quant (147-148)
flashinfer/norm.py (3)
  • rmsnorm_quant (97-132)
  • fused_add_rmsnorm (149-180)
  • fused_add_rmsnorm_quant (198-234)
tests/utils/test_norm.py (4)
flashinfer/utils.py (1)
  • device_support_pdl (615-619)
csrc/flashinfer_norm_binding.cu (2)
  • rmsnorm_quant (20-21)
  • fused_add_rmsnorm_quant (26-27)
csrc/norm.cu (4)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm_quant (147-178)
  • fused_add_rmsnorm_quant (147-148)
flashinfer/norm.py (2)
  • rmsnorm_quant (97-132)
  • fused_add_rmsnorm_quant (198-234)
🔇 Additional comments (12)
csrc/flashinfer_norm_binding.cu (2)

20-27: LGTM!

The new function declarations for rmsnorm_quant and fused_add_rmsnorm_quant correctly mirror their implementations in csrc/norm.cu. Parameter types and order are consistent with the kernel wrappers.


38-40: LGTM!

The exports for the new quantized variants follow the established pattern.

include/flashinfer/norm.cuh (4)

148-161: LGTM on kernel structure.

The RMSNormQuantKernel follows the same pattern as the non-quantized RMSNormKernel, with appropriate additions for scale_inv computation and output type O for quantized output.


229-261: LGTM!

The RMSNormQuant host wrapper correctly follows the pattern established by RMSNorm, with appropriate handling of the additional scale parameter and output type O.


515-610: LGTM on kernel logic.

The FusedAddRMSNormQuantKernel correctly implements the fused operation: updating residual in-place and writing the normalized+quantized result to a separate output tensor. The shared memory usage for intermediate values (smem_x) matches the non-quantized variant.

Same note as before: line 600 uses hard-coded 448.0f for clamping.


612-647: LGTM!

The FusedAddRMSNormQuant host wrapper correctly mirrors FusedAddRMSNorm, with the additional output pointer and stride_output parameter for the quantized output tensor.

csrc/norm.cu (1)

88-115: LGTM!

The 2D-only restriction is reasonable for the quantized variant. The implementation correctly dispatches FP16 input to FP8 output through the nested type dispatch macros.

tests/utils/test_norm.py (3)

35-46: LGTM on the reference implementation.

The llama_rms_norm_quant reference correctly implements the quantization formula: normalize → scale by 1/scale → clamp to FP8 range → cast to float8_e4m3fn.


83-97: LGTM!

The fused_add_rms_norm_quant reference correctly mirrors the kernel behavior: fused add → update residual → normalize → quantize.


243-255: LGTM on test structure.

The test correctly validates both the quantized output (with FP8 tolerances) and the residual update (with standard tolerances). Good separation of concerns.

flashinfer/norm.py (2)

135-144: LGTM!

The fake op is appropriately a no-op since it mutates the pre-allocated out tensor.


237-247: LGTM!

The fake op correctly matches the signature of the real implementation.

Comment on lines +214 to +218
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] =
float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Find DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 macro definition
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8" --type cpp -B 2 -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 1821


🏁 Script executed:

# Search for FP8 constants or fp8.h
fd -n "fp8" -e h
rg -n "448|57344" --type cpp -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 8470


🏁 Script executed:

# Check norm.cu to see how the kernel is called/dispatched
rg -n "FusedAddRMSNormQuantKernel|FusedRMSNormQuantKernel" csrc/norm.cu -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for dispatch pattern in norm.cu
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8" csrc/norm.cu -B 3 -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 1628


🏁 Script executed:

# Search for RMSNormQuant implementation to confirm it's in norm.cuh
rg -n "RMSNormQuant" include/flashinfer/norm.cuh -B 2 -A 15 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2389


🏁 Script executed:

# Get the full RMSNormQuantKernel to see line 217 context
rg -n "RMSNormQuantKernel" include/flashinfer/norm.cuh -A 80 | head -90

Repository: flashinfer-ai/flashinfer

Length of output: 3843


🏁 Script executed:

# Search for FusedAddRMSNormQuantKernel to see if it has the same issue at line 600
rg -n "FusedAddRMSNormQuantKernel" include/flashinfer/norm.cuh -A 100 | grep -A 5 -B 5 "448"

Repository: flashinfer-ai/flashinfer

Length of output: 623


🏁 Script executed:

# Check if there are any recent commits or comments addressing this 448 value
rg -n "FP8_E5M2|57344|e5m2" include/flashinfer/norm.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 50


Hard-coded FP8 clamping limit breaks e5m2 outputs.

The kernel templates use hard-coded 448.0f clamping at lines 217 and 600, which is correct for float8_e4m3fn but incorrect for float8_e5m2 (max ≈ 57344). Since DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 in csrc/tvm_ffi_utils.h dispatches to both _DISPATCH_CASE_FP8_E4M3 and _DISPATCH_CASE_FP8_E5M2, values up to 57344 will be incorrectly clipped to 448 when using e5m2. Use type-dependent limits (following the pattern in csrc/fmha_v2/fmha/numeric_types.h which defines MAX_E4M3 and MAX_E5M2) or a named constant.

@BLaZeKiLL BLaZeKiLL force-pushed the dev/dlal/norm-fp8-fusion branch 2 times, most recently from 594845f to 6e26dbb Compare December 19, 2025 00:55
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
flashinfer/norm.py (2)

95-133: Missing return statement.

The function declares -> torch.Tensor and documents a return value, but line 132 doesn't return the out tensor after the kernel call.

🔎 Proposed fix
     if enable_pdl is None:
         enable_pdl = device_support_pdl(input.device)
     get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl)
+    return out

194-206: Return type should be torch.Tensor and docstring needs correction.

The function signature indicates -> None but it should return the out tensor for consistency with rmsnorm_quant and because out is a separate output buffer (not in-place to input). Additionally, the docstring incorrectly describes the operation as writing to input[i] when it actually writes to out.

🔎 Proposed fixes
 @flashinfer_api
 @register_custom_op(
     "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual")
 )
 def fused_add_rmsnorm_quant(
     out: torch.Tensor,
     input: torch.Tensor,
     residual: torch.Tensor,
     weight: torch.Tensor,
     scale: float,
     eps: float = 1e-6,
     enable_pdl: Optional[bool] = None,
-) -> None:
+) -> torch.Tensor:
     r"""Fused add root mean square normalization.

     Step 1:
     ``residual[i] += input[i]``

     Step 2:
-    ``input[i] = (residual[i] / RMS(residual)) * weight[i]``
+    ``out[i] = (residual[i] / RMS(residual)) * weight[i] * (1/scale)`` (quantized to out's dtype)

     Parameters
     ----------
+    out: torch.Tensor
+        The output tensor, will quantize the output to the dtype of this tensor.
     input: torch.Tensor
         Input tensor, shape (batch_size, hidden_size).
     residual: torch.Tensor
         Residual tensor, shape (batch_size, hidden_size).
     weight: torch.Tensor
         Weight tensor, shape (hidden_size,).
+    scale: float
+        Scale factor for quantization.
     eps: float
         Epsilon for numerical stability.
     enable_pdl: bool
         Whether to enable `programmatic dependent launch
         <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
+
+    Returns
+    -------
+    output: torch.Tensor
+        Quantized normalized tensor, shape (batch_size, hidden_size).
     """
     if enable_pdl is None:
         enable_pdl = device_support_pdl(input.device)
     get_norm_module().fused_add_rmsnorm_quant(
         out, input, residual, weight, scale, eps, enable_pdl
     )
+    return out
include/flashinfer/norm.cuh (2)

214-218: Hard-coded FP8 E4M3 clamping breaks E5M2 support.

The clamping values -448.0f and 448.0f at line 217 are hard-coded for __nv_fp8_e4m3 but will incorrectly clip __nv_fp8_e5m2 values (max ≈ 57344). Since the kernel is templated on output type O and dispatched via DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 (which includes both E4M3 and E5M2), the clamping should be type-dependent.

Consider using if constexpr on type O or defining numeric limits in a central header (e.g., include/flashinfer/fp8.h as mentioned in the PR description).


598-601: Hard-coded FP8 E4M3 clamping breaks E5M2 support.

Same issue as in RMSNormQuantKernel: line 600 hard-codes E4M3 limits (±448.0f), which will incorrectly clip E5M2 values. The clamping should be type-dependent based on output type O.

🧹 Nitpick comments (1)
csrc/flashinfer_norm_binding.cu (1)

20-21: Parameter name inconsistency with implementation.

The declaration uses TensorView out while the implementation in csrc/norm.cu (lines 79-114) uses TensorView output. For consistency, the declaration should match the implementation parameter name.

🔎 Proposed fix
-void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps,
+void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps,
                    bool enable_pdl);
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0aaf5e8 and 594845f.

📒 Files selected for processing (5)
  • csrc/flashinfer_norm_binding.cu (2 hunks)
  • csrc/norm.cu (2 hunks)
  • flashinfer/norm.py (2 hunks)
  • include/flashinfer/norm.cuh (3 hunks)
  • tests/utils/test_norm.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/norm.cu
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/norm.cuh
🧬 Code graph analysis (2)
csrc/flashinfer_norm_binding.cu (1)
csrc/norm.cu (6)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm (117-145)
  • fused_add_rmsnorm (117-118)
  • fused_add_rmsnorm_quant (147-178)
  • fused_add_rmsnorm_quant (147-148)
flashinfer/norm.py (5)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/utils.py (2)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
csrc/flashinfer_norm_binding.cu (2)
  • rmsnorm_quant (20-21)
  • fused_add_rmsnorm_quant (26-27)
csrc/norm.cu (4)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm_quant (147-178)
  • fused_add_rmsnorm_quant (147-148)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
🔇 Additional comments (5)
csrc/flashinfer_norm_binding.cu (1)

38-40: LGTM!

The FFI exports for the new quantized norm functions follow the established pattern and are correctly defined.

tests/utils/test_norm.py (2)

152-152: Verify the very loose test tolerances.

The test uses rtol=1, atol=1, meaning 100% relative tolerance and 1.0 absolute tolerance. This is extremely permissive and may hide accuracy issues. For reference, the non-quantized tests use rtol=1e-3, atol=1e-3 (line 125).

While FP8's limited precision may require looser tolerances than FP16, the current values seem excessive. Please verify this is intentional and necessary for FP8 e4m3fn quantization.


254-254: Verify the very loose test tolerances for quantized output.

Similar to test_norm_quant, the quantized output comparison uses rtol=1, atol=1 (100% relative/1.0 absolute tolerance), which is extremely permissive. The residual comparison (line 255) correctly uses tighter tolerances since it's not quantized.

Please verify this tolerance is necessary for FP8 e4m3fn quantization and not masking potential accuracy issues.

include/flashinfer/norm.cuh (2)

229-261: LGTM!

The RMSNormQuant host wrapper follows established patterns and correctly sets up the kernel launch configuration, including PDL support and dynamic shared memory.


612-647: LGTM!

The FusedAddRMSNormQuant host wrapper correctly extends the fused add RMSNorm pattern with quantization support, properly handling the separate output buffer and its stride.

@BLaZeKiLL BLaZeKiLL force-pushed the dev/dlal/norm-fp8-fusion branch from 6e26dbb to fe46655 Compare December 19, 2025 00:59
Signed-off-by: Devashish Lal <laldevashish@gmail.com>
@BLaZeKiLL BLaZeKiLL force-pushed the dev/dlal/norm-fp8-fusion branch from fe46655 to 1f39bcc Compare December 19, 2025 02:02
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (9)
flashinfer/norm.py (2)

194-237: Verify decorator parameter name and return type annotation.

Based on past review comments, there were two issues identified:

  1. The mutates_args decorator references parameter name "output" but the actual parameter is named "out" (appears fixed in current code showing "out" at line 196)
  2. The return type annotation is -> None but for consistency with rmsnorm_quant (which returns the output tensor), this function should also return out

Please verify whether the function should return the out tensor for consistency with the non-fused quantized variant.


95-133: Missing return statement.

The function signature declares -> torch.Tensor and the docstring documents a return value, but line 132 doesn't return the out tensor after calling the kernel.

🔎 Proposed fix
     if enable_pdl is None:
         enable_pdl = device_support_pdl(input.device)
     get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl)
+    return out
tests/utils/test_norm.py (4)

35-46: Consider parameterizing FP8 dtype for broader test coverage.

The reference implementation hardcodes torch.float8_e4m3fn for clamping and casting (lines 43, 45). To enable testing with other FP8 formats like torch.float8_e5m2, consider adding an fp8_dtype parameter and using torch.finfo(fp8_dtype).min/max for clamping.

This would allow the test suite to verify correctness across different FP8 formats as the kernels add e5m2 support.


83-97: Consider parameterizing FP8 dtype.

Similar to llama_rms_norm_quant, this reference implementation hardcodes torch.float8_e4m3fn. Parameterizing the FP8 dtype would enable testing with torch.float8_e5m2 and other formats.


128-152: Consider expanding test coverage to other FP8 dtypes.

The test currently only validates torch.float8_e4m3fn output (line 149). After parameterizing the reference implementation, consider adding test cases for torch.float8_e5m2 to ensure full kernel coverage.

Note: The relaxed tolerances (rtol=1, atol=1 at line 152) are expected for FP8 quantization due to limited precision.


220-255: Consider expanding test coverage to other FP8 dtypes.

Similar to test_norm_quant, this test only validates torch.float8_e4m3fn (line 249). Adding test coverage for torch.float8_e5m2 would ensure the fused kernel works correctly with different FP8 formats.

csrc/norm.cu (1)

80-115: Add device check for output tensor.

The function validates that input and weight are on the same device (line 85), but doesn't verify that output is also on the same device. Since output has a different dtype (FP8), there's a risk it could be on a different device.

🔎 Suggested fix
   CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
   CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
   CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
   CHECK_DEVICE(input, weight);
+  CHECK_DEVICE(input, output);
   CHECK_DIM(1, weight);  // weight: (hidden_size)
include/flashinfer/norm.cuh (2)

214-218: Hardcoded FP8 clamping breaks e5m2 support.

Line 217 hardcodes clamping to ±448.0f (FP8 e4m3 range). Since DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 dispatches to both e4m3 and e5m2 types (csrc/tvm_ffi_utils.h), e5m2 values (max ≈ 57344) will be incorrectly clipped, causing silent data corruption.

The PR description acknowledges this limitation and proposes adding include/flashinfer/fp8.h to centralize FP8 limits. Consider using if constexpr with type-dependent constants to fix this before merge.

🔎 Example fix using type-dependent limits
 #pragma unroll
     for (uint32_t j = 0; j < VEC_SIZE; j++) {
       output_vec[j] =
           float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
-      output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
+      // Clamp based on output type O
+      constexpr float fp8_max = std::is_same_v<O, __nv_fp8_e4m3> ? 448.0f : 57344.0f;
+      output_vec[j] = fmaxf(-fp8_max, fminf(output_vec[j], fp8_max));
     }

598-601: Hardcoded FP8 clamping breaks e5m2 support.

Line 600 has the same hardcoded ±448.0f clamping issue as RMSNormQuantKernel. This will incorrectly clip e5m2 values when the output type is torch.float8_e5m2.

Apply the same type-dependent clamping fix as suggested for RMSNormQuantKernel at line 217.

🧹 Nitpick comments (1)
csrc/flashinfer_norm_binding.cu (1)

20-27: Inconsistent parameter naming between declarations.

The first parameter is named out in rmsnorm_quant (line 20) but output in fused_add_rmsnorm_quant (line 26). The implementations in csrc/norm.cu use output for both. While this doesn't affect functionality, consistent naming improves code maintainability.

🔎 Suggested fix for consistency
-void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps,
+void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps,
                    bool enable_pdl);
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 594845f and 1f39bcc.

📒 Files selected for processing (5)
  • csrc/flashinfer_norm_binding.cu (2 hunks)
  • csrc/norm.cu (2 hunks)
  • flashinfer/norm.py (2 hunks)
  • include/flashinfer/norm.cuh (3 hunks)
  • tests/utils/test_norm.py (4 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/norm.cuh
🧬 Code graph analysis (4)
flashinfer/norm.py (4)
flashinfer/utils.py (5)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
  • device_support_pdl (615-619)
  • register_fake_op (330-334)
  • register_fake_op (359-364)
csrc/flashinfer_norm_binding.cu (2)
  • rmsnorm_quant (20-21)
  • fused_add_rmsnorm_quant (26-27)
csrc/norm.cu (4)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm_quant (147-181)
  • fused_add_rmsnorm_quant (147-148)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
csrc/flashinfer_norm_binding.cu (2)
csrc/norm.cu (6)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm (117-145)
  • fused_add_rmsnorm (117-118)
  • fused_add_rmsnorm_quant (147-181)
  • fused_add_rmsnorm_quant (147-148)
flashinfer/norm.py (3)
  • rmsnorm_quant (97-132)
  • fused_add_rmsnorm (149-180)
  • fused_add_rmsnorm_quant (198-237)
tests/utils/test_norm.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • hidden_size (265-265)
flashinfer/utils.py (1)
  • device_support_pdl (615-619)
csrc/flashinfer_norm_binding.cu (2)
  • rmsnorm_quant (20-21)
  • fused_add_rmsnorm_quant (26-27)
csrc/norm.cu (4)
  • rmsnorm_quant (80-115)
  • rmsnorm_quant (80-81)
  • fused_add_rmsnorm_quant (147-181)
  • fused_add_rmsnorm_quant (147-148)
flashinfer/norm.py (2)
  • rmsnorm_quant (97-132)
  • fused_add_rmsnorm_quant (198-237)
csrc/norm.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (294-296)
🔇 Additional comments (6)
flashinfer/norm.py (2)

135-144: LGTM!

The fake op stub is correctly implemented as a no-op for testing/compilation paths.


240-250: LGTM!

The fake op stub is correctly implemented for testing/compilation.

csrc/flashinfer_norm_binding.cu (1)

38-40: LGTM!

The FFI exports correctly expose the quantized norm functions.

include/flashinfer/norm.cuh (3)

229-261: LGTM!

The host wrapper correctly sets up kernel launch configuration with proper PDL support, shared memory allocation, and error checking.


612-647: LGTM!

The host wrapper correctly handles the fused kernel's larger shared memory requirements (for intermediate storage) and properly configures PDL support.


19-19: Good practice: explicit include for used types.

Adding #include <cstdint> explicitly documents the dependency on uint32_t and related types used throughout the kernels.

Comment on lines +147 to +181
void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
TensorView weight, double scale, double eps, bool enable_pdl) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
CHECK_DEVICE(input, residual);
CHECK_DEVICE(input, weight);
CHECK_DEVICE(input, output);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_DIM(2, output);
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);
TVM_FFI_ICHECK_EQ(residual.size(0), batch_size);
TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size);
TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size);
ffi::CUDADeviceGuard device_guard(input.device().device_id);
const cudaStream_t stream = get_stream(input.device());

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(output.dtype(), o_type, [&] {
cudaError_t status = norm::FusedAddRMSNormQuant(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), static_cast<o_type*>(output.data_ptr()),
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), scale,
eps, enable_pdl, stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "FusedAddRMSNormQuant failed with error code " << cudaGetErrorString(status);
return true;
});
});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for output size checks in csrc/norm.cu around the fused_add_rmsnorm_quant function
rg -n "TVM_FFI_ICHECK_EQ.*output\.size" csrc/norm.cu -A 2 -B 2 || echo "No matches found"

Repository: flashinfer-ai/flashinfer

Length of output: 1745


🏁 Script executed:

#!/bin/bash
# Check lines around 160-180 in csrc/norm.cu to see all validation checks
sed -n '146,180p' csrc/norm.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 2013


Add output tensor size validation to fused_add_rmsnorm_quant.

The function validates tensor dimensions and sizes for input, residual, and weight tensors, but lacks validation for output.size(0) and output.size(1) against batch_size and hidden_size. This is inconsistent with similar functions in the codebase (e.g., other norm functions) and leaves the function vulnerable to size mismatches. Add:

TVM_FFI_ICHECK_EQ(output.size(0), batch_size);
TVM_FFI_ICHECK_EQ(output.size(1), hidden_size);

after line 19 (following the weight.size(0) check).

🤖 Prompt for AI Agents
In csrc/norm.cu around lines 147-181, the function validates input, residual and
weight sizes but misses checks that output.size(0) == batch_size and
output.size(1) == hidden_size; add two TVM_FFI_ICHECK_EQ assertions immediately
after the existing TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size) line to
compare output.size(0) with batch_size and output.size(1) with hidden_size so
the output tensor size is validated before proceeding.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !203 has been created, and the CI pipeline #40475223 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with the PR and implementation in general.

Please also note that @bkryu is working on another PR #2233 using cute-dsl and we might shift to cute-dsl for development of these kernels in the future.

@yzh119 yzh119 enabled auto-merge (squash) December 19, 2025 04:12
@yzh119 yzh119 merged commit 49d66ee into flashinfer-ai:main Dec 19, 2025
4 checks passed
@bkryu
Copy link
Collaborator

bkryu commented Dec 19, 2025

I'm good with the PR and implementation in general.

Please also note that @bkryu is working on another PR #2233 using cute-dsl and we might shift to cute-dsl for development of these kernels in the future.

At the least, I can see if I can benchmark the code and see if there is a need to/can write a cute-dsl based implementation

BLaZeKiLL added a commit to BLaZeKiLL/flashinfer that referenced this pull request Feb 1, 2026
follow up on flashinfer-ai#2243

quant_scale being a float causes cuda graph capture to fail
even with workaround, by making it a tensor it fixes cuda
graph capture for fusion passes in sglang.

also added docs for the fused kernels.

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
BLaZeKiLL added a commit to BLaZeKiLL/flashinfer that referenced this pull request Feb 1, 2026
follow up on flashinfer-ai#2243

quant_scale being a float causes cuda graph capture to fail
even with workaround, by making it a tensor it fixes cuda
graph capture for fusion passes in sglang.

also added docs for the fused kernels.

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
BLaZeKiLL added a commit to BLaZeKiLL/flashinfer that referenced this pull request Feb 2, 2026
follow up on flashinfer-ai#2243

quant_scale being a float causes cuda graph capture to fail
even with workaround, by making it a tensor it fixes cuda
graph capture for fusion passes in sglang.

also added docs for the fused kernels.

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…#2243)

<!-- .github/pull_request_template.md -->

## 📌 Description

FP8 model inference requires multiple intermediate quantization kernels,
which can be avoided by fusing norm and quantization kernels. Consumers
like sglang and vllm can lower to these norm + quant fusion kernels
using custom torch compile passes

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

### Reference
I have been working on adding custom fusion passes to sglang as part of
the following [RFC](sgl-project/sglang#10118)
and would like to use flashinfer's norm kernels for the norm quant
fusions instead of migrating vllm kernels to sglang as part of the
following [MR](sgl-project/sglang#10549)

### Implementation
I realise that existing kernels (at least for rmsnorm) can be modified
to add the scale parameter as an optional parameter, thereby avoiding
most code duplication. However, as an initial implementation, I have
opted for a separate implementation route. This can be refactored if
required.

For fused_add_rmsnorm_quant, I don't think an in-place update would be
possible since dtypes for input and output differ

Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am
not aware of getting this value at compile time without including c10
headers from torch, and not sure if that is acceptable post tvm ffi
migration

Following is a snippet from VLLM, and I have seen similar code for
getting the FP8 numeric limits
```cpp
#include <c10/util/Float8_e4m3fn.h>

template <typename T,
          typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
                                      std::is_same_v<T, c10::Float8_e4m3fnuz> ||
                                      std::is_same_v<T, int8_t>>>
struct quant_type_max {
  static constexpr T val() { return std::numeric_limits<T>::max(); }
};
```

The best option in my mind is to introduce `include/flashinfer/fp8.h`
containing something similar to the above snippet, and also support e5m2

### Tests
atol and rtol for the fp8 assertions had to be high due to the low
precision nature of the data, but with tolerances of 1e-2, just a few
tests fail with a single element mismatch

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added quantized RMSNorm and fused quantized RMSNorm (residual-add)
with configurable scale, eps, and PDL toggle.
* Supports FP16/FP8 paths and optional per-token or per-tensor scaling;
outputs are clamped for quantized formats.

* **Tests**
* Added tests validating quantized normalization and fused-residual
flows across dtypes, batch sizes, scaling modes, and PDL configurations.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants