Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overlap CUDA graph building and processing to minimize GPU idle time and improve tokens per seconds performance. #11867

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

aendk
Copy link
Contributor

@aendk aendk commented Feb 14, 2025

Hi all,

this PR takes the ideas applied to the vulkan backend (#9118 and #10499) and implements them for CUDA. This results in improved tokens per second performance.

Performance

I tested on two systems using an example query in llama-cli and the phi3-mini-4k-instruct model.
Prompt eval tokens per second improved between 2.5 and 7%.
Context print tokens per second improved between 2.8 and 3.57%
Note that this is a PR to reduce CPU overhead, and that these numbers were generated using top-end CPUs.
On less powerful consumer CPUs, the performance increase will be more significant.
perf_results

Explanation

Currently, before every forward pass, a CUDA graph is built on CPU and then executed on GPU. This results in a delay, the GPU needs to wait around for the CPU to finish CUDA graph building.
Our proposed change splits the CPU workload into smaller pieces, with the effect that after the first graph has been built, the CPU and GPU can work in parallel on different CUDA graphs.
The before/after is shown in the below images from nsight systems. Top is the master, bottom is this changeset.
The time between the start of the forward pass (red/green timeline of the CUDA API) and GPU graph execution (orange) is measured. We highlighted the time taken (256us vs 56us) with a red circle. This seems small, but as this is done before each forward pass / token generation step, this adds up quickly.
Note that both screenshots are made at different time-scales, the width of the items itself is misleading. Only the measured time is relevant, and the pattern of the red/green operations of the CUDA-API.

Screenshot 2025-02-14 at 10 23 28

Performance impact of switching between graphs during forward passes

My code mirrors the changes in vulkan. In our testing, each forward pass is done with dozens of graphs. One could argue that the last few context switches likely are not required and hinder performance.

We investigated this. Switching between these is a non-issue for now, at about 2us per switch. However we could discuss strategies to steadily increase the graph size to reduce the number of context switches.
Screenshot 2025-02-14 at 10 23 58

@mtavenrath @agray3

for-loop to cycle through them, optimizes function calls to pass
specific graphs instead of the whole context
hard-coding 2 cuda graphs and setting custom offsets.
the vulkan backend. The first two graphs are small to minimize idle
time, and then graphs have uniform size.
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 14, 2025
@IMbackK
Copy link
Collaborator

IMbackK commented Feb 14, 2025

This also provides a small but mensurable speed up on ROCm / CDNA:

Master:

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |         tg128 |         85.00 ± 0.21 |

Pr:

ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |         tg128 |         88.04 ± 0.16 |

Making HipGraphs performance positive for the first time.

Unfortunately it also leads to random crashes in hipGraphDestroy, which at first examination seam to be ROCRs fault not this prs.

@slaren
Copy link
Member

slaren commented Feb 17, 2025

My understanding of the CUDA graph implementation is that a new graph is only created when there are incompatible changes to the graph, otherwise only a small number graph nodes are updated to reflect the new positions of the KV cache. Due to the KV padding, this should only happen every 32 and 256 tokens, with FA disabled and enabled respectively. Thus, in the worst case, this optimization could save 224us every 32 tokens, an average 7us per token. Is this what you are observing?

@agray3
Copy link
Contributor

agray3 commented Feb 17, 2025

My understanding of the CUDA graph implementation is that a new graph is only created when there are incompatible changes to the graph, otherwise only a small number graph nodes are updated to reflect the new positions of the KV cache. Due to the KV padding, this should only happen every 32 and 256 tokens, with FA disabled and enabled respectively. Thus, in the worst case, this optimization could save 224us every 32 tokens, an average 7us per token. Is this what you are observing?

That's not quite right, the cdst parameter of e.g. the cpy_f32_f16 kernel changes for every token, requiring cudaGraphKernelNodeSetParams and cudaGraphExecUpdate for every token (as well as cudaGraphLaunch). #9017 works around this need for frequent updates (but was deemed to have maintainability issues).

@slaren
Copy link
Member

slaren commented Feb 17, 2025

Thanks for the clarification. I understand that even in the case where only a few nodes need to be updated, the call to cudaGraphExecUpdate is expensive enough to make it worth splitting the graph into multiple parts.

I still think that this change adds a significant amount of complexity, to code that is already too fragile and complex to reasonably maintain. I mentioned this in the initial PR where you added support for CUDA graphs, and I still think this is the way to go, that this could be implemented via the "graph plan" API in ggml-backend. With the graph plan API, llama.cpp would take on the responsibility of creating and updating these plans, simplifying the logic in the CUDA backend significantly. It would also be possible to add higher level logic in llama.cpp to handle the plans, for example, it could prepare the plan for the next graph while a graph is being evaluated, effectively achieving the same that is done in this PR. Other backends could take advantage of the same optimizations by implementing the graph plan interface.

If that's something that you would be interested in implementing, I could help with that. As it is, I do not think that the performance difference is large enough to justify the added complexity, and I am not willing to take the responsibility of maintaining this code. Other maintainers may still review and merge this PR if they disagree.

@aendk
Copy link
Contributor Author

aendk commented Feb 17, 2025

I created some logs and checked how often cuda_graph_update_required is set to true. Out of 200 checks, only 9 set the variable to true.

However even with cuda_graph_update_required=false, a lot of checks are being done.
In nsight systems, this overhead is very visible and very uniform. The improvement is also quite visible. Contrary to our previous understanding, the benefit of this change is not limited to the cases where the graph needs to be rebuilt, but it applies to all other cases, too

For example in the master, is_cuda_graph_update_required() iterates over the whole cgraph. The same is true for check_node_graph_compatibility_and_refresh_copy_ops(), regardless of the result of the former function. In our change set, the first graphs only iterate over 25-50 and from then on 100 nodes, instead of the whole cgraph. With the trend of growing model architectures, this issue will only worsen.

The approach presented here is an elegant way to do these checks in parallel with the GPU, instead of as a blocking operation up front, as it's currently done. A full rewrite of this area might fix the aforementioned issues, but this proposed change is in my opinion ultimately much easier, less risky, and introduces very little additional complexity.

Regarding code complexity:

I encourage you and the other interested maintainers to read the change set commit by commit. It might look complex but it actually is reasonable.
The first commit changes some function definitions, and adds a for-loop. This is the bulk of the entire change set.
To be specific, it removes cuda_ctx in favor of cuda_graph, leading to many deletions of cuda_context->-snippets. This improves the structure and readability of the function code.
Additionally, it wraps the lower part of ggml_backend_cuda_graph_compute() into a single for-loop, and indents it. Due to debug ifdefs remaining unindented, this looks like a lot. The nature of both of these changes are however very simple and repetitive.

The rest of the commits are much smaller.

Regarding performance:

The numbers presented here are the best case. I will try to get a low-tier system, retest and report back.

Graph API

Out of interest and for future reference, do you have any plans or resources you can share for the graph plan API?
We of course want to continue to support llama.cpp as best as we can, so that inference on NVIDIA GPUs runs as fast and energy efficient as possible !

@FSSRepo
Copy link
Collaborator

FSSRepo commented Feb 18, 2025

@slaren Honestly, I think Flash Attention should be an optional feature in ggml since it doesn't introduce significant performance improvements, and the binary size has increased considerably—not to mention the compilation time, which, even though I only compile it for my GPU architecture, still takes 20 minutes on an i5-12400. It is not related to this PR, but it would be good to take it into account.

@IMbackK
Copy link
Collaborator

IMbackK commented Feb 18, 2025

I have not been Collaborator very long here so my option should be taken with a grain of salt. My knowledge of ggml is also still very restricted to my perspective of looking at single kernels in an effort toward optimizing them, without spending much time looking at the whole picture.
However examining the cu/hipGraph support i would agree with @slaren that the whole graph support in the cuda backend seams non-ideal, it seams awkward that we capture the graph instead of constructing it from via ggml_cgraph like the metal backend dose and things like

// One of the arguments to the copy kernel is updated for each token, hence we need to
feal very hacky.

I also challenge the usefulness of the graph support in general, even on cuda it seams to increase performance only very slightly and only in situations where the t/s performance is very high already due to the execution of small models on high end hardware and even then the amount of exceptions in the code where it is disabled for being not useful is long.

On the other side i agree with @aendk that this pr hardly moves the needle at all in terms of complexity of the implementation of this feature while giving a larger performance boost than the pr that introduced graph support in the first place.

In my option not merging this pr makes no sense, if the intention is to keep the graph support as is. Either it should be merged or the current usage of hip/cuGraph should be eliminated entirely, potentially being replaced by code that explicitly constructs the graph.

@slaren
Copy link
Member

slaren commented Feb 18, 2025

I very much disagree that CUDA graphs are not useful. It's true that it mostly benefits small models and fast GPUs, but we use small models all the time, even if only as a draft model for speculative decoding. Here is an overview of the performance I get on my system:

No CUDA graphs:

model size params backend ngl fa test t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B CUDA 99 1 tg128 128.14 ± 1.75
gemma 2B Q4_0 1.44 GiB 2.51 B CUDA 99 1 tg128 232.97 ± 6.30
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 tg128 139.59 ± 1.48

CUDA graphs:

model size params backend ngl fa test t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B CUDA 99 1 tg128 362.72 ± 1.19
gemma 2B Q4_0 1.44 GiB 2.51 B CUDA 99 1 tg128 348.48 ± 2.33
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 tg128 173.45 ± 0.33

CUDA graphs + this PR:

model size params backend ngl fa test t/s
qwen2 1.5B Q4_0 1013.62 MiB 1.78 B CUDA 99 1 tg128 364.40 ± 1.74
gemma 2B Q4_0 1.44 GiB 2.51 B CUDA 99 1 tg128 313.56 ± 1.46
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 tg128 173.45 ± 0.66

Small print: this is under Windows (WSL), with hardware GPU scheduling enabled, which has a notoriously high kernel launch overhead. The difference is not as significant on Linux or with GPU scheduling disabled.

I do not disagree that the amount of complexity that this change adds is not very significant compared to the overall complexity of the feature, but I am just not willing to continue going on this road where we continue developing and adding complexity to a feature that should have been immediately refactored the moment it was added. At the time I concluded that the performance difference was too big to ignore the PR, but this is not the case here.

About the graph plan API: this is meant to be a very simplified abstraction of features similar to CUDA graphs. Currently it looks like this:

// (optional) graph plans (not used currently)
// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
void (*graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
// compute the graph with the plan
enum ggml_status (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);

As noted, this is not currently used, and can be changed if necessary. I would be willing to implement the minimum necessary changes to llama.cpp and ggml-backend to support this feature, but I would need your help implementing the interface on the CUDA backend.

@aendk
Copy link
Contributor Author

aendk commented Feb 18, 2025

Hi slaren, can you share your hardware setup and how you ran these tests? Additionally, how many runs have you done per configuration? This is the first time I've seen a perf decrease (gemma 2B) and I would love to get some more info.

Just FYI, I plan on running a low-end CPU until next week to see the worst case for CPU overhead, and thus the best improvement for this PR.

@slaren
Copy link
Member

slaren commented Feb 18, 2025

My hardware is an Intel 13900k and 3090 Ti, running in WSL under Windows 11. The command used to run this test was llama-bench -m ... -p 0 -fa 1. By default, llama-bench repeats each test 5 times (can be changed with -r).

The difference with gemma-2b is not caused by this PR, @JohannesGaessler has made some optimizations recently that are not present in this PR. After merging master into this PR, the performance drop disappears, but it is still not significantly faster.

Model Test t/s master t/s akieslinger/reduce_cuda_graph_cpu_overhead Speedup
gemma 2B Q4_0 tg128 352.40 355.05 1.01

@IMbackK
Copy link
Collaborator

IMbackK commented Feb 18, 2025

I have debugged the hip graph crash this pr triggers to #11949 and can confirm that it is not the fault of this pr.

@IMbackK
Copy link
Collaborator

IMbackK commented Feb 19, 2025

@slaren

I do not disagree that the amount of complexity that this change adds is not very significant compared to the overall complexity of the feature, but I am just not willing to continue going on this road where we continue developing and adding complexity to a feature that should have been immediately refactored the moment it was added. At the time I concluded that the performance difference was too big to ignore the PR, but this is not the case here.

I gues that in this case the fact that this pr helps alot for hip systems on small models is a thing to consider then:

Master:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| qwen2 1.5B Q4_0                | 888.43 MiB |     1.54 B | ROCm       |  99 |         tg128 |        156.20 ± 0.61 |

Pr:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| qwen2 1.5B Q4_0                | 888.43 MiB |     1.54 B | ROCm       |  99 |         tg128 |        201.85 ± 0.32 |

Dont worry about the hip runtime bug, i have a simple oneliner to avoid it i will add soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants