Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions ggml/src/ggml-cuda/delta-net.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ __global__ void delta_net_recurrent_f32(
}

constexpr int HEAD_DIM_S = HEAD_DIM + 1;
__shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS];
constexpr int num_stored_rows = block_size >= HEAD_DIM && block_size % HEAD_DIM == 0 ? block_size/HEAD_DIM : NUM_WARPS;
__shared__ float all_sum[2*HEAD_DIM_S*num_stored_rows];
auto all_sum1 = all_sum;
auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;
auto all_sum2 = all_sum1 + HEAD_DIM_S*num_stored_rows;

// Process each token sequentially
for (int64_t t = 0; t < n_tokens; t++) {
Expand All @@ -116,39 +117,68 @@ __global__ void delta_net_recurrent_f32(
float beta_val = sigmoid_f(beta_ptr[t]);
float decay = expf(fminf(g_ptr[t], 50.0f));

for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
if constexpr (block_size >= HEAD_DIM && block_size % HEAD_DIM == 0) {
int idx = tid / HEAD_DIM;
int row_out = tid % HEAD_DIM;
float sum1 = 0, sum2 = 0;
#pragma unroll
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
for (int col = idx; col < HEAD_DIM; col += block_size/HEAD_DIM) {
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();

for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
all_sum1[idx*HEAD_DIM_S + row_out] = sum1;
all_sum2[idx*HEAD_DIM_S + row_out] = sum2;

__syncthreads();

if (idx == 0) {
#pragma unroll
for (int i = 1; i < block_size/HEAD_DIM; ++i) {
sum1 += all_sum1[i*HEAD_DIM_S + row_out];
sum2 += all_sum2[i*HEAD_DIM_S + row_out];
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
__syncthreads();
} else {
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
#pragma unroll
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();

for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
__syncthreads();
}
__syncthreads();

for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
float k_col = sK[out_dim];
#pragma unroll
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
float state_val = state_dst[row + out_dim * HEAD_DIM];
float new_state_val = decay * state_val + sVNew[row] * sK[out_dim];
float new_state_val = decay * state_val + sVNew[row] * k_col; //sK[out_dim];
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
}
Expand Down