diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9516d8ec8f..3a268dd443 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -526,6 +526,50 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #endif // FP16_AVAILABLE } +enum warp_reduce_method { + WARP_REDUCE_MAX, + WARP_REDUCE_SUM, +}; + +template static __device__ float two_stage_warp_reduce(float val, float * shared_vals) { + + float (*reduce_fun)(float); + switch (reduce_method) { + case WARP_REDUCE_MAX: + reduce_fun = warp_reduce_max; + break; + case WARP_REDUCE_SUM: + reduce_fun = warp_reduce_sum; + break; + } + + val = reduce_fun(val); + const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + if (block_size > WARP_SIZE) { + assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + __syncthreads(); + switch (reduce_method) { + case WARP_REDUCE_MAX: + val = -INFINITY; + break; + case WARP_REDUCE_SUM: + val = 0.0f; + break; + } + if (lane_id < (static_cast(block_size) / WARP_SIZE)) { + val = shared_vals[lane_id]; + } + return reduce_fun(val); + } else { + return val; + } +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 6bcae9e52f..6ecc893df0 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } // sum up partial sums - sum = warp_reduce_sum(sum); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = sum; - } - __syncthreads(); - sum = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - sum = s_sum[lane_id]; - } - sum = warp_reduce_sum(sum); - } + __shared__ float shared_vals[32]; + sum = two_stage_warp_reduce(sum, shared_vals); if (col != 0) { return; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 1ae84ebf63..a646650c65 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -75,9 +75,6 @@ static __global__ void soft_max_f32( const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; @@ -102,21 +99,7 @@ static __global__ void soft_max_f32( } // find the max value in the block - max_val = warp_reduce_max(max_val); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = max_val; - } - __syncthreads(); - - max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); - } + max_val = two_stage_warp_reduce(max_val, buf_iw); float tmp = 0.0f; // partial sum @@ -134,22 +117,7 @@ static __global__ void soft_max_f32( } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __syncthreads(); - if (warp_id == 0) { - buf_iw[lane_id] = 0.0f; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = tmp; - } - __syncthreads(); - - tmp = buf_iw[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = two_stage_warp_reduce(tmp, buf_iw); if (sinks) { tmp += expf(sinks[i02] - max_val); @@ -169,50 +137,6 @@ static __global__ void soft_max_f32( } } - -// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated -static __device__ float two_stage_warp_reduce_max(float val) { - val = warp_reduce_max(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = -INFINITY; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_max(val); - } else { - return val; - } -} - -static __device__ float two_stage_warp_reduce_sum(float val) { - val = warp_reduce_sum(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_sum(val); - } else { - return val; - } -} - // TODO: Template to allow keeping ncols in registers if they fit static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, @@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; float local_max = -INFINITY; const int step_size = gridDim.x * blockDim.x; + __shared__ float shared_vals[32]; // Compute thread-local max for (int col = col_start; col < p.ncols;) { @@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Compute CTA-level max - local_max = two_stage_warp_reduce_max(local_max); + local_max = two_stage_warp_reduce(local_max, shared_vals); // Store CTA-level max to GMEM if (tid == 0) { @@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { local_max = -INFINITY; } - local_max = two_stage_warp_reduce_max(local_max); + local_max = two_stage_warp_reduce(local_max, shared_vals); // Compute softmax dividends, accumulate divisor float tmp_expf = 0.0f; @@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Reduce divisor within CTA - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = two_stage_warp_reduce(tmp_expf, shared_vals); // Store CTA-level sum to GMEM if (tid == 0) { @@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { tmp_expf = 0.0f; } - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = two_stage_warp_reduce(tmp_expf, shared_vals); // Divide dividend by global sum + store data for (int col = col_start; col < p.ncols;) {