diff --git a/include/flashinfer/fp8_types.cuh b/include/flashinfer/fp8_types.cuh new file mode 100644 index 0000000000..a5093c06fe --- /dev/null +++ b/include/flashinfer/fp8_types.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_FP8_TYPES_CUH_ +#define FLASHINFER_FP8_TYPES_CUH_ + +#include + +namespace flashinfer { + +/*! + * \brief Type trait providing the maximum representable value for FP8 types. + * Used for clamping before FP8 cast to avoid NaN/Inf. + * + * - __nv_fp8_e4m3: 4-bit exponent, 3-bit mantissa, max = 448.0 + * - __nv_fp8_e5m2: 5-bit exponent, 2-bit mantissa, max = 57344.0 + */ +template +struct fp8_clamp_max { + static_assert(sizeof(T) == 0, "Unsupported FP8 type for fp8_clamp_max"); +}; + +template <> +struct fp8_clamp_max<__nv_fp8_e4m3> { + static constexpr float value = 448.0f; +}; + +template <> +struct fp8_clamp_max<__nv_fp8_e5m2> { + static constexpr float value = 57344.0f; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_FP8_TYPES_CUH_ diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6814e892d1..471df06a4d 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -23,6 +23,7 @@ #include "flashinfer/trtllm/common/cudaUtils.h" #include "flashinfer/trtllm/common/reduceKernelUtils.cuh" #include "flashinfer/utils.cuh" +#include "fp8_types.cuh" #include "math.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" @@ -214,7 +215,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + output_vec[j] = + fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + @@ -598,7 +600,8 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; - output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); + output_vec[j] = + fmaxf(-fp8_clamp_max::value, fminf(output_vec[j], fp8_clamp_max::value)); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 4fdd75e0a3..c919f0319a 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -22,6 +22,7 @@ #include #include +#include "fp8_types.cuh" #include "layout.cuh" #include "math.cuh" #include "page.cuh" @@ -273,6 +274,8 @@ __device__ __forceinline__ void scale_store_partial_chunk(const DType* in_ptr, Q #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { vec[i] = vec[i] * scale; + vec[i] = + fmaxf(-fp8_clamp_max::value, fminf(vec[i], fp8_clamp_max::value)); } if (lane_elem_offset + vec_size <= chunk_valid) { vec.cast_store(out_ptr + lane_elem_offset); @@ -501,6 +504,8 @@ __global__ void RopeQuantizeKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; + q_rope_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(q_rope_vec[i], fp8_clamp_max::value)); } q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); @@ -526,6 +531,8 @@ __global__ void RopeQuantizeKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; + k_rope_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(k_rope_vec[i], fp8_clamp_max::value)); } k_rope_vec.cast_store(k_rope_out_ptr + tx * vec_size); @@ -902,6 +909,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; + q_rope_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(q_rope_vec[i], fp8_clamp_max::value)); } q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); @@ -931,6 +940,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; + k_rope_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(k_rope_vec[i], fp8_clamp_max::value)); } if constexpr (IS_MLA) { @@ -961,6 +972,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; + k_nope_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(k_nope_vec[i], fp8_clamp_max::value)); } if constexpr (IS_MLA) { @@ -991,6 +1004,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { v_vec[i] = v_vec[i] * quant_scale_kv; + v_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(v_vec[i], fp8_clamp_max::value)); } QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, v_elem_offset + tx * vec_size); @@ -1020,6 +1035,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + q_nope_vec[i] = fmaxf(-fp8_clamp_max::value, + fminf(q_nope_vec[i], fp8_clamp_max::value)); } q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); }