-
Notifications
You must be signed in to change notification settings - Fork 156
FlashMLA-2: reduce compute buffer size (CUDA and CPU) #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller.
I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only.
On CUDA just a quick hack that allows us to cancatenate tensors with more than 65535 rows along zroth dimension as needed by FlashMLA-2. Also needed some care in the perplexity tool to avoid int overflows when evaluating the computed logits.
Oh, also fix int overflow in the CUDA concat implementation. It is funny how the llama.cpp 64-bit police has gone (almost) everywhere and replaced 32-bit ints with 64-bit ints, needed or not, but hasn't done it where it is actually needed.
|
Will test and report back. Thank you @ikawrakow PS. Those fixes for |
|
First model load: Segfault with `-c 16384 -amb 1024 -fmoe -mla 2 -fa` |
No. It is an integer overflow. The logit location in the array of logits was computed with 32-bit integers. As there are ~128k entries in the vocabulary, the integer multiplication |
It fails to allocate |
|
I'll take a quick stab at it too given using a simple 1x RTX A6000 48GB GPU configuration. Update$ git checkout ik/flash_mla2_cuda_no_f32
$ git rev-parse --short HEAD
b147e31f
$ cmake -B ./build -DGGML_CUDA=ON -DGGML_BLAS=OFF
$ cmake --build ./build --config Release -j $(nproc)
$ ./build/bin/llama-server --version
version: 3601 (b147e31f)
built with cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 for x86_64-linux-gnu
Basic CommandCUDA_VISIBLE_DEVICES="0," \
./build/bin/llama-server \
--alias ubergarm/DeepSeek-R1-Q2_K_R4 \
--model /mnt/raid/models/ubergarm/DeepSeek-R1-GGUF/DeepSeek-R1-GGUF-Q2_K_R4.gguf \
--ctx-size 16384 \
-ctk f16 -ctv f16 \
-mla 2 -fa \
-amb 1024 \
-fmoe \
--n-gpu-layers 63 \
--override-tensor exps=CPU \
--parallel 1 \
--threads 24 \
--host 127.0.0.1 \
--port 8080
llm_load_tensors: offloading 61 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 62/62 layers to GPU
llm_load_tensors: CPU buffer size = 225736.00 MiB
llm_load_tensors: CPU buffer size = 938.98 MiB
llm_load_tensors: CUDA0 buffer size = 17744.02 MiBResults
llama-bench# Run this twice, once with without specifying `-amb` at all and once like so:
CUDA_VISIBLE_DEVICES="0," \
./build/bin/llama-bench \
--model /mnt/raid/models/ubergarm/DeepSeek-R1-GGUF/DeepSeek-R1-GGUF-Q2_K_R4.gguf \
-ctk f16 -ctv f16 \
-mla 2 -fa 1 \
-amb 1024,128,64,32,16,8,4,1 \
-fmoe 1 \
--n-gpu-layers 63 \
--override-tensor exps=CPU \
--threads 24
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA RTX A6000, compute capability 8.6, VMM: yes
|
|
So, this looks quite a bit better than the main branch. It would seem that a single 24 GB GPU could handle the non-expert tensors and up to 32k context? |
Update: Yes, it is better than main branch as shown in table below. Also the quant I'm using has
Oh, I thought I noticed it was reporting less with Comparison TableI ran enough to show it is working, gonna stop and not fill in the blanks for now.
|
I haven't put a guard against using quantized cache for Based on the performance values @ubergarm posted, there doesn't seem to be any major performance impact, even with |
The relevant part of the above table for this specific question:
|
|
Sorry for delay here. As model loading takes quite a long amount of time on 16 GPUs, and I'm near to the limit there's been some OOMs (my own fault nothing to do with PR), I've been quite slow to come back. From what I can see so far, there is zero notable difference with performance of TODO |
Yeah, please double check me, but I updated my chart and command above which suggests going down to Curious if you have similar outcome across all your GPUs! |
Interestingly I got an error for Haven't seen that error before! Also, you should test setting |
|
Sorry, I wasn't clear enough with my request. The PP test should be done with
This is only relevant of the MoE experts are computed on CUDA. When the MoE part runs on the CPU the default |
Hrmm, I've seen some chatter about I was kinda surprised that you were offloading shared experts onto GPUs with your config given that doesn't work on ktransformers yet in my own testing an in their documentation:
I'll set that up and post the results here soon. |
@davidsyoung has 16 x 3090's, so the entire model is run on the GPU's. CUDA graphs get disabled for MoE models (also true on mainline |
Neither have I. It means that the back-end is miscalculating the required compute buffer size somehow. Not sure what to do about that. |
|
I increased CUDA_VISIBLE_DEVICES="0," \
./build/bin/llama-bench \
--model /mnt/raid/models/ubergarm/DeepSeek-R1-GGUF/DeepSeek-R1-GGUF-Q2_K_R4.gguf \
-ctk f16 -ctv f16 \
-mla 2 -fa 1 \
-amb 1024,128,64,32,16,8,4,2,1 \
-p 16384,8192 \
-n 0 \
-fmoe 1 \
-r 2 \
--n-gpu-layers 63 \
--override-tensor exps=CPU \
--threads 24
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA RTX A6000, compute capability 8.6, VMM: yes
(crashes when starting the 16k prompt with amb=16MiB) |
|
So compute buffers are massively improved. I don't have apples for apples comparison as I went down a rabbit hole after realising I could turn off pipeline parallel and it would also give me more VRAM back (thanks @ubergarm!). But it is massively improved. Had some issues going below
Same error: |
Even without the direct comparison, I'm curious what your at now. Also you probably have fixed it by now but CUDA15 was very unused in here:
|
|
Damn, I don’t have it right on me as I closed the laptop (night time here). I do have some data in notes from very early run. I was able to get to 24k context, with Here are some very initial runs (this is without disabling pipeline parallelism). This is already quite improved from what I can remember. Also, for gpu 16, unfortunately I can’t really use it. I can’t split the layers any bit more evenly (at least with what I’ve tried - it’s a bit of a limitation unfortunately without being able to split by row). I will add some more data tomorrow for you! Compute Buffer Configuration Comparison
Example Device Buffer Changes (MiB):
Key Findings:
|
|
Thank you for testing! It looks like a winner, so merging it. |
This PR
-mla 2 -fa) via the-ambcommand line option.perplexitytool, and in the CUDA implementation ofGGML_OP_CONCAT)For FlashMLA-2 one computes$X = W_{kv} K$ , where $K$ is the K-cache and $W_{kv}$ is the $X$ has the shape $X$ can become (it is "just" 1 GiB for a context of 65k tokens). But $X$ becomes 8 GiB ($X$ is computed as $X$ gets converted to $V$ and $K_{\rm nope}$ , both having half the elements of $X$ . As all 3 tensors need to exist simultaneously before the memory used for $X$ can be reused for other data, we end up requiring 16 GiB for these 3 tensors for a context of 65k tokens. This severely limits the maximum context length that can be processed on a GPU with limited VRAM. This PR solves the problem by slitting the attention computation into chunks. The number of chunks used is determined by the size of $X$ and the maximum attention buffer size $B_{\rm max}$ specified on the command-line via the $N_{\rm step} = {\rm sizeof}(X)/B_{\rm max}$ . In each step, $1/N_{\rm srep}$ of the $W_{kv}$ matrix are used, and the entire FlashMLA-2 series of operations is processed with this reduced dataset (effectively using $N_{\rm step}$ attention heads). The final attention result is obtained by concatenating the results of the individual steps along the head dimension.
blk.*.attk_kv_b.weighttensor.(n_embd_k_nope + n_embd_v) x n_kv x n_head, wheren_kvis the number of tokens currently in the cache,n_headis the number of heads, andn_embd_k_nope, n_embd_vare the head dimensions. For DeepSeekV3/R1/Liten_embd_k_nope = n_embd_v = 128. As I don't have the ability to run DeepSeekV3/R1, I'm experimenting with DeepSeek-Lite, wheren_head = 16, so I had not noticed how largen_head = 128for DeepSeekV3/R1, so for a context of 65k tokensfp32). When attention is computed on the GPU the cache isfp16(quantized cache still does not work for FlashMLA-2 on CUDA), sofp16tensors-amboption (the argument following-ambis maximum buffer size in MiB). We haven_head/For DeepSeek-Lite I need to use a quite low
-ambthreshold of 256 MiB to even trigger the multi-step attention calculation at 65k tokens (attention is computed with 4 steps at 65k tokens, 2 steps at 32k tokens, and 1 step for 16k tokens or less). I observe a 2-3% drop in performance on the CPU and on CUDA for context of 32k tokens computed in 2 steps. I would really appreciate if someone tested this PR with DeepSeekV3/R1 and reported-mla 2 -fa -amb 1024 -fmoe-amb 1024(only PP performance is required, TG in FlashMLA-2 is done the same way as no FA, so does not go through this memory optimization).