Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 59 additions & 29 deletions ggml/src/ggml-cuda/gated_delta_net.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,29 @@
#include "ggml-cuda/common.cuh"

template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t rq3,
float scale) {
__global__ void __launch_bounds__(S_v, 1)
gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
const int64_t H,
const int64_t n_tokens,
const int64_t n_seqs,
const int64_t sq1,
const int64_t sq2,
const int64_t sq3,
const int64_t sv1,
const int64_t sv2,
const int64_t sv3,
const int64_t sb1,
const int64_t sb2,
const int64_t sb3,
const int64_t rq1,
const int64_t rq3,
const float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
Expand All @@ -40,8 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;

// Load state column into registers
// CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA)
extern __shared__ float s_shared[];
float * s = s_shared + col * S_v;
#else
float s[S_v];
#endif
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
Expand Down Expand Up @@ -130,24 +136,48 @@ static void launch_gated_delta_net(
dim3 block_dims(S_v, 1, 1);

switch (S_v) {
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
case 32: {
constexpr int sv = 32;
#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA)
constexpr size_t smem = sv * sv * sizeof(float);
CUDA_SET_SHARED_MEMORY_LIMIT((gated_delta_net_cuda<sv, KDA>), smem);
#else
constexpr size_t smem = 0;
#endif
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 64: {
constexpr int sv = 64;
#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA)
constexpr size_t smem = sv * sv * sizeof(float);
CUDA_SET_SHARED_MEMORY_LIMIT((gated_delta_net_cuda<sv, KDA>), smem);
#else
constexpr size_t smem = 0;
#endif
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 128: {
constexpr int sv = 128;
#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA)
constexpr size_t smem = sv * sv * sizeof(float);
CUDA_SET_SHARED_MEMORY_LIMIT((gated_delta_net_cuda<sv, KDA>), smem);
#else
constexpr size_t smem = 0;
#endif
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
}
default:
GGML_ABORT("fatal error");
break;
Expand Down
Loading