Skip to content

Improve CUDA graph capture#19754

Merged
am17an merged 4 commits intoggml-org:masterfrom
gaugarg-nv:cuda_graph_fix
Feb 21, 2026
Merged

Improve CUDA graph capture#19754
am17an merged 4 commits intoggml-org:masterfrom
gaugarg-nv:cuda_graph_fix

Conversation

@gaugarg-nv
Copy link
Contributor

@gaugarg-nv gaugarg-nv commented Feb 20, 2026

Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because:

  • The first call always incurs CUDA graph capture overhead even if the graph is unstable
  • Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode)

The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize. This also fixes issues such as #19708

Perf improvement for Llama-8b-Q4_K_M on RTX 6000 Ada 300W:

  Master PR Speed-up
pp2100+tg250 1042.83 1164.66 1.12
tg200 @ d4096 124.43 138.3 1.11
pp4096 7928.06 7940.24 1.0

Nsight profile Master:
image

Nsight profile PR:
image

Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because:

- The first call always incurs CUDA graph capture overhead even if the graph is unstable
- Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode)

The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize.
This also fixes issues such as ggml-org#19708
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 20, 2026
@gaugarg-nv
Copy link
Contributor Author

@JohannesGaessler
Copy link
Contributor

How often does a CUDA graph need to be evaluated to offset the overhead? If that number is low it may make more sense to just enable them more generally since we can now cache more than a single one.

@JohannesGaessler
Copy link
Contributor

Actually, I think the problem was that due to the increasing context size CUDA graphs were not reusable for the prefill phase so it didn't make sense to use them.

@gaugarg-nv
Copy link
Contributor Author

Actually, I think the problem was that due to the increasing context size CUDA graphs were not reusable for the prefill phase so it didn't make sense to use them.

Exactly. I think PR #19645 enabled it for pre-fill phase, but it was causing CUDA graph to get disabled altogether after 4 consecutive updates.

@ggerganov
Copy link
Member

Actually, I think the problem was that due to the increasing context size CUDA graphs were not reusable for the prefill phase so it didn't make sense to use them.

Exactly. I think PR #19645 enabled it for pre-fill phase, but it was causing CUDA graph to get disabled altogether after 4 consecutive updates.

Ah yes, thanks for flagging that - I missed it.

@ORippler
Copy link
Collaborator

ORippler commented Feb 20, 2026

Regarding the CUDA Graph logic, we currently have the following state in the code:

  1. ggml_cgraph has not changed -> launch existing cudaGraphExec_t
  2. ggml_cgraph has changed -> capture new cudaGraph_t -> try update existing cudaGraphExec_t from captured cudaGraph_t->:
    • a) update suceeds -> launch updated cudaGraphExec_t
    • b) update fails -> destroy old cudaGraphExec_t -> build new cudaGraphExec_t -> launch

The counter + disablement logic fires on both 2a) and 2b), whereas 2a) has performance parity to using cudaLaunch API over cudaGraph API, see the following numbers (tg 200 on a B6000):

Model Force 2b) Force 2a) Current Logic Graphs Disabled (cudaLaunch API)
gpt-oss 20B MXFP4 223.67 287.31 338.76 286.75
qwen3next 80B Q4_K 74.05 104.85 114.27 99.89
llama 3B Q4_K 264.47 348.94 412.15 336.23

For 2a/2b, I made ggml_cuda_graph_update_required always return true, as we can save CPU overhead that way.
I also evaluated on Linux, as cudaLaunch API is more efficient here compared to Windows and its the tougher baseline to beat.

From this data, I would recommend the following:

  1. Exclude 2a) from the update-heuristic/a newly-introduced warm-up behavior, and trigger only on 2b)
  2. Emit a warning when 2b) happens repeatedly, as this points to inefficiencies/bugs either in the CUDA backend or in llama-context/the orchestration loop that dispatches into the backend. For example, I've observed 2b) to occur on Qwen3-Coder-Next, which uses llama_memory_hybrid_context. Consequentially, it yields two different cgraph topologies for the two first calls upon each context update (this is an optimization opportunity inside llama-context I'd say).

@ORippler
Copy link
Collaborator

ORippler commented Feb 20, 2026

#19757 This is a WIP based on my insights above, which would enable cudaGraphs for PP on dense models that do not use llama_memory_hybrid_context (this could actually be perf-positive on Windows due to launch-overheads, will collect some numbers)

@gaugarg-nv
Copy link
Contributor Author

gaugarg-nv commented Feb 20, 2026

#19757 This is a WIP based on my insights above, which would enable cudaGraphs for PP on dense models that do not use llama_memory_hybrid_context (this could actually be perf-positive on Windows due to launch-overheads, will collect some numbers)

One of the reason to take this approach in the PR was to fix issues like #19708. Basically, there could be cases where one ggml_cgraph is launched only once. In those cases, using CUDA graph doesn't make much sense.

@bssrdf
Copy link
Contributor

bssrdf commented Feb 20, 2026

@gaugarg-nv, @ORippler, I am just curious whether it is feasible to "freeze" the cuda graph. In some applications, e.g. diffusion, the same cgraph runs repeatedly without topography changes. In this case, even ggml_cuda_graph_update_required can be skipped. Of course, this will require user intervention to unfreeze the graph.

@ORippler
Copy link
Collaborator

I am just curious whether it is feasible to "freeze" the cuda graph. In some applications, e.g. diffusion, the same cgraph runs repeatedly without topography changes. In this case, even ggml_cuda_graph_update_required can be skipped. Of course, this will require user intervention to unfreeze the graph.

Unfortunately ggml does not allow for this yet. We wanted to tackle this as part of the graph-plan API, which was postponed indefinitely due to the hiatus of slaren. llama-context, the LLM orchestration loop of llama.cpp, already avoids rebuilding the graph if topology is consistent since #14482, and would simply have to forward this information to the backends somehow.

Quoting from #14482:

[ ] Make CUDA reuse CUDA graphs using this new mechanism

@bssrdf
Copy link
Contributor

bssrdf commented Feb 20, 2026

I am just curious whether it is feasible to "freeze" the cuda graph. In some applications, e.g. diffusion, the same cgraph runs repeatedly without topography changes. In this case, even ggml_cuda_graph_update_required can be skipped. Of course, this will require user intervention to unfreeze the graph.

Unfortunately ggml does not allow for this yet. We wanted to tackle this as part of the graph-plan API, which was postponed indefinitely due to the hiatus of slaren. llama-context, the LLM orchestration loop of llama.cpp, already avoids rebuilding the graph if topology is consistent since #14482, and would simply have to forward this information to the backends somehow.

Quoting from #14482:

[ ] Make CUDA reuse CUDA graphs using this new mechanism

@ORippler, thanks for the #14482 link. Didn't know such capability already exists.

@ORippler
Copy link
Collaborator

#19757 This is a WIP based on my insights above, which would enable cudaGraphs for PP on dense models that do not use llama_memory_hybrid_context (this could actually be perf-positive on Windows due to launch-overheads, will collect some numbers)

Following up on this with a

TLDR: Given the limited perf-potential (talking about up to ~2% on Windows based on Llama-3.2-3B-Instruct-Q4_K_M.gguf), I'll shelve #19757 as it wouldn't resolve #19708, yet this PR does. We can potentially revisit once we robustify node-fusion in the cuda backend w.r.t. padded GGML_BACKEND_BUFFER_USAGE_COMPUTE buffers, as this made us hit 2b more often than anticipated in the PP-phase.

For the interested, continue reading along: Let's start with the collected perf numbers:
➜  llama.cpp git:(osimons/update_cuda_graph_heuristics) ✗ ./build-x64-linux-gcc-reldbg/bin/llama-bench -m /mnt/share/gguf/Llama-3.2-3B-Instruct-Q4_K_M.gguf -fa 1 -mmp 0 -dio 1 -p 16000
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --: | --------------: | -------------------: |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |   1 |         pp16000 |     15474.39 ± 28.32 |

build: f39a785e4 (8091)
➜  llama.cpp git:(osimons/update_cuda_graph_heuristics) ✗ GGML_CUDA_DISABLE_GRAPHS=1 ./build-x64-linux-gcc-reldbg/bin/llama-bench -m /mnt/share/gguf/Llama-3.2-3B-Instruct-Q4_K_M.gguf -fa 1 -mmp 0 -dio 1 -p 16000
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --: | --------------: | -------------------: |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |   1 |         pp16000 |     15525.54 ± 14.94 |
(base) (base) PS C:\Users\osimons\llama.cpp> $env:GGML_CUDA_DISABLE_GRAPHS=$null; .\build-cuda\bin\RelWithDebInfo\llama-bench.exe -m "\\hhs-truenas01.nvidia.com\devtech-data\gguf\Llama-3.2-3B-Instruct-Q4_K_M.gguf" -fa 1 -dio 1 -p 16000
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --: | --------------: | -------------------: |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |   1 |         pp16000 |     15145.43 ± 49.93 |

build: f39a785e4 (8091)
(base) (base) PS C:\Users\osimons\llama.cpp> ; .\build-cuda\bin\RelWithDebInfo\llama-bench.exe -m "\\hhs-truenas01.nvidia.com\devtech-data\gguf\Llama-3.2-3B-Instruct-Q4_K_M.gguf" -fa 1 -dio 1 -p 16000^C
(base) (base) PS C:\Users\osimons\llama.cpp> $env:GGML_CUDA_DISABLE_GRAPHS=1; .\build-cuda\bin\RelWithDebInfo\llama-bench.exe -m "\\hhs-truenas01.nvidia.com\devtech-data\gguf\Llama-3.2-3B-Instruct-Q4_K_M.gguf" -fa 1 -dio 1 -p 16000
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa | dio |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --: | --------------: | -------------------: |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |   1 |         pp16000 |     15211.40 ± 28.14 |

build: f39a785e4 (8091)

Surprisingly, we don't see perf gains on Windows, despite there being a 2% gap between Windows and Linux in cudaLaunch-API-mode. Why? Let's take an nsight systems report to figure out

Untitled

Ahh, we hit 2b) way more often than expected (my expectation was to hit only 2a).
But why do we hit 2b) way more often than expected? Well it's cause

        gf = model.build_graph(gparams);
        if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {

in llama_context::process_ubatch may yield different allocation of compute-buffers to intermediate nodes when expanding the context (would love to know why this occurs exclusively in PP and not in TG; I did see > 512 graph-reuses for TG in Llama-3.2-3B-Instruct-Q4_K_M.gguf, and AFAIK we increase our KV-caches/contexts in multiples of 256. Maybe @ggerganov knows more?). This in turn affects how we fuse nodes in the CUDA backend, meaning the topology of cudaGraph_t will change, and that's a reason for an update to cudaGraphExec_t to fail (i.e. to walk down the 2b path):

nodes of incoming ggml_cgraph: 874
Destroying previous CUDA graph instance for graph key 0x55555a631ec0 (num_nodes = 856)
New CUDA graph instance for graph key 0x55555a631ec0 (num_nodes = 867)

graph_key = ggml_cuda_graph_get_key(cgraph);

use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strictly speaking this function solely checks whether the GPU supports cudaGraphs, but we can change the name in a separate PR

Copy link
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code comments contain EM dashes. Unless there is a good reason not to, please stick to ASCII characters.

gaugarg-nv and others added 2 commits February 21, 2026 09:03
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@gaugarg-nv
Copy link
Contributor Author

Your code comments contain EM dashes. Unless there is a good reason not to, please stick to ASCII characters.

Fixed now.

Copy link
Contributor

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice and elegant solution!

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
@am17an am17an merged commit a0c91e8 into ggml-org:master Feb 21, 2026
77 of 78 checks passed
liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
* Improve CUDA graph capture

Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because:

- The first call always incurs CUDA graph capture overhead even if the graph is unstable
- Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode)

The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize.
This also fixes issues such as ggml-org#19708

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Remove EM dashes

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 2, 2026
* Improve CUDA graph capture

Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because:

- The first call always incurs CUDA graph capture overhead even if the graph is unstable
- Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode)

The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize.
This also fixes issues such as ggml-org#19708

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Remove EM dashes

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Mar 3, 2026
* Improve CUDA graph capture

Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because:

- The first call always incurs CUDA graph capture overhead even if the graph is unstable
- Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode)

The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize.
This also fixes issues such as ggml-org#19708

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Remove EM dashes

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
@JohannesGaessler
Copy link
Contributor

@gaugarg-nv I'm seeing a performance regression for pp512 from this PR:

GPU Model Microbatch size Test t/s 07968d5 t/s b8121 Speedup
RTX 3090 llama 8B Q4_0 512 pp512 5533.69 5158.79 0.93
RTX 3090 llama 8B Q4_0 512 pp4096 5083.82 5078.58 1.00
RTX 4090 llama 8B Q4_0 512 pp512 12939.20 11044.82 0.85
RTX 4090 llama 8B Q4_0 512 pp4096 12469.61 12419.14 1.00
RTX 5090 llama 8B Q4_0 512 pp512 16433.97 13120.55 0.80
RTX 5090 llama 8B Q4_0 512 pp4096 15003.54 14863.17 0.99

@gaugarg-nv
Copy link
Contributor Author

@gaugarg-nv I'm seeing a performance regression for pp512 from this PR:

Thanks for reporting this. What's the command-line option used? Do you have a warm-up iteration? What is the number of runtime iterations?

I guess that when you are testing pp512 with a micro-batch size of 512, the CUDA graph won't change across iterations, and the CUDA graph captured during warm-up will get reused. Can you try removing the warm-up loop or reducing runtime iteration and see if it has any impact?

@JohannesGaessler
Copy link
Contributor

JohannesGaessler commented Mar 8, 2026

I tested the performance like this:

export mn=llama_3-8b && export q=q4_0
export CUDA_VISIBLE_DEVICES=0
./bench --model models/opt/${mn}-${q}.gguf -fa 1 -r 1 -n 0 -p 512,4096 -o sql|sqlite3 llama-bench.sqlite
py scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite -b 07968d53e4c4421|tee bench.txt

@JohannesGaessler
Copy link
Contributor

I think I know what's going on. If I use 10 rather than 1 repetition for the benchmark there is basically no performance difference:

GPU Model Microbatch size Test t/s 07968d5 t/s b8121 Speedup
RTX 3090 llama 8B Q4_0 512 pp512 5550.80 5489.52 0.99
RTX 3090 llama 8B Q4_0 512 pp4096 5064.85 4980.22 0.98
RTX 4090 llama 8B Q4_0 512 pp512 13101.74 12946.40 0.99
RTX 4090 llama 8B Q4_0 512 pp4096 12287.83 12273.52 1.00
RTX 5090 llama 8B Q4_0 512 pp512 16313.25 16014.68 0.98
RTX 5090 llama 8B Q4_0 512 pp4096 14928.63 14817.79 0.99

On the warmup run for pp512 CUDA graphs are not used because it's the first run. On the first benchmark run the same ggml graph is run for a second time so a CUDA graph is captured which introduces overhead. With 10 benchmark runs the overhead amortizes so there is no difference. So I would see this not as an issue with the code but rather with how we are benchmarking it.

@gaugarg-nv
Copy link
Contributor Author

Right, I think one way to fix this is to increase the number of warmup iterations to 2 instead of the current 1. With that change, both implementations should show the same perf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants