Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 66 additions & 35 deletions ggml/src/ggml-cuda/gated_delta_net.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"

template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
Expand All @@ -21,15 +20,17 @@ __global__ void gated_delta_net_cuda(const float * q,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t neqk1,
int64_t rq3,
const uint3 neqk1_magic,
const uint3 rq3_magic,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
// each warp owns one column, using warp-level primitives to reduce across rows
const int lane = threadIdx.x;
const int col = blockIdx.z * blockDim.y + threadIdx.y;

const int64_t iq1 = h_idx % neqk1;
const int64_t iq3 = sequence / rq3;
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);

const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
Expand All @@ -40,11 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;

// Load state column into registers
float s[S_v];
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
}

for (int t = 0; t < n_tokens; t++) {
Expand All @@ -62,55 +66,71 @@ __global__ void gated_delta_net_cuda(const float * q,
const float g_val = expf(*g_t);

// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += s[i] * k_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);

// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;

// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = g_val * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}

attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum<warp_size>(attn_partial);

if (lane == 0) {
attn_data[col] = attn_col * scale;
}
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += expf(g_t[i]) * s[i] * k_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}

float kv_col = warp_reduce_sum<warp_size>(kv_shard);

// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;

// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}

attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum<warp_size>(attn_partial);

if (lane == 0) {
attn_data[col] = attn_col * scale;
}
}

attn_data += S_v * H;
}

// Write state back to global memory
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = s[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
}
}

Expand All @@ -125,28 +145,39 @@ static void launch_gated_delta_net(
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int num_warps = 4;
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);

dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
const uint3 rq3_magic = init_fastdiv_values(rq3);

switch (S_v) {
case 16:
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
default:
GGML_ABORT("fatal error");
Expand Down
4 changes: 1 addition & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4999,9 +4999,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifdef GGML_USE_MUSA
return false;
#else
// KDA is faster using the AR kernel even when n_tokens >= 512
//TODO: Add chunked kernel
return op->src[0]->ne[2] == 1 || op->src[3]->ne[0] == op->src[2]->ne[0];
return true;
#endif // GGML_USE_MUSA
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
Expand Down
5 changes: 5 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8451,6 +8451,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}

test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
Expand All @@ -8460,10 +8463,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// KDA (vector gate)
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true));

#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
Expand Down
Loading