-
Notifications
You must be signed in to change notification settings - Fork 861
fix: use type-specific FP8 max value for clamping in RMSNorm and RoPE quantization kernels #2639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #ifndef FLASHINFER_FP8_TYPES_CUH_ | ||
| #define FLASHINFER_FP8_TYPES_CUH_ | ||
|
|
||
| #include <cuda_fp8.h> | ||
|
|
||
| namespace flashinfer { | ||
|
|
||
| /*! | ||
| * \brief Type trait providing the maximum representable value for FP8 types. | ||
| * Used for clamping before FP8 cast to avoid NaN/Inf. | ||
| * | ||
| * - __nv_fp8_e4m3: 4-bit exponent, 3-bit mantissa, max = 448.0 | ||
| * - __nv_fp8_e5m2: 5-bit exponent, 2-bit mantissa, max = 57344.0 | ||
| */ | ||
| template <typename T> | ||
| struct fp8_clamp_max { | ||
| static_assert(sizeof(T) == 0, "Unsupported FP8 type for fp8_clamp_max"); | ||
| }; | ||
|
|
||
| template <> | ||
| struct fp8_clamp_max<__nv_fp8_e4m3> { | ||
| static constexpr float value = 448.0f; | ||
| }; | ||
|
|
||
| template <> | ||
| struct fp8_clamp_max<__nv_fp8_e5m2> { | ||
| static constexpr float value = 57344.0f; | ||
| }; | ||
|
|
||
| } // namespace flashinfer | ||
|
|
||
| #endif // FLASHINFER_FP8_TYPES_CUH_ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| #include <string> | ||
| #include <type_traits> | ||
|
|
||
| #include "fp8_types.cuh" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Pre-commit reports formatting diffs; please format this header before merging. 🤖 Prompt for AI Agents |
||
| #include "layout.cuh" | ||
| #include "math.cuh" | ||
| #include "page.cuh" | ||
|
|
@@ -273,6 +274,8 @@ __device__ __forceinline__ void scale_store_partial_chunk(const DType* in_ptr, Q | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| vec[i] = vec[i] * scale; | ||
| vec[i] = | ||
| fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
| if (lane_elem_offset + vec_size <= chunk_valid) { | ||
| vec.cast_store(out_ptr + lane_elem_offset); | ||
|
|
@@ -501,6 +504,8 @@ __global__ void RopeQuantizeKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; | ||
| q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
| q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); | ||
|
|
||
|
|
@@ -526,6 +531,8 @@ __global__ void RopeQuantizeKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; | ||
| k_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(k_rope_vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
| k_rope_vec.cast_store(k_rope_out_ptr + tx * vec_size); | ||
|
|
||
|
|
@@ -902,6 +909,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; | ||
| q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
| q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); | ||
|
|
||
|
|
@@ -931,6 +940,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; | ||
| k_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(k_rope_vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
|
|
||
| if constexpr (IS_MLA) { | ||
|
|
@@ -961,6 +972,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; | ||
| k_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(k_nope_vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
|
|
||
| if constexpr (IS_MLA) { | ||
|
|
@@ -991,6 +1004,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| v_vec[i] = v_vec[i] * quant_scale_kv; | ||
| v_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(v_vec[i], fp8_clamp_max<QuantType>::value)); | ||
|
Comment on lines
+1007
to
+1008
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } | ||
| QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, | ||
| v_elem_offset + tx * vec_size); | ||
|
|
@@ -1020,6 +1035,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( | |
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; | ||
| q_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value, | ||
| fminf(q_nope_vec[i], fp8_clamp_max<QuantType>::value)); | ||
| } | ||
| q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve code readability and reduce repetition across
norm.cuhandpos_enc.cuh, consider adding a helper function for clamping values before FP8 conversion. This pattern is used in multiple places.You could add the following function to this file, which would make the call sites much cleaner (e.g.,
output_vec[j] = clamp<O>(output_vec[j]);).