Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,50 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#endif // FP16_AVAILABLE
}

enum warp_reduce_method {
WARP_REDUCE_MAX,
WARP_REDUCE_SUM,
};

template <warp_reduce_method reduce_method, const unsigned int block_size_template = 0> static __device__ float two_stage_warp_reduce(float val, float * shared_vals) {

float (*reduce_fun)(float);
switch (reduce_method) {
case WARP_REDUCE_MAX:
reduce_fun = warp_reduce_max;
break;
case WARP_REDUCE_SUM:
reduce_fun = warp_reduce_sum;
break;
}

val = reduce_fun(val);
const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
if (block_size > WARP_SIZE) {
assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
shared_vals[warp_id] = val;
}
__syncthreads();
switch (reduce_method) {
case WARP_REDUCE_MAX:
val = -INFINITY;
break;
case WARP_REDUCE_SUM:
val = 0.0f;
break;
}
if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
val = shared_vals[lane_id];
}
return reduce_fun(val);
} else {
return val;
}
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE

Expand Down
18 changes: 2 additions & 16 deletions ggml/src/ggml-cuda/reduce_rows.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
}

// sum up partial sums
sum = warp_reduce_sum(sum);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}
__shared__ float shared_vals[32];
sum = two_stage_warp_reduce<WARP_REDUCE_SUM>(sum, shared_vals);

if (col != 0) {
return;
Expand Down
89 changes: 7 additions & 82 deletions ggml/src/ggml-cuda/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ static __global__ void soft_max_f32(

const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;

const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);

extern __shared__ float data_soft_max_f32[];
Expand All @@ -102,21 +99,7 @@ static __global__ void soft_max_f32(
}

// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();

max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
max_val = two_stage_warp_reduce<WARP_REDUCE_MAX, block_size_template>(max_val, buf_iw);

float tmp = 0.0f; // partial sum

Expand All @@ -134,22 +117,7 @@ static __global__ void soft_max_f32(
}

// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();

tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
tmp = two_stage_warp_reduce<WARP_REDUCE_SUM, block_size_template>(tmp, buf_iw);

if (sinks) {
tmp += expf(sinks[i02] - max_val);
Expand All @@ -169,50 +137,6 @@ static __global__ void soft_max_f32(
}
}


// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
static __device__ float two_stage_warp_reduce_max(float val) {
val = warp_reduce_max(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = -INFINITY;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_max(val);
} else {
return val;
}
}

static __device__ float two_stage_warp_reduce_sum(float val) {
val = warp_reduce_sum(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_sum(val);
} else {
return val;
}
}

// TODO: Template to allow keeping ncols in registers if they fit
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
Expand All @@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
__shared__ float shared_vals[32];

// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
Expand All @@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
}

// Compute CTA-level max
local_max = two_stage_warp_reduce_max(local_max);
local_max = two_stage_warp_reduce<WARP_REDUCE_MAX>(local_max, shared_vals);

// Store CTA-level max to GMEM
if (tid == 0) {
Expand All @@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
} else {
local_max = -INFINITY;
}
local_max = two_stage_warp_reduce_max(local_max);
local_max = two_stage_warp_reduce<WARP_REDUCE_MAX>(local_max, shared_vals);

// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
Expand All @@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
}

// Reduce divisor within CTA
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
tmp_expf = two_stage_warp_reduce<WARP_REDUCE_SUM>(tmp_expf, shared_vals);

// Store CTA-level sum to GMEM
if (tid == 0) {
Expand All @@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
} else {
tmp_expf = 0.0f;
}
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
tmp_expf = two_stage_warp_reduce<WARP_REDUCE_SUM>(tmp_expf, shared_vals);

// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {
Expand Down