Skip to content

CUDA: add stream-based concurrency#16991

Merged
am17an merged 18 commits intoggml-org:masterfrom
am17an:fused-qkv-stream
Nov 30, 2025
Merged

CUDA: add stream-based concurrency#16991
am17an merged 18 commits intoggml-org:masterfrom
am17an:fused-qkv-stream

Conversation

@am17an
Copy link
Contributor

@am17an am17an commented Nov 4, 2025

Possibly supersede #16813.

This PR adds support to run concurrent CUDA streams on single GPU setups.
At the moment this only targets the Q, K, V branch. I feel this is the "correct" approach in case the Q, K, V tensors are of different types/not in the same place in memory. The downside is that this approach doesn't come for free and there's some complexity involved, but I'm not an expert at the ggml graph and I feel it could be simplified.

Currently this is hidden by an env variable flag. To run you can use GGML_CUDA_GRAPH_OPT=1

TG Performance gain is more than the previous PR (1-9% gain depending on the model/GPU), probably because we parallelize MUL_MAT + NORM + ROPE rather than just MUL_MAT. At the moment we leave some performance on the table where we don't fuse operations in the parallel streams themselves (e.g. MUL_MAT + BIAS, RMS_NORM + MUL etc.), I couldn't find a simple enough way to enable fusion there.

Performance details:

Details

5090:

Model Test t/s mxfp4 t/s 0fa5630c0 Speedup
gpt-oss 20B MXFP4 MoE pp512 10554.76 10572.61 1.00
gpt-oss 20B MXFP4 MoE tg128 359.67 382.52 1.06
llama 8B Q8_0 pp512 13948.06 13903.74 1.00
llama 8B Q8_0 tg128 166.44 174.08 1.05
qwen3 8B F16 pp512 14120.38 14170.43 1.00
qwen3 8B F16 tg128 97.21 100.39 1.03
qwen3moe 30B.A3B Q4_K_M pp512 7361.89 7367.90 1.00
qwen3moe 30B.A3B Q4_K_M tg128 306.84 334.19 1.09

4090:

Model Test t/s mxfp4 t/s 0fa5630c0 Speedup
gpt-oss 20B MXFP4 MoE pp512 9551.59 9549.55 1.00
gpt-oss 20B MXFP4 MoE tg128 263.77 271.84 1.03
llama 8B Q8_0 pp512 12111.67 12105.00 1.00
llama 8B Q8_0 tg128 105.57 107.57 1.02
qwen3 8B F16 pp512 10437.09 10414.37 1.00
qwen3 8B F16 tg128 58.99 59.79 1.01
qwen3moe 30B.A3B Q4_K_M pp512 7246.49 7246.52 1.00
qwen3moe 30B.A3B Q4_K_M tg128 248.75 268.12 1.08

And just for comparison, this is without fusing ops inside a stream

Details

5090:

Model Test t/s mxfp4 t/s e50fbde39 Speedup
gpt-oss 20B MXFP4 MoE pp512 10570.11 10548.35 1.00
gpt-oss 20B MXFP4 MoE tg128 359.37 373.34 1.04
llama 8B Q8_0 pp512 13905.82 13913.88 1.00
llama 8B Q8_0 tg128 166.30 174.15 1.05
qwen3 8B F16 pp512 14108.08 14116.40 1.00
qwen3 8B F16 tg128 97.20 99.69 1.03
qwen3moe 30B.A3B Q4_K_M pp512 7358.84 7354.03 1.00
qwen3moe 30B.A3B Q4_K_M tg128 306.36 325.89 1.06

4090:

Model Test t/s mxfp4 t/s e50fbde39 Speedup
gpt-oss 20B MXFP4 MoE pp512 9547.67 9545.65 1.00
gpt-oss 20B MXFP4 MoE tg128 263.74 270.22 1.02
llama 8B Q8_0 pp512 12071.11 12041.15 1.00
llama 8B Q8_0 tg128 105.57 107.61 1.02
qwen3 8B F16 pp512 10409.54 10385.86 1.00
qwen3 8B F16 tg128 59.00 59.64 1.01
qwen3moe 30B.A3B Q4_K_M pp512 7254.97 7240.68 1.00
qwen3moe 30B.A3B Q4_K_M tg128 248.53 263.10 1.06

TODO:

  • Enable fusion within a stream
  • Add tests?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 4, 2025
@JohannesGaessler
Copy link
Contributor

Sorry, I wanted to tell you this but I forgot: a long time ago I tried something similar, see #4719 . There the performance did not improve, I think the reason was the lack of CUDA graphs to reduce the overhead.

@am17an
Copy link
Contributor Author

am17an commented Nov 4, 2025

Yeah, I think CUDA graphs are essential for this to work (hence this PR only looks at batch_size=1)

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 6, 2025

Minimal changes to make this work on hip:

diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index d3153f430..28bafb84e 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -25,6 +25,7 @@
 #include <cfloat>
 #include <cstdio>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #if defined(GGML_USE_HIP)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f69b99b2a..f0df4a9a9 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3514,7 +3514,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
 
                         for (int i = 1; i <= concurrent_event->n_streams; ++i) {
                             cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
-                            CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
+                            cudaStreamWaitEvent(stream, concurrent_event->fork_event);
                         }
 
                         is_concurrent_event_active = true;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 890c10364..b7d6edf7f 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -105,7 +105,7 @@
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamPerThread hipStreamPerThread
 #define cudaStreamSynchronize hipStreamSynchronize
-#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaStreamWaitEvent hipStreamWaitEvent
 #define cudaGraphExec_t hipGraphExec_t
 #define cudaGraphNode_t hipGraphNode_t
 #define cudaKernelNodeParams hipKernelNodeParams

If used for real, cudaStreamWaitEvent error needs to handled of course
hipEventCreateWithFlags is also nodiscard which needs to be handled

with -DGGML_HIP_GRAPHS=On and GGML_CUDA_GRAPH_OPT=1 in env this seams performance neutral on mi100:

Model Test t/s b6955 t/s pr Speedup
llama 13B Q8_0 tg128 29.07 29.06 1.00

@am17an
Copy link
Contributor Author

am17an commented Nov 6, 2025

The almost exact same numbers make me think that this change is not launching the streams. I would expect a shift in performance either for the worse or the better.

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 6, 2025

yeah ill run a trace on it later.

@am17an am17an marked this pull request as ready for review November 8, 2025 06:03
@am17an am17an requested a review from slaren as a code owner November 8, 2025 06:03
@am17an
Copy link
Contributor Author

am17an commented Nov 9, 2025

It turns out there was a bug in finding the correct node to fork on because of RMS fusion, so what was happening was the Q mul-mat was still running on stream 0, and only after it would finish would we launch the rest of the streams.

Secondly, this is still without fusing nodes within a stream, I could make fusion work by re-ordering back the graph to it's original shape, I think the best solution here would be to signal to the ggml-graph that these nodes are required later and their memory should not be re-used, is there a way to do this?

Also as long as the Q, K, V stuff per layer lives on the same device, this should also work for the multi-GPU case too, but that is not addressed in this PR

On a 5090:

Details

Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg32 285.58 ± 2.98
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg64 266.76 ± 0.16
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg128 264.19 ± 0.09
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg32 382.39 ± 0.16
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg64 370.78 ± 0.11
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg128 367.91 ± 0.18
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg32 85.41 ± 0.01
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg64 83.95 ± 0.01
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg128 83.60 ± 0.05

After:

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg32 304.13 ± 0.36
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg64 281.68 ± 0.06
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg128 276.83 ± 0.28
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg32 391.99 ± 0.45
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg64 380.56 ± 0.14
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg128 375.88 ± 0.05
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg32 86.32 ± 0.03
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg64 84.86 ± 0.01
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg128 84.44 ± 0.01
Verified with nsys profile image

@am17an am17an force-pushed the fused-qkv-stream branch 3 times, most recently from 12d5f82 to d3a8d93 Compare November 10, 2025 05:17
@am17an
Copy link
Contributor Author

am17an commented Nov 10, 2025

@ggerganov would you mind testing this on your DGX spark? I want to see if the low memory bandwidth GPUs benefit from this change

@jeffbolznv
Copy link
Contributor

Secondly, this is still without fusing nodes within a stream, I could make fusion work by re-ordering back the graph to it's original shape, I think the best solution here would be to signal to the ggml-graph that these nodes are required later and their memory should not be re-used, is there a way to do this?

