diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6814e892d1..c061891300 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -16,7 +16,11 @@ #ifndef FLASHINFER_NORM_CUH_ #define FLASHINFER_NORM_CUH_ +#include +#include #include +#include +#include #include #include "flashinfer/trtllm/common/cudaTypeUtils.cuh" @@ -33,7 +37,72 @@ namespace norm { using namespace tensorrt_llm::common; -template +inline int GetRMSNormNumWarpsOverrideFromEnv() { + static int num_warps_override = []() -> int { + const char* env = std::getenv("FLASHINFER_RMSNORM_NUM_WARPS"); + if (env == nullptr || env[0] == 0) { + return 0; + } + char* end = nullptr; + errno = 0; + unsigned long parsed = std::strtoul(env, &end, 10); + if (end == env || *end != 0 || parsed == 0 || errno == ERANGE || + parsed > static_cast(std::numeric_limits::max())) { + return 0; + } + return static_cast(parsed); + }(); + return num_warps_override; +} + +inline uint32_t GetRMSNormNumWarps(uint32_t d, uint32_t vec_size) { + const uint32_t max_threads = std::max(32u, std::min(1024u, d / vec_size)); + const uint32_t max_num_warps = ceil_div(max_threads, 32u); + + const int override = GetRMSNormNumWarpsOverrideFromEnv(); + if (override > 0) { + return std::min(max_num_warps, static_cast(override)); + } + + const uint32_t vec_chunks = d / vec_size; + uint32_t target_threads = ceil_div(vec_chunks, 4u); + target_threads = ceil_div(target_threads, 32u) * 32u; + target_threads = std::max(32u, std::min(256u, target_threads)); + target_threads = std::min(target_threads, max_threads); + return std::max(1u, ceil_div(target_threads, 32u)); +} + +// Thread-safe per-process cache. Supports multiple CUDA devices in one process. +inline int GetRMSNormMaxSharedMemoryPerBlockOptin() { + constexpr int kDefaultSmemLimit = 48 * 1024; + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess || device < 0) { + return kDefaultSmemLimit; + } + + constexpr int kMaxCachedDevices = 32; + static std::atomic max_smem_per_block[kMaxCachedDevices]{}; + + int cached = 0; + if (device < kMaxCachedDevices) { + cached = max_smem_per_block[device].load(std::memory_order_relaxed); + } + + if (cached == 0) { + int queried = 0; + if (cudaDeviceGetAttribute(&queried, cudaDevAttrMaxSharedMemoryPerBlockOptin, device) != + cudaSuccess) { + return kDefaultSmemLimit; + } + cached = queried; + if (device < kMaxCachedDevices) { + max_smem_per_block[device].store(cached, std::memory_order_relaxed); + } + } + return cached; +} + +template __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output, const uint32_t d, const uint32_t stride_input, const uint32_t stride_output, float weight_bias, float eps) { @@ -46,6 +115,8 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* const uint32_t num_threads = num_warps * warp_size; const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); extern __shared__ float smem[]; + const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u; + [[maybe_unused]] T* smem_input = reinterpret_cast(smem + smem_reduce_elems); float sum_sq = 0.f; @@ -56,8 +127,15 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* for (uint32_t i = 0; i < rounds; i++) { vec_t input_vec; input_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE; + if (vec_offset < d) { + input_vec.load(input + bx * stride_input + vec_offset); + if constexpr (CACHE_INPUT) { +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + smem_input[vec_offset + j] = input_vec[j]; + } + } } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { @@ -92,17 +170,21 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* vec_t output_vec; input_vec.fill(0.f); weight_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE; + if (vec_offset < d) { + if constexpr (CACHE_INPUT) { + input_vec.load(smem_input + vec_offset); + } else { + input_vec.load(input + bx * stride_input + vec_offset); + } + weight_vec.load(weight + vec_offset); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])); } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - output_vec.store(output + bx * stride_output + i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); + if (vec_offset < d) { + output_vec.store(output + bx * stride_output + vec_offset); } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -116,13 +198,18 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_ bool enable_pdl = false, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); + const uint32_t num_warps = GetRMSNormNumWarps(d, vec_size); dim3 nblks(batch_size); dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); + const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u; + const uint32_t smem_size_no_cache = smem_reduce_elems * sizeof(float); + const size_t smem_size_with_cache = + static_cast(smem_size_no_cache) + static_cast(d) * sizeof(T); + const int max_smem_per_block = GetRMSNormMaxSharedMemoryPerBlockOptin(); + const bool cache_input = smem_size_with_cache <= static_cast(max_smem_per_block); + const uint32_t smem_size = + cache_input ? static_cast(smem_size_with_cache) : smem_size_no_cache; float weight_bias = 0.f; - void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps}; cudaLaunchConfig_t config; config.gridDim = nblks; @@ -136,16 +223,28 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_ config.attrs = attrs; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = RMSNormKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input, - stride_output, weight_bias, eps)); + if (cache_input) { + auto kernel = RMSNormKernel; + if (smem_size > 48 * 1024) { + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, + stride_input, stride_output, weight_bias, eps)); + } else { + auto kernel = RMSNormKernel; + if (smem_size > 48 * 1024) { + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, + stride_input, stride_output, weight_bias, eps)); + } }); return cudaSuccess; } -template +template __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight, O* __restrict__ output, const uint32_t d, const uint32_t stride_input, const uint32_t stride_output, @@ -160,6 +259,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); const float scale_inv = 1.0f / scale; extern __shared__ float smem[]; + const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u; + [[maybe_unused]] T* smem_input = reinterpret_cast(smem + smem_reduce_elems); float sum_sq = 0.f; @@ -170,8 +271,15 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight for (uint32_t i = 0; i < rounds; i++) { vec_t input_vec; input_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE; + if (vec_offset < d) { + input_vec.load(input + bx * stride_input + vec_offset); + if constexpr (CACHE_INPUT) { +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + smem_input[vec_offset + j] = input_vec[j]; + } + } } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { @@ -206,9 +314,14 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight vec_t output_vec; input_vec.fill(0.f); weight_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + const uint32_t vec_offset = (i * num_threads + thread_id) * VEC_SIZE; + if (vec_offset < d) { + if constexpr (CACHE_INPUT) { + input_vec.load(smem_input + vec_offset); + } else { + input_vec.load(input + bx * stride_input + vec_offset); + } + weight_vec.load(weight + vec_offset); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { @@ -216,9 +329,8 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])) * scale_inv; output_vec[j] = fmaxf(-448.0f, fminf(output_vec[j], 448.0f)); } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - output_vec.cast_store(output + bx * stride_output + i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); + if (vec_offset < d) { + output_vec.cast_store(output + bx * stride_output + vec_offset); } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -232,11 +344,17 @@ cudaError_t RMSNormQuant(T* input, T* weight, O* output, uint32_t batch_size, ui float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); + const uint32_t num_warps = GetRMSNormNumWarps(d, vec_size); dim3 nblks(batch_size); dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); + const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u; + const uint32_t smem_size_no_cache = smem_reduce_elems * sizeof(float); + const size_t smem_size_with_cache = + static_cast(smem_size_no_cache) + static_cast(d) * sizeof(T); + const int max_smem_per_block = GetRMSNormMaxSharedMemoryPerBlockOptin(); + const bool cache_input = smem_size_with_cache <= static_cast(max_smem_per_block); + const uint32_t smem_size = + cache_input ? static_cast(smem_size_with_cache) : smem_size_no_cache; float weight_bias = 0.f; cudaLaunchConfig_t config; @@ -251,11 +369,25 @@ cudaError_t RMSNormQuant(T* input, T* weight, O* output, uint32_t batch_size, ui config.attrs = attrs; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = RMSNormQuantKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input, - stride_output, weight_bias, scale, eps)); + if (cache_input) { + auto kernel = RMSNormQuantKernel; + if (smem_size > 48 * 1024) { + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, + stride_input, stride_output, weight_bias, scale, + eps)); + } else { + auto kernel = RMSNormQuantKernel; + if (smem_size > 48 * 1024) { + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, + stride_input, stride_output, weight_bias, scale, + eps)); + } }); return cudaSuccess; } @@ -653,13 +785,18 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui bool enable_pdl = false, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); + const uint32_t num_warps = GetRMSNormNumWarps(d, vec_size); dim3 nblks(batch_size); dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); + const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u; + const uint32_t smem_size_no_cache = smem_reduce_elems * sizeof(float); + const size_t smem_size_with_cache = + static_cast(smem_size_no_cache) + static_cast(d) * sizeof(T); + const int max_smem_per_block = GetRMSNormMaxSharedMemoryPerBlockOptin(); + const bool cache_input = smem_size_with_cache <= static_cast(max_smem_per_block); + const uint32_t smem_size = + cache_input ? static_cast(smem_size_with_cache) : smem_size_no_cache; float weight_bias = 1.f; - void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps}; cudaLaunchConfig_t config; config.gridDim = nblks; @@ -673,11 +810,23 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui config.attrs = attrs; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = RMSNormKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input, - stride_output, weight_bias, eps)); + if (cache_input) { + auto kernel = RMSNormKernel; + if (smem_size > 48 * 1024) { + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, + stride_input, stride_output, weight_bias, eps)); + } else { + auto kernel = RMSNormKernel; + if (smem_size > 48 * 1024) { + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, + stride_input, stride_output, weight_bias, eps)); + } }); return cudaSuccess; }