Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions include/flashinfer/fp8_types.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_FP8_TYPES_CUH_
#define FLASHINFER_FP8_TYPES_CUH_

#include <cuda_fp8.h>

namespace flashinfer {

/*!
* \brief Type trait providing the maximum representable value for FP8 types.
* Used for clamping before FP8 cast to avoid NaN/Inf.
*
* - __nv_fp8_e4m3: 4-bit exponent, 3-bit mantissa, max = 448.0
* - __nv_fp8_e5m2: 5-bit exponent, 2-bit mantissa, max = 57344.0
*/
template <typename T>
struct fp8_clamp_max {
static_assert(sizeof(T) == 0, "Unsupported FP8 type for fp8_clamp_max");
};

template <>
struct fp8_clamp_max<__nv_fp8_e4m3> {
static constexpr float value = 448.0f;
};

template <>
struct fp8_clamp_max<__nv_fp8_e5m2> {
static constexpr float value = 57344.0f;
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code readability and reduce repetition across norm.cuh and pos_enc.cuh, consider adding a helper function for clamping values before FP8 conversion. This pattern is used in multiple places.

You could add the following function to this file, which would make the call sites much cleaner (e.g., output_vec[j] = clamp<O>(output_vec[j]);).

};

template <typename T>
__device__ __forceinline__ float clamp(float val) {
  constexpr float max_val = fp8_clamp_max<T>::value;
  return fmaxf(-max_val, fminf(val, max_val));
}


} // namespace flashinfer

#endif // FLASHINFER_FP8_TYPES_CUH_
7 changes: 5 additions & 2 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "flashinfer/trtllm/common/cudaUtils.h"
#include "flashinfer/trtllm/common/reduceKernelUtils.cuh"
#include "flashinfer/utils.cuh"
#include "fp8_types.cuh"
#include "math.cuh"
#include "utils.cuh"
#include "vec_dtypes.cuh"
Expand Down Expand Up @@ -214,7 +215,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] =
float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
output_vec[j] =
fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +
Expand Down Expand Up @@ -598,7 +600,8 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
output_vec[j] =
fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +
Expand Down
17 changes: 17 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>
#include <type_traits>

#include "fp8_types.cuh"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

clang-format cleanup is still required for this file.

Pre-commit reports formatting diffs; please format this header before merging.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/pos_enc.cuh` at line 25, The file fails clang-format
checks; run the repository's clang-format (or pre-commit) on pos_enc.cuh to fix
whitespace, include ordering and style issues—specifically reformat the include
line referencing "fp8_types.cuh" and the surrounding header contents so the file
passes pre-commit clang-format checks before merging.

#include "layout.cuh"
#include "math.cuh"
#include "page.cuh"
Expand Down Expand Up @@ -273,6 +274,8 @@ __device__ __forceinline__ void scale_store_partial_chunk(const DType* in_ptr, Q
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
vec[i] = vec[i] * scale;
vec[i] =
fmaxf(-fp8_clamp_max<QuantType>::value, fminf(vec[i], fp8_clamp_max<QuantType>::value));
}
if (lane_elem_offset + vec_size <= chunk_valid) {
vec.cast_store(out_ptr + lane_elem_offset);
Expand Down Expand Up @@ -501,6 +504,8 @@ __global__ void RopeQuantizeKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value));
}
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);

Expand All @@ -526,6 +531,8 @@ __global__ void RopeQuantizeKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
k_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(k_rope_vec[i], fp8_clamp_max<QuantType>::value));
}
k_rope_vec.cast_store(k_rope_out_ptr + tx * vec_size);

Expand Down Expand Up @@ -902,6 +909,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
q_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(q_rope_vec[i], fp8_clamp_max<QuantType>::value));
}
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);

Expand Down Expand Up @@ -931,6 +940,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
k_rope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(k_rope_vec[i], fp8_clamp_max<QuantType>::value));
}

if constexpr (IS_MLA) {
Expand Down Expand Up @@ -961,6 +972,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
k_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(k_nope_vec[i], fp8_clamp_max<QuantType>::value));
}

if constexpr (IS_MLA) {
Expand Down Expand Up @@ -991,6 +1004,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
v_vec[i] = v_vec[i] * quant_scale_kv;
v_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(v_vec[i], fp8_clamp_max<QuantType>::value));
Comment on lines +1007 to +1008
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This clamping logic can be simplified using the proposed clamp helper function for better readability.

              v_vec[i] = clamp<QuantType>(v_vec[i]);

}
QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
v_elem_offset + tx * vec_size);
Expand Down Expand Up @@ -1020,6 +1035,8 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
q_nope_vec[i] = fmaxf(-fp8_clamp_max<QuantType>::value,
fminf(q_nope_vec[i], fp8_clamp_max<QuantType>::value));
}
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
}
Expand Down
Loading