I'm not really clear on what the problem is here you're trying to solve. If the order is: MUL_MAT+ADD+MUL_MAT+ADD+MUL_MAT+ADD, then you have the nodes conveniently consecutive (for fusion), the intermediate MUL_MAT outputs aren't needed and the ADDs will all have different outputs. This is the order ggml-vulkan will use and it gets both fusion and concurrency.

@am17an
Copy link
Contributor Author

am17an commented Nov 10, 2025

My problem is that the buffer gets reused in this case. The graph assumes serial execution, and thinks the first mul-mats buffer is no longer required. (Assume the no fusion case for now)

@jeffbolznv
Copy link
Contributor

If you're not doing fusion, then you'd want graph_optimize to reorder these to MUL_MAT+MUL_MAT+MUL_MAT+ADD+ADD+ADD. Then the MUL_MAT results will stay live until the ADDs.

@am17an
Copy link
Contributor Author

am17an commented Nov 10, 2025

Yeah that's what the current re-order does. But that doesn't allow for fusion. I don't want these two things to be intertwined. Ideally want something that just lets me extend the lifetime for a particular output till a certain node

@jeffbolznv
Copy link
Contributor

IMO they are fundamentally intertwined. The code that detects fusions looks for specific sequences of operations, and graph_optimize should generate or preserve those sequences. If the backend supports fusing something, then graph_optimize should make them consecutive both to make the fusion logic simpler and to shorten the lifetime of transient allocations.

@am17an
Copy link
Contributor Author

am17an commented Nov 10, 2025

I think it will be good to isolate these two behaviours. If you see the graph above it can launch a concurrent graph from the mul-mat till set rows. We don't have fusion for that entire sequence, and reasoning about which output stays alive would involve inspecting the graph in any case.

Secondly fusion is a common source of bugs in the cuda backend, I don't want to add another layer of complexity on top of it.

@ORippler

This comment was marked as outdated.

@am17an
Copy link
Contributor Author

am17an commented Nov 11, 2025

Did you pass the env flag?

@ORippler
Copy link
Collaborator

ORippler commented Nov 11, 2025

Did you pass the env flag?

Why of course not 🫠

image

Updated numbers

+ ./scripts/compare-llama-bench.py -b b8595b16e69e3029e06be3b8f6635f9812b2bc3f -c d3a8d93a04d4cb8cc8daa62d12e6763a2f4c0221 --tool llama-bench -i llama-bench.sqlite
Model Test t/s b8595b1 t/s d3a8d93 Speedup
gpt-oss 120B MXFP4 MoE tg32 43.42 43.86 1.01
gpt-oss 120B MXFP4 MoE tg64 49.96 50.20 1.00
gpt-oss 120B MXFP4 MoE tg128 50.95 51.13 1.00
gpt-oss 120B MXFP4 MoE tg256 50.86 51.10 1.00
gpt-oss 120B MXFP4 MoE tg512 50.44 50.86 1.01
gpt-oss 20B MXFP4 MoE tg32 80.20 81.24 1.01
gpt-oss 20B MXFP4 MoE tg64 80.81 81.45 1.01
gpt-oss 20B MXFP4 MoE tg128 80.97 81.52 1.01
gpt-oss 20B MXFP4 MoE tg256 80.98 81.55 1.01
gpt-oss 20B MXFP4 MoE tg512 80.57 81.18 1.01
llama 3B Q4_0 tg32 103.85 105.21 1.01
llama 3B Q4_0 tg64 104.79 106.39 1.02
llama 3B Q4_0 tg128 104.77 106.43 1.02
llama 3B Q4_0 tg256 104.79 106.50 1.02
llama 3B Q4_0 tg512 103.83 105.49 1.02
qwen3moe 30B.A3B Q3_K_S tg32 94.83 96.21 1.01
qwen3moe 30B.A3B Q3_K_S tg64 95.72 96.97 1.01
qwen3moe 30B.A3B Q3_K_S tg128 96.05 97.27 1.01
qwen3moe 30B.A3B Q3_K_S tg256 96.12 97.40 1.01
qwen3moe 30B.A3B Q3_K_S tg512 95.19 96.42 1.01

@ORippler
Copy link
Collaborator

My problem is that the buffer gets reused in this case.

If my current understanding of ggml is correct, we should be able to get the same behavior (fusion + concurrency) on both vulkan + cuda, as Q should go out of life after flash-attention (and K + V go out of life after being inserted into the KV-cache). Have we root-caused this?

@CISC CISC removed the request for review from lhez November 27, 2025 15:02
@am17an
Copy link
Contributor Author

am17an commented Nov 27, 2025

What just happened?

My git just tweaked in the middle of auto-compaction

@ggerganov
Copy link
Member

ggerganov commented Nov 27, 2025

Consider the following implementation:

  • When creating a ggml graph, optionally allocate an array of integers like node_stream_ids.
  • When adding new nodes to the graph, check node_stream_ids and create forks for concurrent streams if possible by setting node_stream_ids for the new node.
  • When allocating the graph, use node_stream_ids to avoid recycling nodes from other streams that are not joined.
  • In the backends, use node_stream_ids as the basis for constructing streams that can be safely executed in parallel.

@ggerganov @am17an would a design like that be acceptable to you as the long-term goal for how to handle out-of-order execution in ggml?

Not sure if you have a specific logic in mind for assigning the node_stream_ids, but it would be trivial to assign those ids during ggml_visit_parents - for every next iteration of the GGML_MAX_SRC loop, simply increase the current stream id before calling the ggml_visit_parents recursively.

Atm, I don't have a good feeling how difficult would be to modify the allocator to respect the stream ids.

@jeffbolznv
Copy link
Contributor

IMO ggml_visit_parents should just assign everything to stream 0 and then graph_optimize can reassign to streams based on what the backend wants.

@JohannesGaessler
Copy link
Contributor

Atm, I don't have a good feeling how difficult would be to modify the allocator to respect the stream ids.

We need to ensure that nodes are only recycled after they've been used for the last time and after their results have "become visible", meaning that we need to ensure that we don't recycle nodes where the execution can be in the future. I would suggest allocating an array like int becomes_visible_at[n_nodes][n_streams][n_streams] that describes for a given node index when a result from some first stream will become visible for some other stream.

  • First initialize the entry for the last node with n_nodes because we trivially know (for all node positions) that the results from all streams will be visible after the graph has finished executing.
  • For the last node, get its stream id. Set mapping from its own stream id as well as all of its sources to n_nodes - 1. We know that even under out-of-order execution those results must be visible by now due to ordering within the stream as well as data dependencies.
  • Iterate over the nodes back to front. Copy data from one index higher as an upper bound for when nodes become visible, repeat the same steps as for the last node.
  • During the allocation, recycle a node only once the maximum node index of the pre-existing logic for last use and becomes_visible_at has been passed.

IMO ggml_visit_parents should just assign everything to stream 0 and then graph_optimize can reassign to streams based on what the backend wants.

I would also be fine with determining stream ids in the graph optimization step.

@JohannesGaessler
Copy link
Contributor

When I tested performance:

GPU Model Microbatch size Test t/s b7152 t/s c57b750 Speedup
RTX 3090 llama 1B Q4_0 512 tg128 588.26 622.51 1.06
RTX 3090 llama 1B Q4_0 512 tg128@d32768 251.12 256.14 1.02
RTX 3090 llama 8B Q4_0 512 tg128 153.14 156.16 1.02
RTX 3090 llama 8B Q4_0 512 tg128@d32768 80.02 80.87 1.01
RTX 3090 Storied 15m f32 512 tg128 21.45 14.77 0.69
RTX 3090 Storied 15m f32 512 tg128@d32768 8.92 5.77 0.65
RTX 4090 llama 1B Q4_0 512 tg128 741.87 779.35 1.05
RTX 4090 llama 1B Q4_0 512 tg128@d32768 296.65 302.51 1.02
RTX 4090 llama 8B Q4_0 512 tg128 177.80 182.63 1.03
RTX 4090 llama 8B Q4_0 512 tg128@d32768 88.84 90.03 1.01
RTX 4090 Storied 15m f32 512 tg128 1150.35 1144.23 0.99
RTX 4090 Storied 15m f32 512 tg128@d32768 3.23 3.93 1.22
RTX 5090 llama 1B Q4_0 512 tg128 794.24 865.38 1.09
RTX 5090 llama 1B Q4_0 512 tg128@d32768 322.90 339.30 1.05
RTX 5090 llama 8B Q4_0 512 tg128 217.98 228.93 1.05
RTX 5090 llama 8B Q4_0 512 tg128@d32768 112.55 115.34 1.02
RTX 5090 Storied 15m f32 512 tg128 732.16 494.51 0.68
RTX 5090 Storied 15m f32 512 tg128@d32768 5.23 4.58 0.88

As expected, for typical model sizes there is more speedup for fast GPUs running small models on an empty context. For very small models this trend does not seem to monotonically continue, however.

@am17an am17an merged commit c7af376 into ggml-org:master Nov 30, 2025
71 of 74 checks passed
@am17an am17an deleted the fused-qkv-stream branch November 30, 2025 00:18
@ggerganov
Copy link
Member

For my understanding, why do we gate the functionality behind GGML_CUDA_GRAPH_OPT - is the plan after enough validation to make it the default? Or are there going to be use cases that are incompatible with it and it would always remain opt-in?

@ORippler
Copy link
Collaborator

ORippler commented Dec 1, 2025

For my understanding, why do we gate the functionality behind GGML_CUDA_GRAPH_OPT - is the plan after enough validation to make it the default? Or are there going to be use cases that are incompatible with it and it would always remain opt-in?

My understanding is the former. We currently still have issues with BS > 1 for example, where it's disabled #16991 (comment).

@am17an
Copy link
Contributor Author

am17an commented Dec 1, 2025

Yes this is the initial experimental support, wherever we use CUDA graphs it should be possible to default to this being true (without CUDA graphs there is a performance drop). So at the moment it can be true for BS=1 universally for fully offloaded models, however I would like to give it sometime to make this flag something like -fa auto

@ggerganov
Copy link
Member

ggerganov commented Dec 1, 2025

I would like to give it sometime to make this flag something like -fa auto

I don't think it should become a CLI argument. Ideally, it should be an env variable that is enabled by default and can optionally be disabled (mainly for debugging purposes). Similar to how we do it with GGML_METAL_GRAPH_OPTIMIZE_DISABLE and GGML_VK_DISABLE_GRAPH_OPTIMIZE.

am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 30, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 31, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 31, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 31, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 31, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 31, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
am17an added a commit to am17an/llama.cpp that referenced this pull request Dec 31, 2025
This PR enables concurrent streams introduced in ggml-org#16991 by default. To disable a new env flag `GGML_CUDA_DISABLE_GRAPH_OPT` is introduced
Anico2 added a commit to Anico2/llama.cpp that referenced this pull request Jan 15, 2026
* CUDA: add stream-based concurrency

* HIP: fix hipStreamWaitEvent define and nodiscard warnings

* ggml-cuda: fix fusion inside stream

* ggml-cuda: fix bug w.r.t first stream launch

* ggml-cuda: format

* ggml-cuda: improve assert message

* ggml-cuda: use lambda instead of duplicating code

* ggml-cuda: add some more comments

* ggml-cuda: add more detailed comments about concurrency

* ggml-cuda: rename + remove unused var

* ggml-cuda: fix condition for stream launch

* ggml-cuda: address review comments, add destructor

* common.cuh: add is_valid for concurrent events

* common.cuh: make comment better

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* common.cuh: fix lower_bound condition + remove join_node data from write_ranges

* ggml-cuda: fix overlap condition + shadowing parameter

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@bssrdf
Copy link
Contributor

bssrdf commented Jan 22, 2026

@am17an, thanks for this PR, impressive speed up.

I am wondering if this graph optimization can be adopted by image generation models. Both UNET and Dit have attention (self and cross) so should benefit from the concurrency. Could you provide some guidance? Thanks.

@am17an
Copy link
Contributor Author

am17an commented Jan 22, 2026

@bssrdf happy to help. Perhaps you can email me and we can discuss. My email is on my profile page

blime4 referenced this pull request in blime4/llama.cpp Feb 5, 2026
* CUDA: add stream-based concurrency

* HIP: fix hipStreamWaitEvent define and nodiscard warnings

* ggml-cuda: fix fusion inside stream

* ggml-cuda: fix bug w.r.t first stream launch

* ggml-cuda: format

* ggml-cuda: improve assert message

* ggml-cuda: use lambda instead of duplicating code

* ggml-cuda: add some more comments

* ggml-cuda: add more detailed comments about concurrency

* ggml-cuda: rename + remove unused var

* ggml-cuda: fix condition for stream launch

* ggml-cuda: address review comments, add destructor

* common.cuh: add is_valid for concurrent events

* common.cuh: make comment better

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* common.cuh: fix lower_bound condition + remove join_node data from write_ranges

* ggml-cuda: fix overlap condition + shadowing parameter

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants