diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 31b1e97276..6814e892d1 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -369,7 +369,8 @@ cudaError_t QKRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint3 FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - dim3 nblks(num_blocks_per_sm * num_sms); + const int needed_blocks = ceil_div(batch_size * num_heads, num_warps); + dim3 nblks(std::min(num_blocks_per_sm * num_sms, needed_blocks)); dim3 nthrs(32, num_warps); config.gridDim = nblks; config.blockDim = nthrs;