diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index d8e81114559..90b6b9b29d1 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -2,28 +2,29 @@ #include "ggml-cuda/common.cuh" template -__global__ void gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs, - int64_t sq1, - int64_t sq2, - int64_t sq3, - int64_t sv1, - int64_t sv2, - int64_t sv3, - int64_t sb1, - int64_t sb2, - int64_t sb3, - int64_t rq1, - int64_t rq3, - float scale) { +__global__ void __launch_bounds__(S_v, 1) +gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + const int64_t H, + const int64_t n_tokens, + const int64_t n_seqs, + const int64_t sq1, + const int64_t sq2, + const int64_t sq3, + const int64_t sv1, + const int64_t sv2, + const int64_t sv3, + const int64_t sb1, + const int64_t sb2, + const int64_t sb3, + const int64_t rq1, + const int64_t rq3, + const float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; const int col = threadIdx.x; // each thread owns one column @@ -40,8 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // Load state column into registers +// CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229 +#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA) + extern __shared__ float s_shared[]; + float * s = s_shared + col * S_v; +#else float s[S_v]; +#endif #pragma unroll for (int i = 0; i < S_v; i++) { s[i] = curr_state[i * S_v + col]; @@ -130,24 +136,48 @@ static void launch_gated_delta_net( dim3 block_dims(S_v, 1, 1); switch (S_v) { - case 32: - gated_delta_net_cuda<32, KDA><<>>( + case 32: { + constexpr int sv = 32; +#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA) + constexpr size_t smem = sv * sv * sizeof(float); + CUDA_SET_SHARED_MEMORY_LIMIT((gated_delta_net_cuda), smem); +#else + constexpr size_t smem = 0; +#endif + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 64: - gated_delta_net_cuda<64, KDA><<>>( + } + case 64: { + constexpr int sv = 64; +#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA) + constexpr size_t smem = sv * sv * sizeof(float); + CUDA_SET_SHARED_MEMORY_LIMIT((gated_delta_net_cuda), smem); +#else + constexpr size_t smem = 0; +#endif + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 128: - gated_delta_net_cuda<128, KDA><<>>( + } + case 128: { + constexpr int sv = 128; +#if !defined(RDNA3) || !defined(RDNA4) || defined(GGML_USE_MUSA) + constexpr size_t smem = sv * sv * sizeof(float); + CUDA_SET_SHARED_MEMORY_LIMIT((gated_delta_net_cuda), smem); +#else + constexpr size_t smem = 0; +#endif + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; + } default: GGML_ABORT("fatal error"); break;