Skip to content

CUDA: Improve performance via less synchronizations between token#17795

Merged
ggerganov merged 13 commits intoggml-org:masterfrom
aendk:akieslinger/reduce-per-token-syncs
Mar 5, 2026
Merged

CUDA: Improve performance via less synchronizations between token#17795
ggerganov merged 13 commits intoggml-org:masterfrom
aendk:akieslinger/reduce-per-token-syncs

Conversation

@aendk
Copy link
Contributor

@aendk aendk commented Dec 5, 2025

See comment below


This PR suggest to remove some superfluous synchronization calls between tokens to be faster on CUDA backends. I see between 1% and 2% perf gain depending on the model, GPU and settings.

Mechanism

The performance impact is best explained visually. Here are the "before" and "after" nsight system traces. They are not to scale. The relevant part is the row with the green and red bubbles. Both images show the overhead between the GPU execution of two tokens. The generation of the n-th token ends on the left hand side of the screenshot, in the green bubble titled cudaStreamSynchronize. The calculation of the next/n+1st token starts on the right hand side, at the green bar titled cudaGraphLaunch. In between, there is CPU orchestration overhead. This PR aims to shrink the time spent in the middle, between GPU token generation. Original:

Screenshot 2025-12-05 at 14 45 35

In the middle of the above image, we see red and green bubbles alternating. In this case, the green bubbles are synchronization steps, the red bubbles are asynchronous copy calls from host to device. If async operations are immediately followed by synchronization calls, they are executed synchronously. This is not efficient. Removing the green synchronization operations between asynchronous copy calls leads to asynchronous copies and reduced overhead between GPU token generation:

Screenshot 2025-12-05 at 14 45 10

Performance

I benchmarked on a RTX Pro 6000 Blackwell using ./llama-bench -m $models -p 0 -n 128,256,512 -fa 1.
My testing shows around 1% improvement, with gpt-oss-20b gaining up to 1.4%. llama 3B Q4_K - Medium shows very high variance, prompting me to run the tests again with -r 100. At -r 100, a clearer trend of improved performance for gemma3n E2B Q8_0 is also visible.

Details with default `-r 5`

Baseline:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        392.24 ± 1.07 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        392.72 ± 0.35 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        387.72 ± 0.38 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        464.85 ± 0.55 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        465.39 ± 0.59 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        461.87 ± 0.74 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        231.59 ± 0.09 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        231.47 ± 0.03 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        228.21 ± 0.46 |

build: 909072abc (7176)

PR:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        397.14 ± 1.50 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        398.36 ± 0.45 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        393.25 ± 0.65 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        472.48 ± 3.71 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        468.81 ± 0.19 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        463.62 ± 1.28 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        232.84 ± 0.18 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        232.82 ± 0.08 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        229.62 ± 0.25 |

build: f6b408d84 (7178)

Speedup:

1.01249
1.01436
1.01426
1.01641
1.00735
1.00379
1.0054
1.00583
1.00618
Details with `-r 100`

Baseline:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        393.24 ± 0.45 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        393.33 ± 2.97 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        381.93 ± 2.40 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |       446.41 ± 40.17 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |       451.55 ± 21.34 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        454.89 ± 0.33 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        231.90 ± 0.27 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        231.93 ± 0.21 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        228.47 ± 0.14 |

build: 909072abc (7176)

PR:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        398.52 ± 0.41 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        397.32 ± 5.71 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        383.53 ± 3.06 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |       441.09 ± 50.39 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |       456.69 ± 20.91 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        458.19 ± 0.32 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        233.98 ± 0.13 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        233.65 ± 0.25 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        230.18 ± 0.14 |

build: aebcdf119 (7178)

Speedup:

1.01366
1.01025
1.00408
0.982033
1.00875
1.00819
1.00893
1.00876
1.00938

Implementation Concerns

See updated comment below.

@ggerganov @JohannesGaessler

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Dec 5, 2025
@wishstudio
Copy link
Contributor

How about using existing .set_tensor_async in ggml_backend_i? There is also a ggml_backend_tensor_set_async helper.

I believe the set_tensor, get_tensor etc interfaces guarantees synchronization behavior, though I'm not sure which part depends on it. A better plan might be changing each set_tensor use to the async version one-by-one instead of changing them all at once which will likely be ABI change and lead to problems.

Besides, I believe the current way of synchronizing both input_backend and split_backend is still not optimal as some synchronizations can be completely skipped. For example, for a CPU split copying data from GPU, we have to do a synchronization to make sure all data is copied to RAM before use. But for a GPU split copying data from CPU, there is no synchronization needed as the cudaMemcpyAsync call will be automatically serialized with following kernel launches.

For synchronizations at the beginning of this code section (ggml_backend_synchronize(split_backend);, etc) I think the intention is to synchronize before the following steps. The behavior is changed in this patch, though I feel these synchronizations are unnecessary in many cases (especially single GPU).

@aendk
Copy link
Contributor Author

aendk commented Dec 8, 2025

How about using existing .set_tensor_async in ggml_backend_i? There is also a ggml_backend_tensor_set_async helper.

