From 11088ee5d13e769ce59aa99f2aa499f94e01a532 Mon Sep 17 00:00:00 2001 From: Bias92 Date: Sun, 22 Feb 2026 01:41:58 +0900 Subject: [PATCH 1/2] fix: use type-specific FP8 max value for clamping in RMSNorm quantization kernels --- include/flashinfer/norm.cuh | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6814e892d1..f540185f2c 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -145,6 +145,21 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_ return cudaSuccess; } +template +struct fp8_clamp_max { + static_assert(sizeof(T) == 0, "Unsupported FP8 type for fp8_clamp_max"); +}; + +template <> +struct fp8_clamp_max<__nv_fp8_e4m3> { + static constexpr float value = 448.0f; +}; + +template <> +struct fp8_clamp_max<__nv_fp8_e5m2> { + static constexpr float value = 57344.0f; +}; + template __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight, O* __restrict__ output, const uint32_t d, @@ -214,7 +229,7 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + output_vec[j] = fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + @@ -598,7 +613,7 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + output_vec[j] = fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + From a45c497bdbe2b4360b9ba458f690f1585ca257d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9E=AC=EC=9A=B0?= Date: Tue, 24 Feb 2026 03:53:22 +0900 Subject: [PATCH 2/2] style: fix pre-commit issues --- include/flashinfer/norm.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index f540185f2c..59bcdc7528 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -229,7 +229,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); + output_vec[j] = + fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + @@ -613,7 +614,8 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); + output_vec[j] = + fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +