Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
return cudaSuccess;
}

template <typename T>
struct fp8_clamp_max {
static_assert(sizeof(T) == 0, "Unsupported FP8 type for fp8_clamp_max");
};

template <>
struct fp8_clamp_max<__nv_fp8_e4m3> {
static constexpr float value = 448.0f;
};

template <>
struct fp8_clamp_max<__nv_fp8_e5m2> {
static constexpr float value = 57344.0f;
};

template <uint32_t VEC_SIZE, typename T, typename O>
__global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight,
O* __restrict__ output, const uint32_t d,
Expand Down Expand Up @@ -214,7 +229,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] =
float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
output_vec[j] =
fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +
Expand Down Expand Up @@ -598,7 +614,8 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
output_vec[j] =
fmaxf(-fp8_clamp_max<O>::value, fminf(output_vec[j], fp8_clamp_max<O>::value));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +
Expand Down
Loading