While there is some amount of inter-op between the two backends (.set_tensor_async of ggml_backend_i can fall back to .set_tensor of ggml_backend_buffer_i), I chose not to mix the usage of these two interfaces further.
They are split in two, to separate two distinct concerns (data vs execution), and to enable different object lifetimes. They also slightly differ regarding stream semantics (default stream vs cudaStreamPerThread).
Another reason for not using .set_tensor_async in ggml_backend_i is that it requires a backend as argument, which you cannot get from just the tensor by design. Using it would therefore require more changes in the general backend, which I tried to avoid. 

Ultimately, this is a matter of taste and what architectural trade-off seems best in this specific case.

@wishstudio
Copy link
Contributor

They are split in two, to separate two distinct concerns (data vs execution), and to enable different object lifetimes.

Thank you for the explanation. I was not aware of the distinction and it totally makes sense.

@ggerganov
Copy link
Member

ggerganov commented Dec 8, 2025

I believe these changes make an implicit assumption that the async copies will enter a common execution queue/stream with all previous operations. This seems to be the case for CUDA, but I don't think it is general enough assumption that can be made for all backends.

For example, removing this backend synchronization here I think is incorrect, since semantically it means that copying of the input can start immediately, even if the backend is still doing work:

// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else {
ggml_backend_synchronize(split_backend);
}
ggml_backend_tensor_copy(input, input_cpy);

With CUDA, it is apparently not a problem because the copying does not start immediately - it gets queued in a queue (is this correct? who owns this queue? what guarantees that the copy will wait for all previous CUDA operations to finish?). But for some other backend, a copy can start immediately - as soon as we call ggml_backend_tensor_copy. So I think the explicit ggml_backend_synchronize() call is needed there.

Another implicit assumption that I think is not very good in this proposal: ggml_backend_synchronize() is assumed that will sync all asynchronous buffer copies. Again, I believe this comes from the assumption of a shared CUDA queue or stream. But imagine that I make a dummy backend buffer that implements async copy like this:

ggml_backend_dummy_buffer_set_tensor_async(...) {
    start_thread_to_copy_data(...).detach();
}

Technically, this is an asynchronous call so it's perfectly fine from an API standpoint. But there isn't a common queue to manage this thread. The thread belongs to the buffer, not to the backend. Calling ggml_backend_synchronize(dummy_backend) would not guarantee that the thread has finished, because the two of them (the thread copying the buffer and the backend) are not associated with each other. I think this is the main reason for not implementing an async API for the backend buffers in the first place.

Copy link
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ultimate authority on what the ggml backend interface should or shouldn't be is Diego but my opinion is that ggml_backend_tensor_set should be blocking and that an asynchronous equivalent should be an explicit opt-in. Otherwise user code can accidentally introduce use-after-free bugs depending on how fast or slow the memcpy is vs. the user code.

@JohannesGaessler
Copy link
Contributor

Re Georgi's first point: I think the PR as-is is already incorrect for CUDA. The reason there is a CPU synchronization for the memcpy is that the source pointer is under host control and the host could free the memory immediately after the function returns, resulting in a use-after-free bug. With this PR this could in principle still happen. It would only be safe if there is a precondition for user code that stipulates that memory may only be freed after calling ggml_backend_synchronize.

Re Georgi's second point: If I understood Diego's idea of what a ggml "backend" is supposed to be correctly it was pretty much equivalent to the concept of a CUDA stream. And especially when the ggml backend API has a function ggml_backend_synchronize my expectation going in would be that any "async" calls would be synchronized on that function.

@ggerganov
Copy link
Member

And especially when the ggml backend API has a function ggml_backend_synchronize my expectation going in would be that any "async" calls would be synchronized on that function.

Yes, ggml_backend_synchronize synchronizes any async calls to the backend. My point was that if we were to introduce a new async API to the backend buffer (as in this proposal), then we cannot assume that a backend synchronization will sync all buffers. Hence my point is that an async backend buffer API would be pointless.

@aendk aendk force-pushed the akieslinger/reduce-per-token-syncs branch 2 times, most recently from 5d26313 to 1233fdd Compare December 19, 2025 10:32
@aendk
Copy link
Contributor Author

aendk commented Dec 19, 2025

Hi @ggerganov @JohannesGaessler , I reworked the whole PR and it is ready for review again. I started it from scratch due to the incorrect assumptions I based the first approach on.
I summarized the implications of this PR in the following Q&A format:

  • What changed, compared to master:
    • Made async copies from CPU to the CUDA backend possible (ggml-cuda.cu change)
    • Made explicit syncs before, during and after async copies optional for backends which have implicit sync mechanisms (i.e. CUDA for now)
    • Exchanged ggml_backend_tensor_copy with ggml_backend_tensor_copy_async
    • Additional optional sync to retain behavior for all other backends
  • What did not change, compared to master:
    • No changes to the API
    • No change in behavior for non-CUDA backends
    • No changes to the required libs to link, no additional #includes
  • Possible issues & points for discussion:
    • This PR essentially removes explicit syncs and relies on implicit syncs (CUDA stream) where possible. In the setting of a single GPU with a single CUDA backend, everything is straightforward. Async copies and the graph launch are passed to the same CUDA stream, which then executes each item like in a queue. No data races are possible. For this, I verified that everything is indeed passed to the same CUDA stream.
    • Other than this, I thought about the possibility of bugs around CUDA backends where a common stream is not guaranteed:
    • Multiple CUDA backends on a single GPU (e.g. a model with an unsupported operation leading to ping-pong execution of GPU->CPU_FALLBACK->GPU->CPU_FALLBACK->GPU…)
    • Multiple CUDA backends on several GPUs, for example in Pipeline parallelism.
    • Ultimately, the transition between two possibly async backends is something I analyzed in-depth to ensure correctness. I found the following:
      1. Multiple CUDA backends on the same device can copy asynchronously via cpy_tensor_async(). My PR does not change the existing behavior.
      2. If no cpy_tensor_async() is possible, the (unchanged) fallback is to sync both backends, and copy synchronously.
      3. If there are CUDA backends on multiple devices, my PR does not change the scheduling behavior either: like in ii., when the asynchronous copy fails, both backends are synced and memory is copied in a synchronous manner.
  • Misc:
    • Optional Syncing can be disabled via ENV Flag.
    • strncmp-call to identify backend: Only compares up to 4 letters, could a SIGSEV still happen on shorter matching strings? I could explicitly match all shorter backend strings to harden this section.
    • Side note: The section about MoE cpu offloading is quite long (L1458/1493 -> L1536/1571). If you agree, I could move this to a separate helper function so the heart of the scheduling logic is easier to understand and maintain.
  • Performance:
    • About 1%, depending on the model (see details below).
  • Future work:
    • Other implicitly synced backends can be enrolled to skip additional syncs with only a few more lines of code.

@aendk aendk marked this pull request as ready for review December 19, 2025 15:51
@aendk aendk changed the title [DRAFT] CUDA: Improve performance via less synchronizations between token CUDA: Improve performance via less synchronizations between token Dec 19, 2025
@jeffbolznv
Copy link
Contributor

The current behavior of what is implicitly ordered is more or less undocumented and challenging to reverse-engineer (sounds like you had to do some of this yourself). Adding code paths that rely on backend-specific implicit ordering is IMO a step in the wrong direction.

@aendk
Copy link
Contributor Author

aendk commented Dec 19, 2025

Screenshot 2025-12-19 at 17 05 28

Because of the implicit sync in the CUDA case, we now only need a single synchronization in the loop. It is required to retrieve the logits. This is also a small improvement over the first approach, next to adhering to scheduling correctness.

Performance on RTX PRO A6000 Blackwell Max-Q Without:
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        356.77 ± 0.89 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        356.83 ± 0.42 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        352.40 ± 0.55 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        426.12 ± 0.37 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        425.38 ± 0.11 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        421.85 ± 0.74 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        205.11 ± 0.18 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        205.44 ± 0.16 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        202.59 ± 0.28 |

build: 909072abc (7176)

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        361.16 ± 0.93 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        362.04 ± 0.34 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        358.05 ± 0.55 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        431.59 ± 0.24 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        430.03 ± 0.15 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        426.97 ± 1.02 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        206.94 ± 0.18 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        207.22 ± 0.09 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        204.42 ± 0.20 |
build: 909072abc (7176)

Speedup:

1.0123
1.0146
1.01603
1.01284
1.01093
1.01214
1.00892
1.00866
1.00903

@aendk aendk marked this pull request as draft December 20, 2025 13:09
@aendk
Copy link
Contributor Author

aendk commented Dec 20, 2025

Back to draft for some double checking on scenarios like Multi-CPU or non-CUDA to CUDA with UMA.

@aendk
Copy link
Contributor Author

aendk commented Dec 22, 2025

Hi,
I needed to double-check the change to allow host to device async copy operations in ggml-cuda.cu.
In ggml-cuda.cu, checking that the source buffer backend is indeed the CPU backend is crucial to avoid data races.
The CPU backend is synchronous and is instantiated only once (ggml_backend_cpu_reg_get_device_count()). It is therefore implicitly synced with the orchestration loop calling this cpy_tensor_async() function. This ensures that computation of the previous CPU split is done, and that the copy to the next CUDA-backend can take place asynchronously without the possibility of data races.
Next to this case, there are or there could be other backends which work on host buffers where this might not apply.
A check like ggml_backend_is_cpu(backend_src) is therefore essential in ggml-cuda.cu.
I removed it in the last commit to fix linking errors, but now I think I need suggestions how to best reimplement this functionality again.
Should I:
(1) include ggml_backend_is_cpu() and therefore ggml-cpu.h into ggml-cuda.cu again, and change the linking structure to make it work? Currently, every backend is independent, even though the ggml-cpu backend is a strict requirement in ggml_backend_sched_new().
(2) check the backend via ggml_backend_dev_type == GGML_BACKEND_DEVICE_TYPE_CPU (Is that precise enough?)
(3) duplicate the code of ggml_backend_is_cpu() in ggml-cuda.cu (guid check)

Or is there something more elegant I've overlooked?

@aendk aendk force-pushed the akieslinger/reduce-per-token-syncs branch 2 times, most recently from 402502f to 459cb40 Compare January 13, 2026 13:58
@aendk
Copy link
Contributor Author

aendk commented Jan 13, 2026

Hi @ggerganov, @JohannesGaessler,

apart from the spurious and likely unrelated CI fails, this PR is ready for review.
It is now more general, and it improves windows performance much more than above measurements on linux indicated.
In my testing, I see speedups between 3-9% on windows (see below). Most importantly, there is a solid 8-9% gain for gpt-oss 20B.
For linux, the earlier measurements with 1-2% speedup still apply. The likely reason for this is that the CUDA API is slower on windows.

For easier reviewing, please take a look at the graphic to understand my changes:
Screenshot 2026-01-08 at 13 37 23

This graphic shows the behavior of ggml_backend_compute_splits() for gpt-oss 20B.
Splits are executed one after another. Before a split is executed, the necessary data is copied to its backend.
This can happen in two ways: from the host orchestration loop which is in ggml_backend_compute_splits() (upper half), or from the previous backend (lower half).

Lower half (Backend to Backend)

For the lower half, each backend can implement asynchronous copies to itself. The fallback is fully synchronous. Here, I enabled async copies from the CPU backend to the CUDA backend, next to the existing CUDA->CUDA behavior.
This is possible as there is only a single, synchronous CPU backend in ggml. This backend is therefore implicitly synchronized with the host orchestration loop ([0]).
When this loop triggers the copy to the CUDA backend, it is guaranteed that the CPU backend has finished.

Upper half (Host to Backend)

For the upper half, we have several small (position ids, masks) parallel copies from the host code to the backend. My changes in ggml-backend.cpp modify these.
Currently, these async copies are always followed by costly synchronizations. Now, depending on the backend, there can also be 1) no synchronizations at all in the lead up to the backend execution, or 2) only a single synchronization between the last copy and the first read (graph execution).
The first case applies to the CUDA backend, where copies and kernel launches are implicitly synchronized with the CUDA stream. The CANN backend could also benefit from this.
The second option could be useful for backends like the vulkan backend, which requires pipeline barriers between copies and shaders.
For now, I only enrolled the CUDA backend to leverage these changes. Other maintainers can do the same for their backend.

Regarding my previous comment, I decided to go for option 2) for checking the backend type in ggml-cuda.cu. Let me know if this is strict enough.

Performance on Windows

Details

Master (4150da9a956d0587a3e8a08e69940f9a27f88e5c):

| model                          |       size |     params | backend    | ngl | fa |            test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        321.30 ± 4.01 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        324.02 ± 1.22 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        319.36 ± 1.66 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |       375.01 ± 54.46 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        396.44 ± 1.71 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        392.60 ± 0.70 |
| qwen3 4B Q4_K - Medium         |   2.44 GiB |     4.02 B | CUDA       |  99 |  1 |           tg128 |       291.21 ± 35.51 |
| qwen3 4B Q4_K - Medium         |   2.44 GiB |     4.02 B | CUDA       |  99 |  1 |           tg256 |        306.63 ± 1.00 |
| qwen3 4B Q4_K - Medium         |   2.44 GiB |     4.02 B | CUDA       |  99 |  1 |           tg512 |        301.51 ± 0.95 |

build: 4150da9a9 (7713)

PR:

| model                          |       size |     params | backend    | ngl | fa |            test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes

| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        350.81 ± 1.64 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        350.33 ± 1.85 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        346.63 ± 0.57 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        411.76 ± 1.56 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        411.01 ± 2.77 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        405.78 ± 1.35 |
| qwen3 4B Q4_K - Medium         |   2.44 GiB |     4.02 B | CUDA       |  99 |  1 |           tg128 |        316.53 ± 2.41 |
| qwen3 4B Q4_K - Medium         |   2.44 GiB |     4.02 B | CUDA       |  99 |  1 |           tg256 |        317.05 ± 1.32 |
| qwen3 4B Q4_K - Medium         |   2.44 GiB |     4.02 B | CUDA       |  99 |  1 |           tg512 |        311.76 ± 0.90 |

build: 402502f46 (7722)

Speedup on windows:

1.09185
1.0812
1.08539
1.098
1.03675
1.03357
1.08695
1.03398
1.034

@aendk aendk marked this pull request as ready for review January 13, 2026 16:31
@ggerganov
Copy link
Member

On first look, this looks like a solid proposal - at least I don't see any major concerns atm.

Can you verify the following use cases are working as expected:

  • Partial GPU offload (i.e. -ngl less than total number of layers)
  • Expert offloading (e.g. -n-cpu-moe less than total number of layers)
  • Multi-GPU with pipeline parallel (e.g. -ub 256 -b 2048)

Regarding my previous comment, I decided to go for option 2) for checking the backend type in ggml-cuda.cu. Let me know if this is strict enough.

I'll try to answer this - atm not 100% sure. I'll also try to adapt this to the Metal backend and see if everything is compatible there as well.


// To not change any APIs or change what ggml-base links to, we can only detect backends by string matching
auto backend_name = ggml_backend_name(current_backend);
if (strncmp(backend_name, "CUDA", 4) == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I think Diego would reject any change that string compares the backend name. If we need to query a capability from the backend maybe add it to ggml_backend_dev_caps?

Copy link
Contributor Author

@aendk aendk Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fully agree that this is not optimal. Working from the feedback above, I aimed to propose no API changes. This was the only obvious way.
The only reason this is remotely ok for me is that only up to 4 chars are checked for each inference pass, and that this is likely irrelevant for performance, and whilst a bit ugly, passable for readability.

If we are open to relaxing the "no API change" rule, I'll gladly rework this.

#define GGML_SCHED_MAX_COPIES 4
#endif

enum ggml_backend_sync_mode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest concern is that it is still unclear what the synchronization requirements are for commands like get/set_tensor(_async), and this makes it even more confusing.

You've posed this as "a backend naturally has some amount of implicit synchronization, and this exposes information about that to the split code to take advantage of". But within a single backend there could be multiple amounts of synchronization it could do, how does somebody implementing a backend know how much synchronization they need to do? The backend interface needs to be a contract that clearly defines what synchronization is required of the backend.

Copy link
Contributor Author

@aendk aendk Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the current async landscape is a bit fuzzy, but I disagree that my additions make it substantially worse.
In other projects I've worked in before, the general gist was to use synchronized functions unless you knew exactly what to do, and needed to go async to squeeze out more performance. I feel that ggml is a bit similar in that regard.

For this reason, my change works on an opt-in basis. If you need more performance, you can relax the synchronization requirements for your backend. If not, keep it simple and retain default synchronous execution.

In the future, there could be new backends which might require a different type of synchronization. I believe that figuring out what level of synchronization relaxation is best for their backend, like any other performance tuning or optimization, is part of the responsibility of these developers.

If GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY (only a single sync between writes and reads) is not suitable for the vulkan backend, let me know. I would be happy to learn more about this, and to make my proposal more applicable for vulkan and similar backends.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of the ggml_backend_i contract atm is that using backend.get_tensor_async and backend.set_tensor_async without synchronization with respect to backend.graph_compute is not safe. I.e. this sequence of calls is not guaranteed to be thread-safe:

# a - async get/set
# g - compute

aaagaaag...

Although some backends such as Metal and CUDA would actually not have a problem in this case due to the implicit synchronization between the data movement and compute commands (i.e. due to sharing the same stream/queue), it is not a general assumption that can be made for all backends.

Because this is the current contract, the implementation in ggml_backend_sched_compute_splits() implements the safe path, avoiding the problematic async get/set/compute:

# c - sync get/set (i.e. data copy)
# g - async compute
# s - explicit sync

sccccsgsccccsg...

Without making any changes to the existing contract, we can relax this logic and avoid some of the extra synchronizations for the copies (as discussed in the other #17795 (comment)):

# a - async get/set
# g - async compute
# s - explicit sync

saaaasgsaaaasg...

This change should be fully compatible with the existing assumptions and any code that uses ggml-backend would not be impacted, since we make no changes to the interface.


Now, the proposal by @aendk attempts to enable the aaaagaaaag... path, which if the above is correct, is currently not safe to use by any ggml application. The bottomline is that in order for this path to work, the backend has to announce in some way whether this sequence of operations is safe depending on the backend implementation. Checking the name is obviously not ideal - likely we can extend the dev_caps to provide this information. Although I think this is rather a property/capability of the backend object itself, rather of the device, so a small concern here.

But for now lets assume that the dev_caps exposes an appropriate flag or enum. What should it be? For example, it can be bool data_and_compute_are_sequential, meaning that if the flag is true, it is guaranteed that the data movement and the compute share a common stream/queue and hence the aaagaaag... path is fine and the user does not need to do explicit synchronizations.

So far, my understanding is that a single flag like the above is all the information that is currently missing regarding this functionality. I am not convinced yet that more information (e.g. an enum as @aendk proposes) is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight addition on my path: Between inference passes, there is always a synchronization point to retrieve the logits. When looking at two inference passes, my proposal keeps them separate: aaaagsaaaags.

Copy link
Member

@ggerganov ggerganov Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is always a synchronization point to retrieve the logits

Generally, that synchronization is optional. For example, when you process a large prompt of N tokens with micro-batch size of N/4, there will be 3 graph computes without synchronization:

# prompt processing      text-generation starts
# v                      v
aaaagaaaaagaaaaagaaaaagas
#                      ^^
#                      |\- this is the sync for getting the logits from the final token
#                      \-- this is the async copy DtoH for the logits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on this, I hacked adding a sync before and after copies to the backend,
by adding a ggml_backend_synchronize(split_backend) right after it is defined, and running with GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY.
This leads to the desired saaasg pattern:
Screenshot 2026-01-19 at 15 14 47
Green cudaStreamSynchronize, followed by red copies, another synchronize and lastly the cudaGraphLaunch.
The performance loss on windows is not significant, even though measurements are a bit noisy on the RTX PRO 6000 Max-Q (not the non-Max-Q I used before). On windows, we lose between ~1% and ~2% compared to earlier tests:

| Model                 | Test   |   t/s b7772 |   t/s akieslinger/reduce-per-token-syncs |   Speedup |
|:----------------------|:-------|------------:|-----------------------------------------:|----------:|
| gpt-oss 20B MXFP4 MoE | tg128  |      295.02 |                                   317.25 |      1.08 |
| gpt-oss 20B MXFP4 MoE | tg256  |      299.84 |                                   318.51 |      1.06 |
| gpt-oss 20B MXFP4 MoE | tg512  |      294.47 |                                   316.15 |      1.07 |
| llama 3B Q4_K_M       | tg128  |      347.32 |                                   357.16 |      1.03 |
| llama 3B Q4_K_M       | tg256  |      364.00 |                                   373.05 |      1.02 |
| llama 3B Q4_K_M       | tg512  |      360.34 |                                   372.52 |      1.03 |
| qwen3 4B Q4_K_M       | tg128  |      267.54 |                                   275.00 |      1.03 |
| qwen3 4B Q4_K_M       | tg256  |      276.73 |                                   286.24 |      1.03 |
| qwen3 4B Q4_K_M       | tg512  |      271.34 |                                   284.77 |      1.05 |

On linux, I see the same 1-2% gain as before.
Note that this run still retains the proposed changes in ggml-cuda.cu. These are now synced by the added ggml_backend_synchronize(split_backend) call, as cpy_tensor_async() is called after this first sync, and because the host orchestration loop is implicitly synced with the CPU backend, as already detailed.

To sum up, the performance gain of this simplified approach is 1-2% less than before, but we still gain up to 7-8% on windows.

@aendk
Copy link
Contributor Author

aendk commented Jan 14, 2026

I double-checked partial GPU offload and MoE offloading via nsight systems and spot-checked coherent output via llama-cli. I did this on a single GPU. In short, I have not seen a scenario which would break with my code changes.

Detailed Analysis for Partial GPU offload

I ran gpt-oss 20B with -ngl 10.
Here are a series of screenshots out of nsight systems which I've annotated with NVTX. The grey row (NVTX on the very left) and the one below it (CUDA API) are the most relevant. I added annotations to the most relevant functions and loop iterations here.
Partial GPU offloading results in the creation of much more splits:
Screenshot 2026-01-14 at 11 12 11
Above, a single inference pass is done, in a single invocation of ggml_backend_sched_compute_splits(). Work is split over 33 splits, alternating between CPU and CUDA backends.
The screenshot below shows a typical transition: from execution on the CPU split (split=12) to a CUDA-split (split=13) and back to a CPU split (split=14) .
Screenshot 2026-01-14 at 11 13 10
Because of the implicit sync of the CPU backend with the host orchestration loop, the memory copies (red on the CUDA API row) to the CUDA backend are only triggered after the synchronous CPU backend is done. This is part of my code changes. On the transition to CPU, the memory copies are fully synchronous. This is untouched by this PR. The CPU backend does not have cpy_tensor_async() implemented, which leads to the synchronous fallback of syncing the previous backend (CUDA), and copying in a synchronous matter. Here, if performance critical, we could remove the first explicit sync for the special case CUDA->CPU in a follow-up. The explicit sync after the copy is however required.

Lastly, I saw another interesting transition: CPU (split=3) to CUDA (split=4) to CUDA (split=5) to CPU (split=6).
The previous example already showed 3-4 and 5-6, so the 4-5 CUDA to CUDA part is most interesting here:
Screenshot 2026-01-14 at 11 21 00
No explicit syncs take place (red bubbles not guarded by green synchronization bubbles), as cpy_tensor_async() takes place. This behavior however is not my change, slaren did this in #6017 at the start of 2024.

Ultimately, my code changes only touch the CPU-> CUDA transitions, other transitions work the same, and are at times already optimal (CUDA-CUDA).

Partial offloading is consequently working correctly with my changes.

Detailad analysis for MoE offloading

Setting --n-cpu-moe 2 leads to an execution of gpt-oss 20B in 6 splits: input embedding on CPU, CUDA execution up to the first expert, first expert on CPU (2), CUDA(3), 2nd expert on CPU(4), and the rest on CUDA (5):
Screenshot 2026-01-14 at 16 31 58

This means that like in partial GPU offloading, only CPU to CUDA and CUDA to CPU are relevant.
Here, we see a variation of this: 2 copies from CUDA to CPU:
Screenshot 2026-01-14 at 16 37 05
Because the CPU backend is unchanged and cpy_tensor_async() is not implemented, the fully synchronous fallback is used. My code changes do not affect this.
An avenue for future work could be here to implement cpy_tensor_async() for the CPU backend for the special case of the CUDA backend as input backend, as all but the very last synchronization (shown in green) are not required.

MoE offloading consequently did not show any new behavior which clashes with my changes.

I'll look into multi-GPU in a subsequent post.

Comment on lines +1462 to +1518
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else {
ggml_backend_synchronize(split_backend);
ggml_backend_synchronize_if_required(split, split_backend);
}
ggml_backend_tensor_copy(input, input_cpy);
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
ggml_backend_synchronize_if_required(split, split_backend, last_input);
Copy link
Member

@ggerganov ggerganov Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am following the logic correctly, I think there is a very simple alternative approach that would still eliminate most of the synchronizations, without any changes to the backend API/infra.

So on master, for multiple inputs N to the current split, we do the following:

# method a
# s - synchronize split backend
# c - sync copy (ggml_backend_tensor_copy) (each of this blocks the CPU thread)

sccc...[n times]...ccc

Let's say for starts, we want to remove just the inter-input syncs (keep it simple):

# method b
# s - synchronize split backend
# a - async copy (ggml_backend_cuda_cpy_tensor_async)

saaa...[n times]...aaas

Note that we keep the syncs at the start and at the end in order to keep this functionality the same as on master (currently, I am still not 100% sure in which cases these syncs can be safely avoided and am worried we might be missing some possible data race)

_[actually, this part is irrelevant - events are not needed]_

So to achieve this second method, we can use the existing ggml_backend_tensor_copy_async and in addition, record an event after all the copies:

# s - synchronize backend
# a - async copy
# r - record event
# w - synchronize backend on the event

saaa...[n times]...aaarw

This logic only works if the split backend has async copies and events supported. Otherwise, we fallback to method a.

I think this change would be:

  • safe as we only "merge" sequential input copies
  • not as optimal as the current proposal, but mostly there
  • simpler as it avoids introducing new concepts to the backend logic

Does this make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that your proposal would actually be a bit less safe, more complex to verify and merge, and also less performant.
I agree that your proposal is a bit simpler, but I think 2 local helper functions + 1 enum is ok for the impact it enables.

Default behavior

  • Not changing the default behavior is an explicit feature of my proposal. No changed behavior for unenrolled backends means less test overhead and overall lower risk for this PR.
  • I only add two cases next to the default behavior with explicit syncs before and after each copy:
  • GGML_SPLIT_SYNC_MODE_IMPLICIT: aaaaaaaaag (where g= graph launch). This is the case CUDA supports.
  • GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY : aaaaaaaaasg. This is the case vulkan likely supports. If vulkan does indeed support it and if it benefits from needing less syncs, we can enroll it using ~3 LoC.
  • If you think that the pattern saaaaaaaaasg could be useful, I can add this case, too.
    We could name it something like (GGML_SPLIT_SYNC_MODE_COPY_BOUNDARY).
  • Generally, having these 3-4 modes allows us to evaluate relaxing synchronizations on a case-by-case basis, and enroll compatible backends over time. This splits the risk and the verification overhead.
    I think you proposed setting the "new default" for all backends to this case, and this looks more complex to test/validate all backends, at least for me.

Perf

Regarding performance, on gpt-oss we gain 8-9% by avoiding 20 explicit syncs per inference pass in total. Each original sync here costs ~0.4%-0.45%.
Dropping only the syncs between copies from the host orchestration loop to the backend would leave 5 (3 for CPU split to CUDA split + 1 start + 1 end of HtoD copies) syncs open, so 2-2.5%.
That is roughly a quarter of the current performance gains.

@aendk
Copy link
Contributor Author

aendk commented Jan 15, 2026

Pipeline Parallelism Analysis

The short answer is that pipeline parallelism uses the event-based scheduling mechanism.
It is a different code path for scheduling than the one I modify here. My changes do not affect it, except for the already discussed CPU-CUDA async copy. Instead of cudaStreamSynchronize calls, I see cudaStreamWaitEvent, cudaEventSynchronize and cudaEventCreate calls.

Below, I briefly analyzed the differences I found compared to the single GPU setting. These findings could be a basis for future optimizations.

Detailed Analysis

With gpt-oss 20B, computation on two GPUs (2x Ada A6000s) is conducted on 3 splits: 1) CPU split for input embedding, 2) CUDA execution on GPU0, 3) CUDA execution on GPU1. Consequently, the interesting parts are CPU->CUDA0 and CUDA0->CUDA1.

CPU->CUDA0

Screenshot 2026-01-15 at 15 34 02

Here, we see a single async copy from the CPU backend to the CUDA backend, followed by 8 host to device copies from the host orchestration loop. The behavior of the CPU backend to CUDA backend copy differs compared to previous analyses:
The new CPU->CUDA async copy is leveraged, but before it, a cudaStreamWaitEvent is called.
For the H2D copies, we can see that before each copy, a cudaEventSynchronize is called. On first look, a few of these look superfluous.

CUDA0->CUDA1

In #6017, slaren implemented async memory copies for this special case from CUDA to CUDA. The screenshot below shows this; using cudaMemcpyPeerAsync (instead of cudaMemcpyAsync as before) tensors are copied from a single device to the next.
Screenshot 2026-01-15 at 16 40 25

There is also some synchronization between the async copies. On CPU, these do not matter anymore as a GPU works in parallel to this (blue row at the top), but it might be useful to analyze this further, as the actual switch from GPU0 to GPU1 looks a bit slow at around ~755us:

Screenshot 2026-01-15 at 16 46 10

Here, GPU execution is also indicated at the top of the image in blue. First, the graph is executed on GPU0, then computation is continued 1 row below on GPU1. Between this, both GPUs idle and all synchronization and copies take place.
The duration of this idle period is shown at the top of the light green vertical bar at 755,470us.
The host code which queued all these commands can be seen on the bottom left.

@aendk aendk force-pushed the akieslinger/reduce-per-token-syncs branch from d941624 to 5034414 Compare January 20, 2026 08:52
@aendk
Copy link
Contributor Author

aendk commented Jan 20, 2026

Another round of testing; Comparing the my second to last commit(43f6684), and the simplified version as the last commit of this branch:
Linux:

| Model                 | Test   |   t/s 43f6684f3 |   t/s akieslinger/reduce-per-token-syncs |   Speedup |
|:----------------------|:-------|----------------:|-----------------------------------------:|----------:|
| gpt-oss 20B MXFP4 MoE | tg128  |          362.01 |                                   360.66 |      1.00 |
| gpt-oss 20B MXFP4 MoE | tg256  |          362.59 |                                   361.48 |      1.00 |
| gpt-oss 20B MXFP4 MoE | tg512  |          358.09 |                                   356.98 |      1.00 |
| llama 3B Q4_K_M       | tg128  |          429.00 |                                   428.17 |      1.00 |
| llama 3B Q4_K_M       | tg256  |          428.03 |                                   427.25 |      1.00 |
| llama 3B Q4_K_M       | tg512  |          425.29 |                                   424.03 |      1.00 |
| qwen3 4B Q4_K_M       | tg128  |          319.82 |                                   319.37 |      1.00 |
| qwen3 4B Q4_K_M       | tg256  |          319.85 |                                   319.28 |      1.00 |
| qwen3 4B Q4_K_M       | tg512  |          314.76 |                                   314.22 |      1.00 |

Windows:

| Model                 | Test   |   t/s 43f6684f3 |   t/s akieslinger/reduce-per-token-syncs |   Speedup |
|:----------------------|:-------|----------------:|-----------------------------------------:|----------:|
| gpt-oss 20B MXFP4 MoE | tg128  |          303.01 |                                   296.67 |      0.98 |
| gpt-oss 20B MXFP4 MoE | tg256  |          299.85 |                                   298.88 |      1.00 |
| gpt-oss 20B MXFP4 MoE | tg512  |          298.36 |                                   296.23 |      0.99 |
| llama 3B Q4_K_M       | tg128  |          336.01 |                                   350.15 |      1.04 |
| llama 3B Q4_K_M       | tg256  |          352.80 |                                   358.26 |      1.02 |
| llama 3B Q4_K_M       | tg512  |          349.64 |                                   352.17 |      1.01 |
| qwen3 4B Q4_K_M       | tg128  |          268.47 |                                   264.41 |      0.98 |
| qwen3 4B Q4_K_M       | tg256  |          270.69 |                                   271.52 |      1.00 |
| qwen3 4B Q4_K_M       | tg512  |          264.82 |                                   267.58 |      1.01 | 

Both are run on the RTX PRO 6000 Max-Q. As with the hacky benchmark presented in my previous comment here, this PR is still yielding up to 7-8% improvement on windows.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current patch is good. However, note that I myself am still grasping the logic of the scheduler and it's possible that we are missing something, so I prefer to merge this when we are more confident and potentially get more feedback from other maintainers.

I am working on some updates for the Metal backend (#18919, #18966) that will allow me to exercise and understand this logic in my development environment. So unless we get strong feedback that this change works as expected and does not have side effects, I prefer to take the time to fully understand the backend scheduler logic before merging.

@JohannesGaessler
Copy link
Contributor

The changes as they are right now are I think correct for CUDA since the copy is queued on the src backend and the dst backend then waits for that copy to complete. Unfortunately the exact synchronization behavior that a ggml backend is expected to implement is not documented. While working on backend-agnostic tensor parallelism I've also had to consider synchronization of multiple backends and run into similar issues. I think what we should do is write down the exact requirements for ggml backends, check these with ggml backend maintainers, and then reason about whether the code in this PR is correct given the requirements we wrote down.

@JohannesGaessler
Copy link
Contributor

Unfortunately the exact synchronization behavior that a ggml backend is expected to implement is not documented.

Sorry, I was looking at the wrong function. In ggml-backend.h it says this:

// asynchronous copy
// the copy is performed after all the currently queued operations in backend_src
// backend_dst will wait for the copy to complete before performing other operations
// automatic fallback to sync copy if async is not supported
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);

So based on these requirements I think the logic in this PR is correct.

@aendk aendk force-pushed the akieslinger/reduce-per-token-syncs branch from 3ecffb2 to 81c9618 Compare February 9, 2026 12:46
@ggerganov
Copy link
Member

The current behavior of what is implicitly ordered is more or less undocumented and challenging to reverse-engineer (sounds like you had to do some of this yourself). Adding code paths that rely on backend-specific implicit ordering is IMO a step in the wrong direction.

@jeffbolznv Do you have an opinion on the latest version of this PR? Are your earlier concerns still there?

I've been doing some tests on my end, getting familiar with the scheduler and I think the logic is sound. Thinking of merging this unless there are some remaining concerns. I believe @JohannesGaessler is mostly OK too.

Most likely we can merge after we resolve some of the ongoing CUDA issues (e.g. #19645 (comment), #19816) in order to avoid overlapping too many changes at the same time.

In any case, happy to hear any feedback and opinions.

@JohannesGaessler
Copy link
Contributor

From my end I think the changes in this PR are correct and would be fine to merge. If we wanted to be extra safe we could wait until we have end-to-end tests though I suspect that those would still be inconclusive for issues with synchronization.

@jeffbolznv
Copy link
Contributor

I'm ok with the current changes.

@ggerganov ggerganov merged commit 2cd20b7 into ggml-org:master Mar 5, 2026
148 of 149 checks passed
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 10, 2026
…ml-org#17795)

* Adds CPU-to-CUDA copy capability to
ggml_backend_cuda_cpy_tensor_async()

* Adds function to relax sync requirements between input copies on
supported backends (CUDA for now)

* Exchanges synchronous copy with async copy function.

* Adds macro guards to allow compilation in non-CUDA builds

* Reworked backend detection in ggml-backend.cpp to avoid linking
conflicts

* Relax requirement of checks in async CUDA copies from backend and buffer type to just buffer type, to avoid linking issues

* Minor cleanup

* Makes opt-in to relax use of explicit syncs more general. Backends like
vulkan which require a synchronization between HtoD copies and graph
execution could also adopt this change now.

* Reintroduces stricter check for CPU->CUDA backend async copy via
GGML_DEVICE_TYPE_CPU.

* Corrects initialization of ggml_backend_sync_mode in
ggml_backend_sched_split initialization

* Simplifies synchronizations to adhere to `saaasg` pattern.

* Apply suggestion from @ggerganov (src->buffer to buf_src)

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Apply suggestion from @ggerganov (src->buffer to buf_src) v2

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
ggerganov added a commit that referenced this pull request Mar 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants