feat: RMSNorm/Fused RMSNorm + FP8 Quantization kernels#2243
feat: RMSNorm/Fused RMSNorm + FP8 Quantization kernels#2243yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds two quantized RMSNorm APIs— Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @BLaZeKiLL, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces specialized CUDA kernels for RMSNorm and Fused Add RMSNorm that incorporate FP8 quantization directly into the normalization process. The primary goal is to enhance the efficiency of FP8 model inference by eliminating the need for separate, intermediate quantization steps. This optimization is particularly beneficial for consumers like sglang and vllm, enabling them to leverage custom PyTorch compile passes for improved performance in large language models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces new CUDA kernels for RMSNorm and Fused RMSNorm with FP8 quantization. The changes are well-structured and follow the existing patterns in the codebase. My review focuses on improving the Python API correctness, making the CUDA kernels more generic, and increasing test coverage. I've identified a few issues in the Python bindings related to return types and documentation, suggested removing hardcoded values in the CUDA kernels, and recommended parameterizing the tests to cover all supported FP8 types.
| for (uint32_t j = 0; j < VEC_SIZE; j++) { | ||
| output_vec[j] = | ||
| float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; | ||
| output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); |
There was a problem hiding this comment.
The clamping values -448.0f and 448.0f are hardcoded for FP8 E4M3. As you noted in the PR description, this prevents the kernel from working correctly with other FP8 types like E5M2. Please make this generic. A good approach would be to use if constexpr on the output type O to select the appropriate numeric limits, and define these limits in a central header to avoid magic numbers.
Example:
if constexpr (std::is_same_v<O, __nv_fp8_e4m3>) {
// E4M3 limits
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
} else if constexpr (std::is_same_v<O, __nv_fp8_e5m2>) {
// E5M2 limits
output_vec[j] = fmaxf(-57344.0f, fminf(output_vec[j], 57344.0f));
}| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; j++) { | ||
| output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; | ||
| output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); |
| def llama_rms_norm_quant(x, w, scale, eps=1e-6): | ||
| inv_scale = torch.reciprocal(torch.tensor(scale)).float() | ||
| x = x.float() | ||
| variance = x.pow(2).mean(dim=-1, keepdim=True) | ||
| x = x * torch.rsqrt(variance + eps) | ||
| x = x * w.float() | ||
| x = x * inv_scale | ||
| x = torch.clamp( | ||
| x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max | ||
| ) | ||
| x = x.to(torch.float8_e4m3fn) | ||
| return x |
There was a problem hiding this comment.
This reference implementation is hardcoded for torch.float8_e4m3fn. To enable testing with other FP8 types like e5m2, please parameterize this function to accept an fp8_dtype and use its finfo for clamping. This will also make the tests more robust.
| def llama_rms_norm_quant(x, w, scale, eps=1e-6): | |
| inv_scale = torch.reciprocal(torch.tensor(scale)).float() | |
| x = x.float() | |
| variance = x.pow(2).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + eps) | |
| x = x * w.float() | |
| x = x * inv_scale | |
| x = torch.clamp( | |
| x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max | |
| ) | |
| x = x.to(torch.float8_e4m3fn) | |
| return x | |
| def llama_rms_norm_quant(x, w, scale, fp8_dtype, eps=1e-6): | |
| inv_scale = torch.reciprocal(torch.tensor(scale)).float() | |
| x = x.float() | |
| variance = x.pow(2).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + eps) | |
| x = x * w.float() | |
| x = x * inv_scale | |
| x = torch.clamp( | |
| x, torch.finfo(fp8_dtype).min, torch.finfo(fp8_dtype).max | |
| ) | |
| x = x.to(fp8_dtype) | |
| return x |
| def fused_add_rms_norm_quant(x, residual, weight, scale, eps): | ||
| inv_scale = torch.reciprocal(torch.tensor(scale)).float() | ||
| orig_dtype = x.dtype | ||
| x = x.to(torch.float32) | ||
| x = x + residual.to(torch.float32) | ||
| residual = x.to(orig_dtype) | ||
| variance = x.pow(2).mean(dim=-1, keepdim=True) | ||
| x = x * torch.rsqrt(variance + eps) | ||
| x = x * weight.float() | ||
| x = x * inv_scale | ||
| x = torch.clamp( | ||
| x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max | ||
| ) | ||
| x = x.to(torch.float8_e4m3fn) | ||
| return x, residual |
There was a problem hiding this comment.
Similar to llama_rms_norm_quant, this reference implementation is hardcoded for torch.float8_e4m3fn. Please parameterize it to accept an fp8_dtype to allow testing against different FP8 formats.
| def fused_add_rms_norm_quant(x, residual, weight, scale, eps): | |
| inv_scale = torch.reciprocal(torch.tensor(scale)).float() | |
| orig_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| x = x + residual.to(torch.float32) | |
| residual = x.to(orig_dtype) | |
| variance = x.pow(2).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + eps) | |
| x = x * weight.float() | |
| x = x * inv_scale | |
| x = torch.clamp( | |
| x, torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max | |
| ) | |
| x = x.to(torch.float8_e4m3fn) | |
| return x, residual | |
| def fused_add_rms_norm_quant(x, residual, weight, scale, fp8_dtype, eps): | |
| inv_scale = torch.reciprocal(torch.tensor(scale)).float() | |
| orig_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| x = x + residual.to(torch.float32) | |
| residual = x.to(orig_dtype) | |
| variance = x.pow(2).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + eps) | |
| x = x * weight.float() | |
| x = x * inv_scale | |
| x = torch.clamp( | |
| x, torch.finfo(fp8_dtype).min, torch.finfo(fp8_dtype).max | |
| ) | |
| x = x.to(fp8_dtype) | |
| return x, residual |
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) | ||
| @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0]) | ||
| @pytest.mark.parametrize("enable_pdl", [True, False]) | ||
| @pytest.mark.parametrize("contiguous", [True, False]) | ||
| def test_norm_quant( | ||
| batch_size, hidden_size, dtype, quant_scale, enable_pdl, contiguous | ||
| ): | ||
| if contiguous: | ||
| x = torch.randn(batch_size, hidden_size).to(0).to(dtype) | ||
| else: | ||
| x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) | ||
| x = x[:, :hidden_size] | ||
|
|
||
| if enable_pdl and not device_support_pdl(x.device): | ||
| pytest.skip("PDL is only available for Hopper and later GPUs") | ||
|
|
||
| w = torch.randn(hidden_size).to(0).to(dtype) | ||
|
|
||
| y_ref = llama_rms_norm_quant(x, w, quant_scale) | ||
| y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda") | ||
| flashinfer.norm.rmsnorm_quant(y, x, w, quant_scale, enable_pdl=enable_pdl) | ||
|
|
||
| torch.testing.assert_close(y_ref.float(), y.float(), rtol=1, atol=1) |
There was a problem hiding this comment.
This test only covers torch.float8_e4m3fn. The underlying kernel supports other FP8 types like torch.float8_e5m2. After parameterizing the reference implementation llama_rms_norm_quant, please also parameterize this test to run against different FP8 dtypes (e.g., torch.float8_e4m3fn, torch.float8_e5m2) to ensure full coverage.
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) | ||
| @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0]) | ||
| @pytest.mark.parametrize("enable_pdl", [True, False]) | ||
| @pytest.mark.parametrize("contiguous", [True, False]) | ||
| def test_fused_add_rmsnorm_quant( | ||
| batch_size, hidden_size, dtype, quant_scale, enable_pdl, contiguous | ||
| ): | ||
| eps = 1e-6 | ||
|
|
||
| if contiguous: | ||
| x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") | ||
| else: | ||
| x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) | ||
| x = x[:, :hidden_size] | ||
|
|
||
| if enable_pdl and not device_support_pdl(x.device): | ||
| pytest.skip("PDL is only available for Hopper and later GPUs") | ||
|
|
||
| residual = torch.randn_like(x) | ||
| weight = torch.randn(hidden_size, dtype=dtype, device="cuda") | ||
|
|
||
| x_native, residual_native = fused_add_rms_norm_quant( | ||
| x.clone(), residual.clone(), weight, quant_scale, eps | ||
| ) | ||
|
|
||
| x_fused = x.clone() | ||
| residual_fused = residual.clone() | ||
| y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda") | ||
| flashinfer.norm.fused_add_rmsnorm_quant( | ||
| y, x_fused, residual_fused, weight, quant_scale, eps, enable_pdl=enable_pdl | ||
| ) | ||
|
|
||
| torch.testing.assert_close(y.float(), x_native.float(), rtol=1, atol=1) | ||
| torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
tests/utils/test_norm.py (1)
148-152: Consider whether tolerances are appropriate for FP8.The tolerances
rtol=1, atol=1are very loose—they allow deviations up to 100% relative or 1.0 absolute. While FP8 has limited precision, this may mask subtle bugs. Forfloat8_e4m3fn(max 448, min granularity varies by magnitude), tighter tolerances likeatol=0.1oratol=0.5might still pass while catching more regressions.flashinfer/norm.py (1)
207-228: Docstring is missing documentation foroutandscaleparameters.The docstring doesn't document the
out(output tensor) andscale(quantization scale factor) parameters, which are important for users to understand the API.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/flashinfer_norm_binding.cu(2 hunks)csrc/norm.cu(2 hunks)flashinfer/norm.py(2 hunks)include/flashinfer/norm.cuh(3 hunks)tests/utils/test_norm.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/norm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream(294-296)
csrc/flashinfer_norm_binding.cu (2)
csrc/norm.cu (6)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm(117-145)fused_add_rmsnorm(117-118)fused_add_rmsnorm_quant(147-178)fused_add_rmsnorm_quant(147-148)flashinfer/norm.py (3)
rmsnorm_quant(97-132)fused_add_rmsnorm(149-180)fused_add_rmsnorm_quant(198-234)
tests/utils/test_norm.py (4)
flashinfer/utils.py (1)
device_support_pdl(615-619)csrc/flashinfer_norm_binding.cu (2)
rmsnorm_quant(20-21)fused_add_rmsnorm_quant(26-27)csrc/norm.cu (4)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm_quant(147-178)fused_add_rmsnorm_quant(147-148)flashinfer/norm.py (2)
rmsnorm_quant(97-132)fused_add_rmsnorm_quant(198-234)
🔇 Additional comments (12)
csrc/flashinfer_norm_binding.cu (2)
20-27: LGTM!The new function declarations for
rmsnorm_quantandfused_add_rmsnorm_quantcorrectly mirror their implementations incsrc/norm.cu. Parameter types and order are consistent with the kernel wrappers.
38-40: LGTM!The exports for the new quantized variants follow the established pattern.
include/flashinfer/norm.cuh (4)
148-161: LGTM on kernel structure.The
RMSNormQuantKernelfollows the same pattern as the non-quantizedRMSNormKernel, with appropriate additions forscale_invcomputation and output typeOfor quantized output.
229-261: LGTM!The
RMSNormQuanthost wrapper correctly follows the pattern established byRMSNorm, with appropriate handling of the additionalscaleparameter and output typeO.
515-610: LGTM on kernel logic.The
FusedAddRMSNormQuantKernelcorrectly implements the fused operation: updatingresidualin-place and writing the normalized+quantized result to a separateoutputtensor. The shared memory usage for intermediate values (smem_x) matches the non-quantized variant.Same note as before: line 600 uses hard-coded
448.0ffor clamping.
612-647: LGTM!The
FusedAddRMSNormQuanthost wrapper correctly mirrorsFusedAddRMSNorm, with the additionaloutputpointer andstride_outputparameter for the quantized output tensor.csrc/norm.cu (1)
88-115: LGTM!The 2D-only restriction is reasonable for the quantized variant. The implementation correctly dispatches FP16 input to FP8 output through the nested type dispatch macros.
tests/utils/test_norm.py (3)
35-46: LGTM on the reference implementation.The
llama_rms_norm_quantreference correctly implements the quantization formula: normalize → scale by1/scale→ clamp to FP8 range → cast tofloat8_e4m3fn.
83-97: LGTM!The
fused_add_rms_norm_quantreference correctly mirrors the kernel behavior: fused add → update residual → normalize → quantize.
243-255: LGTM on test structure.The test correctly validates both the quantized output (with FP8 tolerances) and the residual update (with standard tolerances). Good separation of concerns.
flashinfer/norm.py (2)
135-144: LGTM!The fake op is appropriately a no-op since it mutates the pre-allocated
outtensor.
237-247: LGTM!The fake op correctly matches the signature of the real implementation.
| for (uint32_t j = 0; j < VEC_SIZE; j++) { | ||
| output_vec[j] = | ||
| float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; | ||
| output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 macro definition
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8" --type cpp -B 2 -A 20Repository: flashinfer-ai/flashinfer
Length of output: 1821
🏁 Script executed:
# Search for FP8 constants or fp8.h
fd -n "fp8" -e h
rg -n "448|57344" --type cpp -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 8470
🏁 Script executed:
# Check norm.cu to see how the kernel is called/dispatched
rg -n "FusedAddRMSNormQuantKernel|FusedRMSNormQuantKernel" csrc/norm.cu -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for dispatch pattern in norm.cu
rg -n "DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8" csrc/norm.cu -B 3 -A 10Repository: flashinfer-ai/flashinfer
Length of output: 1628
🏁 Script executed:
# Search for RMSNormQuant implementation to confirm it's in norm.cuh
rg -n "RMSNormQuant" include/flashinfer/norm.cuh -B 2 -A 15 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2389
🏁 Script executed:
# Get the full RMSNormQuantKernel to see line 217 context
rg -n "RMSNormQuantKernel" include/flashinfer/norm.cuh -A 80 | head -90Repository: flashinfer-ai/flashinfer
Length of output: 3843
🏁 Script executed:
# Search for FusedAddRMSNormQuantKernel to see if it has the same issue at line 600
rg -n "FusedAddRMSNormQuantKernel" include/flashinfer/norm.cuh -A 100 | grep -A 5 -B 5 "448"Repository: flashinfer-ai/flashinfer
Length of output: 623
🏁 Script executed:
# Check if there are any recent commits or comments addressing this 448 value
rg -n "FP8_E5M2|57344|e5m2" include/flashinfer/norm.cuhRepository: flashinfer-ai/flashinfer
Length of output: 50
Hard-coded FP8 clamping limit breaks e5m2 outputs.
The kernel templates use hard-coded 448.0f clamping at lines 217 and 600, which is correct for float8_e4m3fn but incorrect for float8_e5m2 (max ≈ 57344). Since DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 in csrc/tvm_ffi_utils.h dispatches to both _DISPATCH_CASE_FP8_E4M3 and _DISPATCH_CASE_FP8_E5M2, values up to 57344 will be incorrectly clipped to 448 when using e5m2. Use type-dependent limits (following the pattern in csrc/fmha_v2/fmha/numeric_types.h which defines MAX_E4M3 and MAX_E5M2) or a named constant.
594845f to
6e26dbb
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (4)
flashinfer/norm.py (2)
95-133: Missing return statement.The function declares
-> torch.Tensorand documents a return value, but line 132 doesn't return theouttensor after the kernel call.🔎 Proposed fix
if enable_pdl is None: enable_pdl = device_support_pdl(input.device) get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl) + return out
194-206: Return type should betorch.Tensorand docstring needs correction.The function signature indicates
-> Nonebut it should return theouttensor for consistency withrmsnorm_quantand becauseoutis a separate output buffer (not in-place toinput). Additionally, the docstring incorrectly describes the operation as writing toinput[i]when it actually writes toout.🔎 Proposed fixes
@flashinfer_api @register_custom_op( "flashinfer::fused_add_rmsnorm_quant", mutates_args=("out", "residual") ) def fused_add_rmsnorm_quant( out: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: float, eps: float = 1e-6, enable_pdl: Optional[bool] = None, -) -> None: +) -> torch.Tensor: r"""Fused add root mean square normalization. Step 1: ``residual[i] += input[i]`` Step 2: - ``input[i] = (residual[i] / RMS(residual)) * weight[i]`` + ``out[i] = (residual[i] / RMS(residual)) * weight[i] * (1/scale)`` (quantized to out's dtype) Parameters ---------- + out: torch.Tensor + The output tensor, will quantize the output to the dtype of this tensor. input: torch.Tensor Input tensor, shape (batch_size, hidden_size). residual: torch.Tensor Residual tensor, shape (batch_size, hidden_size). weight: torch.Tensor Weight tensor, shape (hidden_size,). + scale: float + Scale factor for quantization. eps: float Epsilon for numerical stability. enable_pdl: bool Whether to enable `programmatic dependent launch <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_ + + Returns + ------- + output: torch.Tensor + Quantized normalized tensor, shape (batch_size, hidden_size). """ if enable_pdl is None: enable_pdl = device_support_pdl(input.device) get_norm_module().fused_add_rmsnorm_quant( out, input, residual, weight, scale, eps, enable_pdl ) + return outinclude/flashinfer/norm.cuh (2)
214-218: Hard-coded FP8 E4M3 clamping breaks E5M2 support.The clamping values
-448.0fand448.0fat line 217 are hard-coded for__nv_fp8_e4m3but will incorrectly clip__nv_fp8_e5m2values (max ≈ 57344). Since the kernel is templated on output typeOand dispatched viaDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(which includes both E4M3 and E5M2), the clamping should be type-dependent.Consider using
if constexpron typeOor defining numeric limits in a central header (e.g.,include/flashinfer/fp8.has mentioned in the PR description).
598-601: Hard-coded FP8 E4M3 clamping breaks E5M2 support.Same issue as in
RMSNormQuantKernel: line 600 hard-codes E4M3 limits (±448.0f), which will incorrectly clip E5M2 values. The clamping should be type-dependent based on output typeO.
🧹 Nitpick comments (1)
csrc/flashinfer_norm_binding.cu (1)
20-21: Parameter name inconsistency with implementation.The declaration uses
TensorView outwhile the implementation incsrc/norm.cu(lines 79-114) usesTensorView output. For consistency, the declaration should match the implementation parameter name.🔎 Proposed fix
-void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps, +void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps, bool enable_pdl);
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/flashinfer_norm_binding.cu(2 hunks)csrc/norm.cu(2 hunks)flashinfer/norm.py(2 hunks)include/flashinfer/norm.cuh(3 hunks)tests/utils/test_norm.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/norm.cu
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/norm.cuh
🧬 Code graph analysis (2)
csrc/flashinfer_norm_binding.cu (1)
csrc/norm.cu (6)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm(117-145)fused_add_rmsnorm(117-118)fused_add_rmsnorm_quant(147-178)fused_add_rmsnorm_quant(147-148)
flashinfer/norm.py (5)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/utils.py (2)
register_custom_op(319-328)register_custom_op(338-357)csrc/flashinfer_norm_binding.cu (2)
rmsnorm_quant(20-21)fused_add_rmsnorm_quant(26-27)csrc/norm.cu (4)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm_quant(147-178)fused_add_rmsnorm_quant(147-148)include/flashinfer/trtllm/common.h (1)
device(83-90)
🔇 Additional comments (5)
csrc/flashinfer_norm_binding.cu (1)
38-40: LGTM!The FFI exports for the new quantized norm functions follow the established pattern and are correctly defined.
tests/utils/test_norm.py (2)
152-152: Verify the very loose test tolerances.The test uses
rtol=1, atol=1, meaning 100% relative tolerance and 1.0 absolute tolerance. This is extremely permissive and may hide accuracy issues. For reference, the non-quantized tests usertol=1e-3, atol=1e-3(line 125).While FP8's limited precision may require looser tolerances than FP16, the current values seem excessive. Please verify this is intentional and necessary for FP8 e4m3fn quantization.
254-254: Verify the very loose test tolerances for quantized output.Similar to
test_norm_quant, the quantized output comparison usesrtol=1, atol=1(100% relative/1.0 absolute tolerance), which is extremely permissive. The residual comparison (line 255) correctly uses tighter tolerances since it's not quantized.Please verify this tolerance is necessary for FP8 e4m3fn quantization and not masking potential accuracy issues.
include/flashinfer/norm.cuh (2)
229-261: LGTM!The
RMSNormQuanthost wrapper follows established patterns and correctly sets up the kernel launch configuration, including PDL support and dynamic shared memory.
612-647: LGTM!The
FusedAddRMSNormQuanthost wrapper correctly extends the fused add RMSNorm pattern with quantization support, properly handling the separate output buffer and its stride.
6e26dbb to
fe46655
Compare
Signed-off-by: Devashish Lal <laldevashish@gmail.com>
fe46655 to
1f39bcc
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (9)
flashinfer/norm.py (2)
194-237: Verify decorator parameter name and return type annotation.Based on past review comments, there were two issues identified:
- The
mutates_argsdecorator references parameter name"output"but the actual parameter is named"out"(appears fixed in current code showing"out"at line 196)- The return type annotation is
-> Nonebut for consistency withrmsnorm_quant(which returns the output tensor), this function should also returnoutPlease verify whether the function should return the
outtensor for consistency with the non-fused quantized variant.
95-133: Missing return statement.The function signature declares
-> torch.Tensorand the docstring documents a return value, but line 132 doesn't return theouttensor after calling the kernel.🔎 Proposed fix
if enable_pdl is None: enable_pdl = device_support_pdl(input.device) get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl) + return outtests/utils/test_norm.py (4)
35-46: Consider parameterizing FP8 dtype for broader test coverage.The reference implementation hardcodes
torch.float8_e4m3fnfor clamping and casting (lines 43, 45). To enable testing with other FP8 formats liketorch.float8_e5m2, consider adding anfp8_dtypeparameter and usingtorch.finfo(fp8_dtype).min/maxfor clamping.This would allow the test suite to verify correctness across different FP8 formats as the kernels add e5m2 support.
83-97: Consider parameterizing FP8 dtype.Similar to
llama_rms_norm_quant, this reference implementation hardcodestorch.float8_e4m3fn. Parameterizing the FP8 dtype would enable testing withtorch.float8_e5m2and other formats.
128-152: Consider expanding test coverage to other FP8 dtypes.The test currently only validates
torch.float8_e4m3fnoutput (line 149). After parameterizing the reference implementation, consider adding test cases fortorch.float8_e5m2to ensure full kernel coverage.Note: The relaxed tolerances (
rtol=1, atol=1at line 152) are expected for FP8 quantization due to limited precision.
220-255: Consider expanding test coverage to other FP8 dtypes.Similar to
test_norm_quant, this test only validatestorch.float8_e4m3fn(line 249). Adding test coverage fortorch.float8_e5m2would ensure the fused kernel works correctly with different FP8 formats.csrc/norm.cu (1)
80-115: Add device check for output tensor.The function validates that
inputandweightare on the same device (line 85), but doesn't verify thatoutputis also on the same device. Sinceoutputhas a different dtype (FP8), there's a risk it could be on a different device.🔎 Suggested fix
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(output); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); CHECK_DEVICE(input, weight); + CHECK_DEVICE(input, output); CHECK_DIM(1, weight); // weight: (hidden_size)include/flashinfer/norm.cuh (2)
214-218: Hardcoded FP8 clamping breaks e5m2 support.Line 217 hardcodes clamping to
±448.0f(FP8 e4m3 range). SinceDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8dispatches to both e4m3 and e5m2 types (csrc/tvm_ffi_utils.h), e5m2 values (max ≈ 57344) will be incorrectly clipped, causing silent data corruption.The PR description acknowledges this limitation and proposes adding
include/flashinfer/fp8.hto centralize FP8 limits. Consider usingif constexprwith type-dependent constants to fix this before merge.🔎 Example fix using type-dependent limits
#pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + // Clamp based on output type O + constexpr float fp8_max = std::is_same_v<O, __nv_fp8_e4m3> ? 448.0f : 57344.0f; + output_vec[j] = fmaxf(-fp8_max, fminf(output_vec[j], fp8_max)); }
598-601: Hardcoded FP8 clamping breaks e5m2 support.Line 600 has the same hardcoded
±448.0fclamping issue asRMSNormQuantKernel. This will incorrectly clip e5m2 values when the output type istorch.float8_e5m2.Apply the same type-dependent clamping fix as suggested for
RMSNormQuantKernelat line 217.
🧹 Nitpick comments (1)
csrc/flashinfer_norm_binding.cu (1)
20-27: Inconsistent parameter naming between declarations.The first parameter is named
outinrmsnorm_quant(line 20) butoutputinfused_add_rmsnorm_quant(line 26). The implementations incsrc/norm.cuuseoutputfor both. While this doesn't affect functionality, consistent naming improves code maintainability.🔎 Suggested fix for consistency
-void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps, +void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps, bool enable_pdl);
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/flashinfer_norm_binding.cu(2 hunks)csrc/norm.cu(2 hunks)flashinfer/norm.py(2 hunks)include/flashinfer/norm.cuh(3 hunks)tests/utils/test_norm.py(4 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/norm.cuh
🧬 Code graph analysis (4)
flashinfer/norm.py (4)
flashinfer/utils.py (5)
register_custom_op(319-328)register_custom_op(338-357)device_support_pdl(615-619)register_fake_op(330-334)register_fake_op(359-364)csrc/flashinfer_norm_binding.cu (2)
rmsnorm_quant(20-21)fused_add_rmsnorm_quant(26-27)csrc/norm.cu (4)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm_quant(147-181)fused_add_rmsnorm_quant(147-148)include/flashinfer/trtllm/common.h (1)
device(83-90)
csrc/flashinfer_norm_binding.cu (2)
csrc/norm.cu (6)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm(117-145)fused_add_rmsnorm(117-118)fused_add_rmsnorm_quant(147-181)fused_add_rmsnorm_quant(147-148)flashinfer/norm.py (3)
rmsnorm_quant(97-132)fused_add_rmsnorm(149-180)fused_add_rmsnorm_quant(198-237)
tests/utils/test_norm.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)flashinfer/utils.py (1)
device_support_pdl(615-619)csrc/flashinfer_norm_binding.cu (2)
rmsnorm_quant(20-21)fused_add_rmsnorm_quant(26-27)csrc/norm.cu (4)
rmsnorm_quant(80-115)rmsnorm_quant(80-81)fused_add_rmsnorm_quant(147-181)fused_add_rmsnorm_quant(147-148)flashinfer/norm.py (2)
rmsnorm_quant(97-132)fused_add_rmsnorm_quant(198-237)
csrc/norm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream(294-296)
🔇 Additional comments (6)
flashinfer/norm.py (2)
135-144: LGTM!The fake op stub is correctly implemented as a no-op for testing/compilation paths.
240-250: LGTM!The fake op stub is correctly implemented for testing/compilation.
csrc/flashinfer_norm_binding.cu (1)
38-40: LGTM!The FFI exports correctly expose the quantized norm functions.
include/flashinfer/norm.cuh (3)
229-261: LGTM!The host wrapper correctly sets up kernel launch configuration with proper PDL support, shared memory allocation, and error checking.
612-647: LGTM!The host wrapper correctly handles the fused kernel's larger shared memory requirements (for intermediate storage) and properly configures PDL support.
19-19: Good practice: explicit include for used types.Adding
#include <cstdint>explicitly documents the dependency onuint32_tand related types used throughout the kernels.
| void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual, | ||
| TensorView weight, double scale, double eps, bool enable_pdl) { | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(output); | ||
| CHECK_DEVICE(input, residual); | ||
| CHECK_DEVICE(input, weight); | ||
| CHECK_DEVICE(input, output); | ||
| CHECK_DIM(2, input); // input: (batch_size, hidden_size) | ||
| CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) | ||
| CHECK_DIM(1, weight); // weight: (hidden_size) | ||
| CHECK_DIM(2, output); | ||
| unsigned int batch_size = input.size(0); | ||
| unsigned int hidden_size = input.size(1); | ||
| TVM_FFI_ICHECK_EQ(residual.size(0), batch_size); | ||
| TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size); | ||
| TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size); | ||
| ffi::CUDADeviceGuard device_guard(input.device().device_id); | ||
| const cudaStream_t stream = get_stream(input.device()); | ||
|
|
||
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(output.dtype(), o_type, [&] { | ||
| cudaError_t status = norm::FusedAddRMSNormQuant( | ||
| static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()), | ||
| static_cast<c_type*>(weight.data_ptr()), static_cast<o_type*>(output.data_ptr()), | ||
| batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), scale, | ||
| eps, enable_pdl, stream); | ||
|
|
||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "FusedAddRMSNormQuant failed with error code " << cudaGetErrorString(status); | ||
| return true; | ||
| }); | ||
| }); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for output size checks in csrc/norm.cu around the fused_add_rmsnorm_quant function
rg -n "TVM_FFI_ICHECK_EQ.*output\.size" csrc/norm.cu -A 2 -B 2 || echo "No matches found"Repository: flashinfer-ai/flashinfer
Length of output: 1745
🏁 Script executed:
#!/bin/bash
# Check lines around 160-180 in csrc/norm.cu to see all validation checks
sed -n '146,180p' csrc/norm.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2013
Add output tensor size validation to fused_add_rmsnorm_quant.
The function validates tensor dimensions and sizes for input, residual, and weight tensors, but lacks validation for output.size(0) and output.size(1) against batch_size and hidden_size. This is inconsistent with similar functions in the codebase (e.g., other norm functions) and leaves the function vulnerable to size mismatches. Add:
TVM_FFI_ICHECK_EQ(output.size(0), batch_size);
TVM_FFI_ICHECK_EQ(output.size(1), hidden_size);
after line 19 (following the weight.size(0) check).
🤖 Prompt for AI Agents
In csrc/norm.cu around lines 147-181, the function validates input, residual and
weight sizes but misses checks that output.size(0) == batch_size and
output.size(1) == hidden_size; add two TVM_FFI_ICHECK_EQ assertions immediately
after the existing TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size) line to
compare output.size(0) with batch_size and output.size(1) with hidden_size so
the output tensor size is validated before proceeding.
|
/bot run |
At the least, I can see if I can benchmark the code and see if there is a need to/can write a cute-dsl based implementation |
follow up on flashinfer-ai#2243 quant_scale being a float causes cuda graph capture to fail even with workaround, by making it a tensor it fixes cuda graph capture for fusion passes in sglang. also added docs for the fused kernels. Signed-off-by: Devashish Lal <laldevashish@gmail.com>
follow up on flashinfer-ai#2243 quant_scale being a float causes cuda graph capture to fail even with workaround, by making it a tensor it fixes cuda graph capture for fusion passes in sglang. also added docs for the fused kernels. Signed-off-by: Devashish Lal <laldevashish@gmail.com>
follow up on flashinfer-ai#2243 quant_scale being a float causes cuda graph capture to fail even with workaround, by making it a tensor it fixes cuda graph capture for fusion passes in sglang. also added docs for the fused kernels. Signed-off-by: Devashish Lal <laldevashish@gmail.com>
…#2243) <!-- .github/pull_request_template.md --> ## 📌 Description FP8 model inference requires multiple intermediate quantization kernels, which can be avoided by fusing norm and quantization kernels. Consumers like sglang and vllm can lower to these norm + quant fusion kernels using custom torch compile passes ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ### Reference I have been working on adding custom fusion passes to sglang as part of the following [RFC](sgl-project/sglang#10118) and would like to use flashinfer's norm kernels for the norm quant fusions instead of migrating vllm kernels to sglang as part of the following [MR](sgl-project/sglang#10549) ### Implementation I realise that existing kernels (at least for rmsnorm) can be modified to add the scale parameter as an optional parameter, thereby avoiding most code duplication. However, as an initial implementation, I have opted for a separate implementation route. This can be refactored if required. For fused_add_rmsnorm_quant, I don't think an in-place update would be possible since dtypes for input and output differ Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am not aware of getting this value at compile time without including c10 headers from torch, and not sure if that is acceptable post tvm ffi migration Following is a snippet from VLLM, and I have seen similar code for getting the FP8 numeric limits ```cpp #include <c10/util/Float8_e4m3fn.h> template <typename T, typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || std::is_same_v<T, c10::Float8_e4m3fnuz> || std::is_same_v<T, int8_t>>> struct quant_type_max { static constexpr T val() { return std::numeric_limits<T>::max(); } }; ``` The best option in my mind is to introduce `include/flashinfer/fp8.h` containing something similar to the above snippet, and also support e5m2 ### Tests atol and rtol for the fp8 assertions had to be high due to the low precision nature of the data, but with tolerances of 1e-2, just a few tests fail with a single element mismatch <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added quantized RMSNorm and fused quantized RMSNorm (residual-add) with configurable scale, eps, and PDL toggle. * Supports FP16/FP8 paths and optional per-token or per-tensor scaling; outputs are clamped for quantized formats. * **Tests** * Added tests validating quantized normalization and fused-residual flows across dtypes, batch sizes, scaling modes, and PDL configurations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Devashish Lal <laldevashish@gmail.com>
📌 Description
FP8 model inference requires multiple intermediate quantization kernels, which can be avoided by fusing norm and quantization kernels. Consumers like sglang and vllm can lower to these norm + quant fusion kernels using custom torch compile passes
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Reference
I have been working on adding custom fusion passes to sglang as part of the following RFC and would like to use flashinfer's norm kernels for the norm quant fusions instead of migrating vllm kernels to sglang as part of the following MR
Implementation
I realise that existing kernels (at least for rmsnorm) can be modified to add the scale parameter as an optional parameter, thereby avoiding most code duplication. However, as an initial implementation, I have opted for a separate implementation route. This can be refactored if required.
For fused_add_rmsnorm_quant, I don't think an in-place update would be possible since dtypes for input and output differ
Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am not aware of getting this value at compile time without including c10 headers from torch, and not sure if that is acceptable post tvm ffi migration
Following is a snippet from VLLM, and I have seen similar code for getting the FP8 numeric limits
The best option in my mind is to introduce
include/flashinfer/fp8.hcontaining something similar to the above snippet, and also support e5m2Tests
atol and rtol for the fp8 assertions had to be high due to the low precision nature of the data, but with tolerances of 1e-2, just a few tests fail with a single element mismatch
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.