diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6814e892d1..59bcdc7528 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -145,6 +145,21 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_ return cudaSuccess; } +template +struct fp8_clamp_max { + static_assert(sizeof(T) == 0, "Unsupported FP8 type for fp8_clamp_max"); +}; + +template <> +struct fp8_clamp_max<__nv_fp8_e4m3> { + static constexpr float value = 448.0f; +}; + +template <> +struct fp8_clamp_max<__nv_fp8_e5m2> { + static constexpr float value = 57344.0f; +}; + template __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight, O* __restrict__ output, const uint32_t d, @@ -214,7 +229,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + output_vec[j] = + fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + @@ -598,7 +614,8 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + output_vec[j] = + fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +