From a4273ef2bc12e3e1e702390b19cb29703ff48960 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 29 May 2026 16:24:21 +0800 Subject: [PATCH 1/4] cuda: reserve space for quantize kv-cache at startup --- ggml/src/ggml-cuda/fattn-common.cuh | 28 ++++++++++------ ggml/src/ggml-cuda/fattn.cu | 52 +++++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn.cuh | 2 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++-- 4 files changed, 76 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d650b5fbd0f..f1805fd82a7 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -952,8 +952,14 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc K_f16(pool); - ggml_cuda_pool_alloc V_f16(pool); + char * extra_data = (char *) KQV->data + ggml_nbytes(KQV); + auto reserve_extra_f16 = [&](const ggml_tensor * tensor) { + extra_data = (char *) GGML_PAD((uintptr_t) extra_data, 128); + half * ptr = (half *) extra_data; + extra_data += ggml_nelements(tensor)*sizeof(half); + return ptr; + }; + ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); @@ -972,10 +978,10 @@ void launch_fattn( const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - K_f16.alloc(ggml_nelements(K)); + half * K_f16 = reserve_extra_f16(K); if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + to_fp16(K_data, K_f16, ggml_nelements(K), main_stream); nb11 = nb11*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts; @@ -986,13 +992,13 @@ void launch_fattn( const int64_t s01 = nb11 / ts; const int64_t s02 = nb12 / ts; const int64_t s03 = nb13 / ts; - to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); nb11 = K->ne[0] * sizeof(half); nb12 = K->ne[1] * nb11; nb13 = K->ne[2] * nb12; } - K_data = (char *) K_f16.ptr; + K_data = (char *) K_f16; } if (need_f16_V && V->type != GGML_TYPE_F16) { @@ -1005,11 +1011,11 @@ void launch_fattn( const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - V_f16.alloc(ggml_nelements(V)); + half * V_f16 = reserve_extra_f16(V); if (ggml_is_contiguously_allocated(V)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; + to_fp16(V_data, V_f16, ggml_nelements(V), main_stream); + V_data = (char *) V_f16; nb21 = nb21*bs*sizeof(half)/ts; nb22 = nb22*bs*sizeof(half)/ts; @@ -1020,13 +1026,13 @@ void launch_fattn( const int64_t s01 = nb21 / ts; const int64_t s02 = nb22 / ts; const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); nb21 = V->ne[0] * sizeof(half); nb22 = V->ne[1] * nb21; nb23 = V->ne[2] * nb22; } - V_data = (char *) V_f16.ptr; + V_data = (char *) V_f16; } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1c7777e8a71..dea0ccaa915 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -537,6 +537,58 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_TILE; } +static size_t ggml_cuda_flash_attn_ext_reserve_f16_buffer(size_t size, const ggml_tensor * tensor) { + size = GGML_PAD(size, 128); + return size + ggml_nelements(tensor)*ggml_type_size(GGML_TYPE_F16); +} + +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) { + size_t size = ggml_nbytes(dst); + + if (dst->op != GGML_OP_FLASH_ATTN_EXT) { + return size; + } + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + if (K == nullptr || V == nullptr) { + return size; + } + + const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst); + + bool need_f16_K = false; + bool need_f16_V = false; + + switch (kernel) { + case BEST_FATTN_KERNEL_TILE: + case BEST_FATTN_KERNEL_WMMA_F16: + case BEST_FATTN_KERNEL_MMA_F16: + need_f16_K = true; + need_f16_V = true; + break; + case BEST_FATTN_KERNEL_VEC: + need_f16_K = K->type == GGML_TYPE_F32; + need_f16_V = V->type == GGML_TYPE_F32; + break; + case BEST_FATTN_KERNEL_NONE: + break; + } + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + if (need_f16_K && K->type != GGML_TYPE_F16) { + size = ggml_cuda_flash_attn_ext_reserve_f16_buffer(size, K); + } + + if (need_f16_V && V->type != GGML_TYPE_F16 && !V_is_K_view) { + size = ggml_cuda_flash_attn_ext_reserve_f16_buffer(size, V); + } + + return size; +} + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_set_device(ctx.device); switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index 78705d59951..f9a7e15fbd6 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -3,3 +3,5 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 18aaa098398..f5293ad4cbb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -801,7 +801,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty } static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - size_t size = ggml_nbytes(tensor); + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context; + + size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT + ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor) + : ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; if (ggml_is_quantized(tensor->type)) { @@ -812,8 +816,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t } return size; - - GGML_UNUSED(buft); } static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { From 9f584d3901d9c12aafea341e71655a44e8d63595 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 31 May 2026 12:56:09 +0800 Subject: [PATCH 2/4] address review comments --- ggml/src/ggml-cuda/common.cuh | 2 ++ ggml/src/ggml-cuda/fattn-common.cuh | 56 ++++++++++++++++++++++++----- ggml/src/ggml-cuda/fattn.cu | 29 ++++----------- 3 files changed, 55 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 50d7763dcdd..f3fa04a7fce 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -110,6 +110,8 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); + // PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8. // However, this has been bugged in CTK < 12.3 for MSVC builds, see // https://github.com/ggml-org/llama.cpp/pull/22522#discussion_r3302393293 diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index f1805fd82a7..90f44573d7b 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)( typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +struct ggml_cuda_flash_attn_ext_f16_extra_data { + uintptr_t K; + uintptr_t V; + uintptr_t end; +}; + +static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data( + const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + ggml_cuda_flash_attn_ext_f16_extra_data data = {}; + data.end = (uintptr_t) dst->data + ggml_nbytes(dst); + + if (need_f16_K && K->type != GGML_TYPE_F16) { + data.end = GGML_PAD(data.end, 128); + data.K = data.end; + data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16); + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + data.V = data.K; + } else { + data.end = GGML_PAD(data.end, 128); + data.V = data.end; + data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16); + } + } + + return data; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -952,13 +992,9 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - char * extra_data = (char *) KQV->data + ggml_nbytes(KQV); - auto reserve_extra_f16 = [&](const ggml_tensor * tensor) { - extra_data = (char *) GGML_PAD((uintptr_t) extra_data, 128); - half * ptr = (half *) extra_data; - extra_data += ggml_nelements(tensor)*sizeof(half); - return ptr; - }; + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V); + GGML_ASSERT(f16_extra.end <= (uintptr_t) KQV->data + ggml_cuda_flash_attn_ext_get_alloc_size(id, KQV)); ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); @@ -978,7 +1014,8 @@ void launch_fattn( const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - half * K_f16 = reserve_extra_f16(K); + GGML_ASSERT(f16_extra.K != 0); + half * K_f16 = (half *) f16_extra.K; if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); to_fp16(K_data, K_f16, ggml_nelements(K), main_stream); @@ -1011,7 +1048,8 @@ void launch_fattn( const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - half * V_f16 = reserve_extra_f16(V); + GGML_ASSERT(f16_extra.V != 0); + half * V_f16 = (half *) f16_extra.V; if (ggml_is_contiguously_allocated(V)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); to_fp16(V_data, V_f16, ggml_nelements(V), main_stream); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index dea0ccaa915..d6c501b1d7e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -537,24 +537,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_TILE; } -static size_t ggml_cuda_flash_attn_ext_reserve_f16_buffer(size_t size, const ggml_tensor * tensor) { - size = GGML_PAD(size, 128); - return size + ggml_nelements(tensor)*ggml_type_size(GGML_TYPE_F16); -} - size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) { - size_t size = ggml_nbytes(dst); - - if (dst->op != GGML_OP_FLASH_ATTN_EXT) { - return size; - } + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - if (K == nullptr || V == nullptr) { - return size; - } + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst); @@ -576,17 +566,10 @@ size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * d break; } - const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); - - if (need_f16_K && K->type != GGML_TYPE_F16) { - size = ggml_cuda_flash_attn_ext_reserve_f16_buffer(size, K); - } - - if (need_f16_V && V->type != GGML_TYPE_F16 && !V_is_K_view) { - size = ggml_cuda_flash_attn_ext_reserve_f16_buffer(size, V); - } + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V); - return size; + return f16_extra.end - (uintptr_t) dst->data; } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From 32e6898943c4e1f40ad7ca37acedad464d05740d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 31 May 2026 16:12:45 +0800 Subject: [PATCH 3/4] remove forward decl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index f3fa04a7fce..50d7763dcdd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -110,8 +110,6 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 -size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); - // PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8. // However, this has been bugged in CTK < 12.3 for MSVC builds, see // https://github.com/ggml-org/llama.cpp/pull/22522#discussion_r3302393293 From b3249879e5fa063708f9e420c5e58767ef48d283 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 31 May 2026 17:18:10 +0800 Subject: [PATCH 4/4] remove assert in ggml-cuda.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-common.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 90f44573d7b..064f753f7ef 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -994,7 +994,6 @@ void launch_fattn( const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V); - GGML_ASSERT(f16_extra.end <= (uintptr_t) KQV->data + ggml_cuda_flash_attn_ext_get_alloc_size(id, KQV)); ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool);