Skip to content

cuda : enable CUDA graphs for MMID 1 <= BS <= 4#19645

Merged
ggerganov merged 3 commits intomasterfrom
gg/cuda-graphs-enable-bs-gt1
Feb 17, 2026
Merged

cuda : enable CUDA graphs for MMID 1 <= BS <= 4#19645
ggerganov merged 3 commits intomasterfrom
gg/cuda-graphs-enable-bs-gt1

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Feb 15, 2026

cont #19644
cont #18958
cont #19521

Enable CUDA graphs for ggml graphs with GGML_OP_MUL_MAT_ID at 1 < BS <= 4. Improves the performance for parallel generation of up to 4 sequences. This is useful for example when running multiple local agents in parallel (#19564 (reply in thread))

Also, simplify the CUDA graph exception logic. Unless I am missing something, we no longer need to restrict GGML_OP_ADD for BS > 1 because we perform an exhaustive node properties check to decide if the graph needs an update.

Next PRs:

  • Try to extend support for BS > 4 && BS <= 8, or more?

DGX Spark:

Model Microbatch size Test t/s pr/19644 t/s pr Speedup
deepseek2 30B.A3B Q8_0 1 pp512 52.58 52.23 0.99
deepseek2 30B.A3B Q8_0 2 pp512 68.96 78.45 1.14
deepseek2 30B.A3B Q8_0 3 pp512 83.68 92.96 1.11
deepseek2 30B.A3B Q8_0 4 pp512 95.28 104.23 1.09
gpt-oss 120B MXFP4 MoE 1 pp512 70.73 70.47 1.00
gpt-oss 120B MXFP4 MoE 2 pp512 86.85 96.82 1.11
gpt-oss 120B MXFP4 MoE 3 pp512 104.29 114.42 1.10
gpt-oss 120B MXFP4 MoE 4 pp512 115.61 124.86 1.08
gpt-oss 20B MXFP4 MoE 1 pp512 106.60 106.66 1.00
gpt-oss 20B MXFP4 MoE 2 pp512 136.89 153.91 1.12
gpt-oss 20B MXFP4 MoE 3 pp512 165.17 186.38 1.13
gpt-oss 20B MXFP4 MoE 4 pp512 191.08 207.84 1.09
qwen3moe 30B.A3B Q8_0 1 pp512 67.24 66.92 1.00
qwen3moe 30B.A3B Q8_0 2 pp512 82.34 91.74 1.11
qwen3moe 30B.A3B Q8_0 3 pp512 98.97 108.10 1.09
qwen3moe 30B.A3B Q8_0 4 pp512 111.61 120.74 1.08
qwen3next 80B.A3B Q4_0 1 pp512 65.15 64.85 1.00
qwen3next 80B.A3B Q4_0 2 pp512 57.28 75.91 1.33
qwen3next 80B.A3B Q4_0 3 pp512 78.10 102.16 1.31
qwen3next 80B.A3B Q4_0 4 pp512 95.98 121.15 1.26
qwen3next 80B.A3B Q8_0 1 pp512 43.10 43.23 1.00
qwen3next 80B.A3B Q8_0 2 pp512 43.33 53.78 1.24
qwen3next 80B.A3B Q8_0 3 pp512 56.90 69.23 1.22
qwen3next 80B.A3B Q8_0 4 pp512 68.84 81.43 1.18

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 15, 2026
Copy link
Contributor

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speedup on 4090:

Model Microbatch size Test t/s 8bc255a t/s gg/cuda-graphs-enable-bs-gt1 Speedup
gpt-oss 20B MXFP4 MoE 1 pp512 283.83 284.09 1.00
gpt-oss 20B MXFP4 MoE 2 pp512 401.81 443.68 1.10
gpt-oss 20B MXFP4 MoE 3 pp512 517.59 561.15 1.08
gpt-oss 20B MXFP4 MoE 4 pp512 595.86 635.65 1.07

// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > 4)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we rename this 4 to MMVQ_MMID_MAX_BATCH_SIZE in case we end up optimizing for BS > 4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler Can you confirm that adding the MMVQ_MMID_MAX_BATCH_SIZE constant is OK for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a named constant here would be fine with me.

Base automatically changed from gg/graph-fix-kq-mask-reuse to master February 16, 2026 07:21
@ggerganov ggerganov requested a review from CISC as a code owner February 16, 2026 07:21
@ggerganov ggerganov force-pushed the gg/cuda-graphs-enable-bs-gt1 branch from 15a6842 to 7d0be2c Compare February 16, 2026 07:22
Comment on lines +2887 to +2889
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see two ways for this:

  1. Refactor the logic expressed on host side into a CUDA kernel
  2. Use cudaLaunchHostFunc and manage lifetime of CPU objects inside ggml_cuda_graph

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fundamental problem here is that we are using cuBLAS for GEMM and are orchestrating the matrix multiplications per expert via CPU logic. It would in principle be possible to pad the matrix multiplications and use batched/strided GEMM but that will waste a lot of compute. I'm not sure whether it's possible to orchestrate cuBLAS GEMM from device code, launching CUDA kernels from within cudaLaunchHostFunc also seems dubious. For the custom ggml kernels I made it so that during the quantization kernel the data is re-arranged to be contiguous per expert and the kernel is skipping any output tiles that would be 100% padding. In principle this could be generalized to floating-point data, see #18864 .

@ORippler
Copy link
Collaborator

Also, simplify the CUDA graph exception logic. Unless I am missing something, we no longer need to restrict GGML_OP_ADD for BS > 1 because we perform an exhaustive node properties check to decide if the graph needs an update.

That's correct.

// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > 4)) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, note here that we currently require quantized src0. This means that with BF16 MoE models for example, the CUDA graphs will not be enabled.

Copy link
Contributor

@am17an am17an Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#18958 enables this for mmvf also, so bf16 should also work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry that is that just for AMD

Co-authored-by: Oliver Simons <osimons@nvidia.com>
@ggerganov
Copy link
Member Author

Merging on green

@ggerganov ggerganov merged commit ad8207a into master Feb 17, 2026
78 checks passed
@ggerganov ggerganov deleted the gg/cuda-graphs-enable-bs-gt1 branch February 17, 2026 10:31
LostRuins added a commit to LostRuins/koboldcpp that referenced this pull request Feb 20, 2026
@LostRuins
Copy link
Collaborator

Hello @ggerganov, this PR specifically seems to cause a new memory leak faced by @henk717 when used with stable-diffusion.cpp. I know it sounds strange but bear with me.

  • Memory leak causes the allocation of 1-2GB of additional RAM every time an image is generated, eventually resulting in OOM after many images.
  • The memory leak increase per-step of the diffusion, so it's something to do with inference time allocations.
  • All image generation models are affected, but text generation is unaffected
  • The leak ONLY happens on the CUDA backend. Vulkan and CPU do not observe such a leak
  • The leak only seems to happen when CUDA graphs is enabled ie. GGML_CUDA_USE_GRAPHS otherwise there is no leak.
  • The leak seems to happen on a 3090 (@henk717 has a dual RTX 3090 setup) but not on my own RTX4090 (is fine)
  • Reverting this PR causes the memory leak to disappear. This was determined by bisecting until this commit was found.

I can create a full issue too if you want but I figured since this is a very small commit maybe you might already know the fix.

Also tagging stable-diffusion.cpp stakeholders although I don't think it's caused there since their code is unchanged - @leejet @stduhpf

@ORippler
Copy link
Collaborator

ORippler commented Feb 20, 2026

Hello @ggerganov, this PR specifically seems to cause a new memory leak faced by @henk717 when used with stable-diffusion.cpp. I know it sounds strange but bear with me.

Memory leak causes the allocation of 1-2GB of additional RAM every time an image is generated, eventually resulting in OOM after many images.
The memory leak increase per-step of the diffusion, so it's something to do with inference time allocations.
All image generation models are affected, but text generation is unaffected
The leak ONLY happens on the CUDA backend. Vulkan and CPU do not observe such a leak
The leak only seems to happen when CUDA graphs is enabled ie. GGML_CUDA_USE_GRAPHS otherwise there is no leak.
The leak seems to happen on a 3090 (@henk717 has a dual RTX 3090 setup) but not on my own RTX4090 (is fine)
Reverting this PR causes the memory leak to disappear. This was determined by bisecting until this commit was found.
I can create a full issue too if you want but I figured since this is a very small commit maybe you might already know the fix.

Also tagging stable-diffusion.cpp stakeholders although I don't think it's caused there since their code is unchanged - @leejet @stduhpf

Seems related to #19708. Does stable-diffusion.cpp re-create a new ggml-cgraph object with different node-key for every step of the pipeline?

Also, does #19754 fix your issue?

@henk717
Copy link

henk717 commented Feb 20, 2026

Small clarification, the amount of memory increase will depend on the image model. It can be as much as 4GB per image on Qwen Image Edit. Its very noticable when it happens. I monitor this in the regular ram, not the vram although that seemed to go up as well. While I run dual 3090's to the application only one of the 3090's is visible.

LostRuins added a commit to LostRuins/koboldcpp that referenced this pull request Feb 22, 2026
@LostRuins
Copy link
Collaborator

Sent you a test build with #19754 merged and this PR reinstated, let me know if it fixes the memory leak

@henk717
Copy link

henk717 commented Feb 22, 2026

This did not resolve the issue.

LostRuins added a commit to LostRuins/koboldcpp that referenced this pull request Feb 22, 2026
@LostRuins
Copy link
Collaborator

maybe gg knows whats wrong.

@ggerganov
Copy link
Member Author

Could you provide the most basic steps to repro?

@LostRuins
Copy link
Collaborator

Could you provide the most basic steps to repro?

Umm sorry, I'd like to if possible, unfortunately... I don't know how to reproduce it easily here in this repo... as most of the necessary code is not inside the llama.cpp repo.

Henk717 was using it from within KoboldCpp. The issue only happens when using image generation code taken from https://github.com/leejet/stable-diffusion.cpp, however sd.cpp itself has yet to sync to the new ggml version (which need this commit), so it won't be reproducible there yet, additionally they also do not use CUDA graphs enabled there which is needed to trigger this issue, so it will also need to be enabled.

I could direct you to my own fork at https://github.com/LostRuins/koboldcpp however I must admit my code quality is not very good, my build process is jank, and might be troublesome to work with, plus I cannot trigger this bug myself...

@henk717 do you have any better idea how we can help gg troubleshoot this here?

The circumstances of the leak caused by this PR are as summarized in #19645 (comment)

@henk717
Copy link

henk717 commented Feb 22, 2026

It happens during regular KoboldCpp usage on every image I generate. System memory usage is extremely noticeable as a measurement as it will rapidly increase on an effected build. For me its consistent on Windows with driver version 581.29 on my local desktop PC with a 3090 GPU. On a 3090 linux cloud rental I was unable to measure this issue (I assume it wasn't present but I am less familiar with the linux tools). Lostruins was uneffected on his 4090 laptop.

The tricky part is that it manifests inside of stablediffusion.cpp because KoboldCpp uses the newer ggml from Llamacpp. At the moment the upstream sdcpp project is still on a 3 week old ggml version and will not be impacted yet.

One way to test it on KoboldCpp is by using https://huggingface.co/MaxedOut/ComfyUI-Starter-Packs/resolve/main/SDXL/checkpoints/sd_xl_base_1.0.safetensors?download=true (Any model should produce it, this is a relatively fast one that I know for certain is easy to observe) as the image model. This can be done in the Image Model tab by selecting it as an Image Gen Model. No further files are required for this model. The commandline equivalent would be --imgmodel

Once the UI launches accept that you wish to be taken to sdui or use the /sdui link for easy image generation. Type any prompt (In my tests I use the word Kobold) and you will now see the system process's memory increase by several hundred megabytes every image you generate.

If you prefer working with the upstream sdcpp project you may be able to link it to a newer ggml version to introduce the same issue.

KoboldCpp could be compiled with cmake for the windows cuda build (Linux is make or koboldcpp.sh, but on Linux I haven't been able to reproduce the issue so far possibly due to my lack of familiarity with the linux memory management tools).

@LostRuins would it be difficult for you to provide a stablediffusioncpp build (not koboldcpp) with the new ggml applied?

@am17an
Copy link
Contributor

am17an commented Feb 23, 2026

Maybe add this check back again and see if it resolves the issue

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 7e6d33035..f841d8ac9 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2893,6 +2893,10 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 #endif
         }
 
+        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+            use_cuda_graph = false;
+        }
+
         if (!use_cuda_graph) {
             break;
         }

@ggerganov
Copy link
Member Author

Yes, let's start with @am17an's suggestion first to try to narrow down the possible root causes. @henk717 Let us know when you give it a try.

@henk717
Copy link

henk717 commented Feb 23, 2026

Bullseye, @LostRuins made a test for me with the fix applied and on this test build I do not have the issue.
Update: As a sanity check I opened the older test build where it happened, it still happens there. Went back to this test and it doesn't happen. So I also confirmed I can still reproduce the issue without the fix, eliminates the potential placebo factor.

@am17an
Copy link
Contributor

am17an commented Feb 23, 2026

So I think we could be adding a large of number of CUDA graphs which are never reclaimed. Namely this code is being called several times. It's probably because each time graph_compute is being called with a new pointer during nrows > 1.

ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
auto it = cuda_graphs.find(first_node_ptr);
if (it == cuda_graphs.end()) {
cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
return cuda_graphs[first_node_ptr].get();
}
return it->second.get();
}

You can confirm this by using valgrind --tool=massif ./binary and then ms_print massif.out.<pid>.

@henk717
Copy link

henk717 commented Feb 23, 2026

Its happening on my Windows system, so if that instruction is for me i'd need to know a windows equivalent.

@am17an
Copy link
Contributor

am17an commented Feb 23, 2026

IDK on windows, from a quick search it looks like drmemory is useful.

liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
@henk717
Copy link

henk717 commented Feb 23, 2026

This is not going to be an easy task for me, the tool you suggested didn't work. I found another one which reduces the generation speeds immensely so it takes a long time, the end result was without any names so I don't think any of us will be able to decypher it.

Between @LostRuins integrating fixes and me testing this would probably take significantly more time compared to queuing up some idea's. Normally when I am testing he queue's up multiple test builds so I can test them all in one go this will likely be faster than the hours it will take for a single test, assuming we get more useful output to begin with as I don't know if the tool I found can handle debug symbols (The exe I have isn't a debug symbol build but one from our ci).

My suggestion is to continue on the more speculative path, ideally in one go. Then he can queue it up for me in his timezone and I can test it in mine if that is needed.

@am17an
Copy link
Contributor

am17an commented Feb 23, 2026

Okay, you can try this patch to confirm

diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 36d8a3aaa..d394fb475 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -1335,6 +1335,9 @@ struct ggml_backend_cuda_context {
         auto it = cuda_graphs.find(first_node_ptr);
         if (it == cuda_graphs.end()) {
             cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
+            if (cuda_graphs.size() % 100 == 0) {
+                GGML_LOG_INFO("cuda graphs current size: %zu\n", cuda_graphs.size());
+            }
             return cuda_graphs[first_node_ptr].get();
         }
         return it->second.get();

This should keep on increasing forever till OOM and you should see the log of the size of the map.

@henk717
Copy link

henk717 commented Feb 24, 2026

The patch was added : LostRuins@c2e06d8
On this build I could reproduce the leak but I did not see what i'd expect as this particular log does not appear for me.
I think the if statements aren't being hit.

Instead it looks like this:
generating image: 1/1 - seed 4394834
|==> | 1/20 - 2.26it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|=====> | 2/20 - 3.20it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|=======> | 3/20 - 3.83it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete

ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|=================> | 7/20 - 5.10it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|======================> | 9/20 - 5.39it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|==========================================> | 17/20 - 6.14it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|===============================================> | 19/20 - 6.20it/s
ggml_backend_cuda_graph_compute: CUDA graph warmup complete
|==================================================| 20/20 - 6.21it/s

@henk717
Copy link

henk717 commented Feb 24, 2026

Update: I noticed it was kinda limited on sdxl so I went back to testing qwen, a few images in I saw the new debug output.

@LostRuins
Copy link
Collaborator

But what did it output (the values)

@henk717
Copy link

henk717 commented Feb 24, 2026

The check was an equals check, so I only got it when it hit 100.

@am17an
Copy link
Contributor

am17an commented Feb 25, 2026

It's a modulo check. So you should get the print at 100, 200, 300 etc.

@henk717
Copy link

henk717 commented Feb 25, 2026

By the time it hit 100 I already had multiple gigabytes of additional ram.

bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 2, 2026
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Mar 3, 2026
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants