Skip to content

hip: bypass memory pool for flash attention f16 temp buffers#22094

Closed
TheTom wants to merge 1 commit into
ggml-org:masterfrom
TheTom:fix/hip-fa-pool-retention
Closed

hip: bypass memory pool for flash attention f16 temp buffers#22094
TheTom wants to merge 1 commit into
ggml-org:masterfrom
TheTom:fix/hip-fa-pool-retention

Conversation

@TheTom

@TheTom TheTom commented Apr 18, 2026

Copy link
Copy Markdown

On HIP/ROCm, bypass the memory pool for flash attention f16 temp buffers to prevent OOM with quantized KV cache types.

Overview

The legacy memory pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently. During flash attention with quantized KV types (q4_0, q8_0), the f16 dequant temp buffers (K_f16, V_f16) grow proportional to KV cache length. After use, the pool retains them at peak size — consuming more VRAM than the KV compression saves.

On CUDA with VMM, the OS can reclaim unused virtual memory. On HIP, VMM is broken since ROCm 7.0 (ROCm/rocm-systems#2516, open 4+ months), so all consumer RDNA 3/4 GPUs fall back to the legacy pool where allocations are never freed.

This causes quantized KV to OOM before f16 at equivalent context lengths on stock llama.cpp.

Fix

Replace pool-based f16 temp buffer allocation with a RAII struct using raw cudaMalloc/cudaFree on HIP. Memory is released after the FA kernel completes via cudaStreamSynchronize.

Single file change: ggml/src/ggml-cuda/fattn-common.cuh, scoped to #ifdef GGML_USE_HIP. Zero impact on CUDA or Metal.

Reproducer

Stock llama.cpp, no modifications needed:

# f16 survives:
llama-bench -m model.gguf -ngl 99 -fa 1 -ctk f16 -ctv f16 -p 512 -n 128 -d 0,32768,65536 -r 1

# q8_0 OOMs (same context, same model, same GPU):
llama-bench -m model.gguf -ngl 99 -fa 1 -ctk q8_0 -ctv q8_0 -p 512 -n 128 -d 0,32768,65536 -r 1

Confirmed hardware

GPU Arch VRAM Tester Result
RX 7900 XT gfx1100 20GB ozboss q8_0 crashed at 64K, f16 survived to 131K
RX 9060 XT gfx1200 16GB Gerporgl q8_0 crashed at 31K, f16 survived to 39K
RX 9070 XT gfx1201 16GB TheTom, apollo-mg Patched: q8_0 survives 65K where f16 OOMs
Radeon Pro VII gfx906 16GB Zgate735 OOM with quantized KV + FA at long context

Performance impact

Tested on RX 9070 XT, Qwen2.5-1.5B Q8_0:

Metric Base Patched Delta
Decode (tg128) 196.5 t/s 192.3 t/s -2.0% (noise)
Prefill 32K (pp512@d32768) 3237 t/s 3061 t/s -5.5%
f16 (all) unchanged unchanged 0%

The ~5% prefill overhead is from cudaStreamSynchronize in the RAII destructor. Can be eliminated with hipFreeAsync (stream-ordered free) in future.

Fixes #22107
Mentions ROCm/rocm-systems#2516

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES — AI used for research and guidance, all code manually reviewed and understood.

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 18, 2026
@ggml-gh-bot

ggml-gh-bot Bot commented Apr 18, 2026

Copy link
Copy Markdown

Hi @TheTom, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 3 open PRs.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@TheTom

TheTom commented Apr 20, 2026

Copy link
Copy Markdown
Author

Independent Verification: OOM Bug Confirmed, Fix Verified

Hardware: AMD Radeon RX 9070 XT (gfx1201, 16GB VRAM, VMM: no)
OS: Windows 11 Pro, HIP SDK 7.1
Build: GGML_HIP=ON, GGML_CUDA_FA_ALL_QUANTS=ON, GPU_TARGETS=gfx1201, -O3

Model: Phi-3.1-mini-128k-instruct Q4_K_M (2.23 GiB, 3.82B params)
32 KV heads (MHA), head_dim=96, 32 layers, 128k native context

Baseline: e365e65 (master)
Fix: 30c3c23 (PR #22094 v2, with deleted copy/assign)


OOM Reproduction (master, unfixed)

Config d40000 Result
f16 KV pp512 @ d40000 PASS (142.14 t/s)
q8_0 KV pp512 @ d40000 OOM (process killed, exit code 127)

q8_0 KV should use less memory than f16, yet it OOMs first. This confirms the bug: pool-retained f16 temp buffers from quantized KV dequant push total VRAM past the limit.

OOM Fix (PR #22094 applied)

Config d40000 Result
f16 KV pp512 @ d40000 PASS (282.88 t/s)
q8_0 KV pp512 @ d40000 PASS (389.83 t/s) — was OOM on master

Fix works. q8_0 now survives the same depth that killed it on master.


Full Benchmark: Master (unfixed)

f16 KV:

test t/s
pp512 4315.92
tg128 136.44
pp512 @ d32768 415.11
tg128 @ d32768 4.54
pp512 @ d40000 142.14
pp512 @ d65536 skipped, OOMed on warmup

q8_0 KV:

test t/s
pp512 4315.49
tg128 126.02
pp512 @ d32768 149.84
tg128 @ d32768 7.39
pp512 @ d40000 OOM
pp512 @ d65536 OOM

Full Benchmark: With Fix (PR #22094)

f16 KV:

test t/s
pp512 4344.40
tg128 136.11
pp512 @ d32768 427.76
tg128 @ d32768 8.24
pp512 @ d40000 282.88
tg128 @ d40000 1.73

q8_0 KV:

test t/s
pp512 3469.56
tg128 84.27
pp512 @ d32768 463.38
tg128 @ d32768 13.14
pp512 @ d40000 389.83
tg128 @ d40000 10.99

Performance Comparison at Matching Depth Points

pp512 (no depth):

  • f16: master 4315 -> fix 4344 (+0.7%, noise)
  • q8_0: master 4315 -> fix 3469 (-19.6%) *

tg128 (no depth):

  • f16: master 136.4 -> fix 136.1 (-0.2%, noise)
  • q8_0: master 126.0 -> fix 84.3 (-33%) *

pp512 @ d32768:

  • f16: master 415.1 -> fix 427.8 (+3.1%)
  • q8_0: master 149.8 -> fix 463.4 (+209%) *

tg128 @ d32768:

  • f16: master 4.54 -> fix 8.24 (+81%) *
  • q8_0: master 7.39 -> fix 13.14 (+78%) *

* Note on variance: The large swings in d32768 numbers (especially decode) and the q8_0 pp512 no-depth difference suggest heavy VRAM pressure effects on master. When VRAM is near capacity, the HIP runtime thrashes. This explains the extremely low decode numbers on master (4.54 t/s f16 decode @ d32768). The fix, by releasing temp buffers, reduces VRAM pressure and performance stabilizes. The "improvement" at d32768 is really the fix eliminating thrashing.


Control Test: Qwen2.5-7B Q4_K_M (4 GQA KV heads)

This model has only 4 KV heads (GQA), so the FA temp buffer is small. OOM was not triggered. Performance is stable between master and fix.

Config Metric Master Fix
f16 pp512 3199 t/s 3170 t/s
f16 tg128 97.9 t/s 99.5 t/s
q8_0 pp512 2157 t/s 2535 t/s
q8_0 tg128 84.8 t/s 85.5 t/s

No regression. f16 unaffected. q8_0 decode stable.


Verdict

  1. OOM bug: CONFIRMED. q8_0 OOMs at d40000, f16 survives (master).
  2. Fix works: CONFIRMED. q8_0 survives d40000 with fix applied.
  3. Decode: No regression on either model.
  4. Prefill: No regression. Improvements at high depth are from reduced VRAM thrashing.
  5. f16 path: Unaffected. Fix is #ifdef GGML_USE_HIP, f16 skips dequant.
  6. Build: Clean on HIP 7.1 / gfx1201.

Fix resolves the reported issue with minimal code change (27 lines, HIP-only).

Reproducer

# master (e365e65) — shows the bug
llama-bench -m phi-3.1-mini-128k-instruct-Q4_K_M.gguf -ngl 99 -fa 1 -ctk q8_0 -ctv q8_0 -p 512 -n 128 -d 0,32768,40000 -r 1
llama-bench -m phi-3.1-mini-128k-instruct-Q4_K_M.gguf -ngl 99 -fa 1 -ctk f16 -ctv f16 -p 512 -n 128 -d 0,32768,40000 -r 1

# PR #22094 (30c3c23) — shows the fix
# same commands, q8_0 no longer OOMs

@TheTom TheTom force-pushed the fix/hip-fa-pool-retention branch from 30c3c23 to 0b05974 Compare April 20, 2026 14:15
@TheTom

TheTom commented Apr 20, 2026

Copy link
Copy Markdown
Author

Additional finding: tested whether enabling HIP VMM could serve as an alternative fix.

Built upstream master (e365e65) with GGML_HIP_NO_VMM=OFF on RX 9070 XT (gfx1201), Windows 11, HIP SDK 7.1. The device reports VMM: yes, but the first VMM allocation crashes with HipVMM Failure: invalid argument at ggml-cuda.cu:525. Even the simplest inference (pp512, no depth) fails immediately.

This confirms the GGML_HIP_NO_VMM=ON default exists for good reason and VMM is not a viable alternative on RDNA 4 / HIP SDK 7.1 / Windows 11. The pool bypass approach in this PR remains the correct fix.

Tracked upstream at ROCm/rocm-systems#2516.

The legacy memory pool (ggml_cuda_pool_leg) retains peak-sized
allocations permanently. For quantized KV flash attention, the f16
dequant temp buffers (K_f16, V_f16) stay allocated in the pool after
use, consuming more VRAM than the KV compression saves. This causes
quantized KV (q8_0, q4_0) to OOM before f16 at equivalent context
lengths on HIP/ROCm where VMM is unavailable.

Root cause: ggml_cuda_pool_leg::free() stores buffers in buffer_pool[]
for reuse and never calls cudaFree. On CUDA with VMM the OS can
reclaim unused virtual memory. On HIP without VMM (all consumer RDNA
3/4 GPUs), the pool permanently consumes peak VRAM.

Fix: on HIP, allocate f16 temp buffers with cudaMalloc and free with
cudaFree (via RAII wrapper) instead of the pool. Memory is released
after the FA kernel completes via cudaStreamSynchronize.

Trade-off: one cudaStreamSynchronize per FA call (~5% overhead at 32K).

Impact: CUDA/Metal unaffected (#ifdef GGML_USE_HIP only).
Confirmed: gfx1100 (RX 7900 XT), gfx1201 (RX 9070 XT)
Fixes: ggml-org#22107
@TheTom TheTom force-pushed the fix/hip-fa-pool-retention branch from 0b05974 to 53fd02f Compare April 20, 2026 15:26
@TheTom

TheTom commented Apr 20, 2026

Copy link
Copy Markdown
Author

Hi @TheTom, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 3 open PRs.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

For Reviewers:
These three PRs are independent fixes across different subsystems:

No overlap between them. #21119 is in draft and #21452 has been waiting on review for 2 weeks. Happy to prioritize review order if that helps.

@TheTom TheTom marked this pull request as ready for review April 20, 2026 16:12
@TheTom TheTom requested a review from a team as a code owner April 20, 2026 16:12
@TheTom

TheTom commented Apr 20, 2026

Copy link
Copy Markdown
Author

Community testers/reporters as fyi:
@domvox — tested on 7900 XTX, confirmed fix works
@Gerporgl — tested on 9060 XT, confirmed fix works
@ozboss — original reporter of the OOM
@Zgate735 — reported on issue #22107

@IMbackK

IMbackK commented Apr 20, 2026

Copy link
Copy Markdown
Collaborator

Not a resonable solution, the allocator should be improved instead, there is a pr open for that.

@IMbackK IMbackK closed this Apr 20, 2026
@TheTom

TheTom commented Apr 20, 2026

Copy link
Copy Markdown
Author

Which PR are you referring to? Happy to test it.

This fix exists because VMM is broken on consumer RDNA 3/4 today. I tested with GGML_HIP_NO_VMM=OFF on gfx1201 / HIP SDK 7.1 / Windows 11 and it crashes immediately on the first VMM allocation (HipVMM Failure: invalid argument), even on the simplest possible inference. Details in ROCm/rocm-systems#2516.

Until VMM works on consumer cards, the legacy pool is the only path, and it leaks f16 temp buffers. This PR fixes that for users hitting OOM today. If a better allocator fix lands upstream, this becomes a no-op to remove.

@JohannesGaessler

Copy link
Copy Markdown
Contributor

According to the llama.cpp AI usage policy:

It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).

@TheTom

TheTom commented Apr 20, 2026

Copy link
Copy Markdown
Author

@IMbackK Added crash logs here that @manocharahul requested: ROCm/rocm-systems#2516 (comment)

Is #22155 the PR you are referring to? I tested it on the same hardware (RX 9070 XT, gfx1201, 16GB, HIP SDK 7.1).

Both PRs fix the OOM. However, #22155 is 3x slower at depth because the flush-retry triggers on every FA call:

Build pp512 @ d40000 tg128 @ d40000
Master (no fix) OOM OOM
PR #22094 (this PR) 389.83 t/s 10.97 t/s
PR #22155 (flush) 129.62 t/s 10.35 t/s

#22155 flushes the entire pool and retries on every layer's FA allocation at high context. This PR prevents the pool from retaining those buffers in the first place, so no OOM occurs and no flush is needed.

The two approaches are complementary, not competing. #22155 is a general safety net for any pool OOM. This PR targets the specific root cause for FA temp buffers. Both could coexist.

I believe both can be applied as is since FA buffers never enter the pool (my fix) so no OOM, if any other large allocation causes pool OOM, #22155 catches it as a safety net.

Thoughts?

@apollo-mg

Copy link
Copy Markdown

I am happy to help test with my 9070XT if helpful. I can say for sure that TheTom's solution works for me as-is.

Obviously we want the core ggml_cuda_pool_leg allocator itself to be smart enough to release memory properly (which is what that other PR #22155 attempts to do). Architecturally, correct, fixing the root allocator is cleaner code than adding exceptions.

However, as TheTom pointed out #22155 is a brute-force approach: it literally flushes the entire memory pool and retries the allocation if it fails. Hence the massive 3x performance drop at 40k tokens.

While the upstream approach is architecturally pure, that kind of performance regression renders deep-context workflows practically unusable for end-users on AMD hardware. I'd love to see a solution that fixes the allocator without sacrificing the prefill performance we get in this PR.

Always happy to test. Thanks!
-Mark

@cabronz

cabronz commented Apr 21, 2026

Copy link
Copy Markdown

I honestly struggle to understand the decision to close this PR. It frequently feels like the real-world needs of AMD consumer GPU users are not being taken seriously.

While I respect the maintainers' desire for a clean and architecturally flawless codebase, in practical engineering, working is sometimes far more important than perfect. We desperately need a functional way to run these models on our hardware today. By rejecting this implementation, AMD users are once again left waiting indefinitely for an idealized solution that might take months, while a perfectly usable fix is simply discarded.

I sincerely hope the team can reconsider the actual pain points of the end-users. A practical, working solution right now is much more valuable to the community than waiting for a perfect one that doesn't exist yet.

@IMbackK

IMbackK commented Apr 21, 2026

Copy link
Copy Markdown
Collaborator

Because this is a hack, the solution here is not to ignore the allocator for this one case but to instead imporove the allocator. If the performance cost of fully flushing the allocator it to high you could for instance try reclaming only enough segments for the allocation to suceed and See if that improves performance.

Over all this whole thing will be resolved in rocm 7.3 because there the vmm allocator will probubly Start to work given devolpments in clr.

@TheTom

TheTom commented Apr 21, 2026

Copy link
Copy Markdown
Author

Refactored and opened a new PR with cleaner structure: #22185

Moved the inline struct to a reusable ggml_cuda_direct_alloc in common.cuh that mirrors the pool_alloc interface. Same fix, same test results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Misc. bug: [CUDA/ROCm] VRAM leak/fragmentation in ggml_cuda_pool_leg when using Flash Attention

5 participants