Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,37 @@ void launch_fattn(
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

#ifdef GGML_USE_HIP
// HIP/ROCm: bypass the memory pool for f16 temp buffers.
// The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently
// because free() stores buffers for reuse rather than releasing them.
// On HIP without VMM support (RDNA 3/4), this means the f16 dequant temp buffers
// for quantized KV stay allocated after use, consuming more VRAM than the KV
// compression saves — causing OOM before f16 at equivalent context lengths.
// Using raw cudaMalloc/cudaFree ensures memory is released after the kernel completes.
// Ref: https://github.com/ggml-org/llama.cpp/issues/22107
struct hip_f16_alloc {
half * ptr = nullptr;
cudaStream_t stream;
hip_f16_alloc(cudaStream_t s) : stream(s) {}
hip_f16_alloc(const hip_f16_alloc &) = delete;
hip_f16_alloc & operator=(const hip_f16_alloc &) = delete;
~hip_f16_alloc() {
if (ptr) {
cudaStreamSynchronize(stream);
cudaFree(ptr);
}
}
void alloc(size_t nelements) {
CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half)));
}
};
hip_f16_alloc K_f16(main_stream);
hip_f16_alloc V_f16(main_stream);
#else
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
#endif
ggml_cuda_pool_alloc<int> KV_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
Expand Down