Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 193 additions & 44 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
#ifndef FLASHINFER_NORM_CUH_
#define FLASHINFER_NORM_CUH_

#include <atomic>
#include <cerrno>
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <numeric>

#include "flashinfer/trtllm/common/cudaTypeUtils.cuh"
Expand All @@ -33,7 +37,72 @@ namespace norm {

using namespace tensorrt_llm::common;

template <uint32_t VEC_SIZE, typename T>
inline int GetRMSNormNumWarpsOverrideFromEnv() {
static int num_warps_override = []() -> int {
const char* env = std::getenv("FLASHINFER_RMSNORM_NUM_WARPS");
if (env == nullptr || env[0] == 0) {
return 0;
}
char* end = nullptr;
errno = 0;
unsigned long parsed = std::strtoul(env, &end, 10);
if (end == env || *end != 0 || parsed == 0 || errno == ERANGE ||
parsed > static_cast<unsigned long>(std::numeric_limits<int>::max())) {
return 0;
}
return static_cast<int>(parsed);
}();
return num_warps_override;
}

inline uint32_t GetRMSNormNumWarps(uint32_t d, uint32_t vec_size) {
const uint32_t max_threads = std::max<uint32_t>(32u, std::min<uint32_t>(1024u, d / vec_size));
const uint32_t max_num_warps = ceil_div(max_threads, 32u);

const int override = GetRMSNormNumWarpsOverrideFromEnv();
if (override > 0) {
return std::min<uint32_t>(max_num_warps, static_cast<uint32_t>(override));
}

const uint32_t vec_chunks = d / vec_size;
uint32_t target_threads = ceil_div(vec_chunks, 4u);
target_threads = ceil_div(target_threads, 32u) * 32u;
target_threads = std::max<uint32_t>(32u, std::min<uint32_t>(256u, target_threads));
target_threads = std::min<uint32_t>(target_threads, max_threads);
return std::max<uint32_t>(1u, ceil_div(target_threads, 32u));
}

// Thread-safe per-process cache. Supports multiple CUDA devices in one process.
inline int GetRMSNormMaxSharedMemoryPerBlockOptin() {
constexpr int kDefaultSmemLimit = 48 * 1024;
int device = 0;
if (cudaGetDevice(&device) != cudaSuccess || device < 0) {
return kDefaultSmemLimit;
}

constexpr int kMaxCachedDevices = 32;
static std::atomic<int> max_smem_per_block[kMaxCachedDevices]{};

int cached = 0;
if (device < kMaxCachedDevices) {
cached = max_smem_per_block[device].load(std::memory_order_relaxed);
}

if (cached == 0) {
int queried = 0;
if (cudaDeviceGetAttribute(&queried, cudaDevAttrMaxSharedMemoryPerBlockOptin, device) !=
cudaSuccess) {
return kDefaultSmemLimit;
}
cached = queried;
if (device < kMaxCachedDevices) {
max_smem_per_block[device].store(cached, std::memory_order_relaxed);
}
}
return cached;
}

template <uint32_t VEC_SIZE, typename T, bool CACHE_INPUT = false>
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
const uint32_t d, const uint32_t stride_input,
const uint32_t stride_output, float weight_bias, float eps) {
Expand All @@ -46,6 +115,8 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];
const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u;
[[maybe_unused]] T* smem_input = reinterpret_cast<T*>(smem + smem_reduce_elems);

float sum_sq = 0.f;

Expand All @@ -56,8 +127,15 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE;
if (vec_offset < d) {
input_vec.load(input + bx * stride_input + vec_offset);
if constexpr (CACHE_INPUT) {
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
smem_input[vec_offset + j] = input_vec[j];
}
}
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
Expand Down Expand Up @@ -92,17 +170,21 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE;
if (vec_offset < d) {
if constexpr (CACHE_INPUT) {
input_vec.load(smem_input + vec_offset);
} else {
input_vec.load(input + bx * stride_input + vec_offset);
}
weight_vec.load(weight + vec_offset);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * stride_output + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
if (vec_offset < d) {
output_vec.store(output + bx * stride_output + vec_offset);
}
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
Expand All @@ -116,13 +198,18 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
const uint32_t num_warps = GetRMSNormNumWarps(d, vec_size);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u;
const uint32_t smem_size_no_cache = smem_reduce_elems * sizeof(float);
const size_t smem_size_with_cache =
static_cast<size_t>(smem_size_no_cache) + static_cast<size_t>(d) * sizeof(T);
const int max_smem_per_block = GetRMSNormMaxSharedMemoryPerBlockOptin();
const bool cache_input = smem_size_with_cache <= static_cast<size_t>(max_smem_per_block);
const uint32_t smem_size =
cache_input ? static_cast<uint32_t>(smem_size_with_cache) : smem_size_no_cache;
float weight_bias = 0.f;
void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps};

cudaLaunchConfig_t config;
config.gridDim = nblks;
Expand All @@ -136,16 +223,28 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
config.attrs = attrs;

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = RMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input,
stride_output, weight_bias, eps));
if (cache_input) {
auto kernel = RMSNormKernel<VEC_SIZE, T, true>;
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d,
stride_input, stride_output, weight_bias, eps));
} else {
auto kernel = RMSNormKernel<VEC_SIZE, T, false>;
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d,
stride_input, stride_output, weight_bias, eps));
}
});
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T, typename O>
template <uint32_t VEC_SIZE, typename T, typename O, bool CACHE_INPUT = false>
__global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight,
O* __restrict__ output, const uint32_t d,
const uint32_t stride_input, const uint32_t stride_output,
Expand All @@ -160,6 +259,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
const float scale_inv = 1.0f / scale;
extern __shared__ float smem[];
const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u;
[[maybe_unused]] T* smem_input = reinterpret_cast<T*>(smem + smem_reduce_elems);

float sum_sq = 0.f;

Expand All @@ -170,8 +271,15 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight
for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE;
if (vec_offset < d) {
input_vec.load(input + bx * stride_input + vec_offset);
if constexpr (CACHE_INPUT) {
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
smem_input[vec_offset + j] = input_vec[j];
}
}
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
Expand Down Expand Up @@ -206,19 +314,23 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight
vec_t<float, VEC_SIZE> output_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE;
if (vec_offset < d) {
if constexpr (CACHE_INPUT) {
input_vec.load(smem_input + vec_offset);
} else {
input_vec.load(input + bx * stride_input + vec_offset);
}
weight_vec.load(weight + vec_offset);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] =
float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv;
output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE +
thread_id * VEC_SIZE);
if (vec_offset < d) {
output_vec.cast_store(output + bx * stride_output + vec_offset);
}
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
Expand All @@ -232,11 +344,17 @@ cudaError_t RMSNormQuant(T* input, T* weight, O* output, uint32_t batch_size, ui
float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
const uint32_t num_warps = GetRMSNormNumWarps(d, vec_size);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u;
const uint32_t smem_size_no_cache = smem_reduce_elems * sizeof(float);
const size_t smem_size_with_cache =
static_cast<size_t>(smem_size_no_cache) + static_cast<size_t>(d) * sizeof(T);
const int max_smem_per_block = GetRMSNormMaxSharedMemoryPerBlockOptin();
const bool cache_input = smem_size_with_cache <= static_cast<size_t>(max_smem_per_block);
const uint32_t smem_size =
cache_input ? static_cast<uint32_t>(smem_size_with_cache) : smem_size_no_cache;
float weight_bias = 0.f;

cudaLaunchConfig_t config;
Expand All @@ -251,11 +369,25 @@ cudaError_t RMSNormQuant(T* input, T* weight, O* output, uint32_t batch_size, ui
config.attrs = attrs;

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = RMSNormQuantKernel<VEC_SIZE, T, O>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input,
stride_output, weight_bias, scale, eps));
if (cache_input) {
auto kernel = RMSNormQuantKernel<VEC_SIZE, T, O, true>;
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d,
stride_input, stride_output, weight_bias, scale,
eps));
} else {
auto kernel = RMSNormQuantKernel<VEC_SIZE, T, O, false>;
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d,
stride_input, stride_output, weight_bias, scale,
eps));
}
});
return cudaSuccess;
}
Expand Down Expand Up @@ -653,13 +785,18 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
const uint32_t num_warps = GetRMSNormNumWarps(d, vec_size);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u;
const uint32_t smem_size_no_cache = smem_reduce_elems * sizeof(float);
const size_t smem_size_with_cache =
static_cast<size_t>(smem_size_no_cache) + static_cast<size_t>(d) * sizeof(T);
const int max_smem_per_block = GetRMSNormMaxSharedMemoryPerBlockOptin();
const bool cache_input = smem_size_with_cache <= static_cast<size_t>(max_smem_per_block);
const uint32_t smem_size =
cache_input ? static_cast<uint32_t>(smem_size_with_cache) : smem_size_no_cache;
float weight_bias = 1.f;
void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps};

cudaLaunchConfig_t config;
config.gridDim = nblks;
Expand All @@ -673,11 +810,23 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
config.attrs = attrs;

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = RMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input,
stride_output, weight_bias, eps));
if (cache_input) {
auto kernel = RMSNormKernel<VEC_SIZE, T, true>;
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d,
stride_input, stride_output, weight_bias, eps));
} else {
auto kernel = RMSNormKernel<VEC_SIZE, T, false>;
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d,
stride_input, stride_output, weight_bias, eps));
}
});
return cudaSuccess;
}
Expand Down