Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add option to override model tensor buffers #11397

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

slaren
Copy link
Member

@slaren slaren commented Jan 24, 2025

Adds command line parameter --override-tensor (-ot) that allows changing the buffer type where a model tensor is allocated. This gives user fine grained control over what tensors are to offloaded to each device.

How is this useful: for example, to force the experts in MoE models to stay on the CPU, while offloading the rest to the GPU, you could use -ngl 99 -ot exps=CPU. This may allow more efficient offloading schemes.

The syntax is <tensor name pattern>=<buffer type>. Currently the pattern is just a string search (edit: this is no longer the case, it is a C++ regex search), ie. any tensors that contains the characters in <tensor name pattern> will be matched and loaded into the given buffer type. Multiple overrides can be given by separating them with commas, or passing the -ot option multiple times. To see what tensors are being matched, enable debugging output with -v.

At this point it is just a demo, feel free to experiment and report if you find any interesting uses.

Edit: added regex support, for example to keep experts of layers 20-99 in the CPU you could use -ot "[2-9][0-9]\.ffn_.*_exps\.=CPU"

@slaren slaren added the demo Demonstrate some concept or idea, not intended to be merged label Jan 24, 2025
@slaren slaren changed the title llama : add option to override tensor buffers llama : add option to override model tensor buffers Jan 24, 2025
@slaren slaren added the need feedback Testing and feedback with results are needed label Jan 24, 2025
@bmtwl
Copy link
Contributor

bmtwl commented Jan 26, 2025

Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds.

@slaren
Copy link
Member Author

slaren commented Jan 26, 2025

No, that's something that would need to handled at a lower level in the CPU backend.

@bmtwl
Copy link
Contributor

bmtwl commented Jan 26, 2025

No, that's something that would need to handled at a lower level in the CPU backend.

Thanks for the reply @slaren. I figured it wouldn't directly help, but that maybe you'd be adding useful metadata to tensor objects that could help coordinate affinity in the future. I'll start a fresh branch and see how far I get.

At this point it is just a demo, feel free to experiment and report if you find any interesting uses.

I'll also try to pull this branch and test it to see what the speedup and sysmem savings look like.

@bmtwl
Copy link
Contributor

bmtwl commented Jan 27, 2025

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s
-ngl 10 = 5.15t/s
-ngl 20 = 5.64t/s
-ngl 30 = 6.10t/s
-ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

@saood06
Copy link

saood06 commented Jan 27, 2025

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

@bmtwl
Do you mind testing performance with -nkvo?

@jukofyork
Copy link
Contributor

What are the shared expert tensors called in llama.cpp - is there a pattern that catches the routed experts (that only activate 1/32 of the time), but doesn't catch the shared experts?

@slaren
Copy link
Member Author

slaren commented Jan 28, 2025

I believe the pattern exps will not match the shared experts, since they are called ffn_xxx_shexp.weight. You can use the gguf preview feature in huggingface to see the names of the tensors. Also remember that you can use multiple patterns, it doesn't have to be a single one.

@jukofyork
Copy link
Contributor

I believe the pattern exps will not match the shared experts, since they are called ffn_xxx_shexp.weight. You can use the gguf preview feature in huggingface to see the names of the tensors. Also remember that you can use multiple patterns, it doesn't have to be a single one.

Thanks - I'll give this a try later in the week.

This PR together with Reddit post opens up the interesting possibility:

https://old.reddit.com/r/LocalLLaMA/comments/1ibbloy/158bit_deepseek_r1_131gb_dynamic_gguf/

of quantising up/gate projections to q2_k and down projections to q4_k (or something similar), then keeping everything else as q8_0.

Sadly I need to move some stuff about to get space to upscale the fp8 download to bf16 before I can try it, but will report back when I do.

@jukofyork
Copy link
Contributor

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

It might be worth trying q4_0 as should almost let you offload all the layers and IIRC should be slightly faster to dequantise than the K-quants?

@jukofyork
Copy link
Contributor

Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds.

Just being able to split the experts between NUMA nodes would make a big difference, but not sure how easy that would be as IIRC the experts' tensors are all in one huge tensor now?

@BarfingLemurs
Copy link
Contributor

During normal operation, When I fit a model between ram and vram, Does the offloading follow a set layer sequence? (layer 0 is chosen first to be offloaded to GPU, then layer 1, etc)

Between GPU offloading and ram, which takes priority?

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run.

Do you remember how much of a speedup? No need for extensive benchmarks, just the rough % estimate.

@saood06
Copy link

saood06 commented Feb 2, 2025

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:

-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes.

@jukofyork
Copy link
Contributor

Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU:
-ngl 0 = 4.65t/s -ngl 10 = 5.15t/s -ngl 20 = 5.64t/s -ngl 30 = 6.10t/s -ngl 40 = 6.95t/s

I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes.

I had a similar problem where if I used a single GPU (via CUDA_VISIBLE_DEVICES=0) it ran fine and if I used both GPUs with the --no-kv-offload option it also ran fine (but much slower).

If I didn't use either of these it tried to allocate this 1.4TB monster buffer:

llama_init_from_model: pipeline parallelism enabled (n_copies=4)
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 1407257.91 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 1475616865280
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 351268.28 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 368331484928
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 353465.98 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 370635939584

After some searching I found this issue:

#7217

and recompiled using -DGGML_SCHED_MAX_COPIES=1 and now it's working fine.

(It's likely nothing to do with this PR, but thought it might help!)

@jukofyork
Copy link
Contributor

@saood06

I figured it out: you have to reorder the devices so the local CUDA devices are last::

#11606
#11424

and mainly these:

#11435

You don't need to run RPC servers for local devices.

#9296
#11424

For those that don't get it (like me initially), you first need to check the device names using the --list-devices option (example below):

 $ llama.cpp/build/bin/llama-server --rpc <IP1>:<PORT1> --rpc <IP2>:<PORT2> --list-devices
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX XXXX, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce GTX YYYY, compute capability 7.5, VMM: yes
Available devices:
  CUDA0: NVIDIA GeForce RTX XXXX (A MiB, B MiB free)
  CUDA1: NVIDIA GeForce GTX YYYY (A MiB, B MiB free)
  RPC[IP1:PORT1]: RPC[IP1:PORT1] (A MiB, B MiB free)
  RPC[IP2:PORT2]: RPC[IP2:PORT2] (A MiB, B MiB free)

It is under Available devices where you get the device names. Next time you launch llama-server, you will use the --device option with the order you want for your devices. An example:

$ llama.cpp/build/bin/llama-server --rpc <IP1>:<PORT1> --rpc <IP2>:<PORT2> \
--device RPC[IP1:PORT1],CUDA0,CUDA1,RPC[IP2:PORT2] \
-ngl 33 --tensor_split 3/20/10/0 --device-draft CUDA1,RPC[IP2:PORT2] -ngld 99 [...]

This way, you can set up the order however you want. In the complicated example above, the main model is offloaded to the first RPC device (using IP1:PORT1 address), mostly on the CUDA0 device, and partially to the CUDA1 device, while the draft model is offloaded to the CUDA1 device and the second RPC device (using IP2:PORT2 address).

Means this works:

--device "RPC[IP1:PORT1],RPC[IP1:PORT2],RPC[IP1:PORT1],RPC[IP2:PORT2],CUDA0,CUDA1"

But if I don't do this I get OOM errors with plenty of VRAM left like you had.

@saood06
Copy link

saood06 commented Feb 5, 2025

I'm testing this with and without #11446 and without on unsloth's UD-IQ2_XXS I was only able to offload 29 layers, and with I was able to allocate only 28 (on a Q4_K_S quant). This is not a VRAM issue, it would have plenty of spare VRAM, it would even get past allocation, and get to warmup, where the rpc-server would then just crash.

The other issue is performance the more layers I allocate the worse performance gets while bmtwl shows performance increase with more layers offloaded with non-RPC based offloading.

@ro99
Copy link

ro99 commented Feb 5, 2025

I am able to load the model with llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-00001-of-00005.gguf --threads 28 --host 0.0.0.0 --port 5001 -c 8192 -ngl 99 -ot exps=CPU :

PID DEV TYPE GPU MEM HOST MEM Command
16431 0 Compute 13294MiB 54% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000
16431 2 Compute 12088MiB 49% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000
16431 3 Compute 11616MiB 47% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000
16431 1 Compute 11488MiB 47% 215686MiB /opt/llama.cpp/build/bin/llama-server -m /mnt/models/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-000

But as soon as I send the prompt I receive:

/opt/llama.cpp/ggml/src/ggml-alloc.c:182: not enough space in the buffer
ggml_dyn_tallocr_alloc: not enough space in the buffer to allocate 18446744073709550624 bytes, largest block available 9223372036854775807 bytes
[New LWP 16444]
[New LWP 16445]
[New LWP 16446]
[New LWP 16447]
...
[New LWP 16533]
[New LWP 16534]
[New LWP 16535]
[New LWP 16536]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f1e950d0bd7 in wait4 () from /lib/x86_64-linux-gnu/libc.so.6
#0  0x00007f1e950d0bd7 in wait4 () from /lib/x86_64-linux-gnu/libc.so.6
#1  0x00007f1e95527fc1 in ggml_abort () from /opt/llama.cpp/build/bin/libggml-base.so
#2  0x00007f1e9553619c in ggml_gallocr_allocate_node () from /opt/llama.cpp/build/bin/libggml-base.so
#3  0x00007f1e955369d0 in ggml_gallocr_reserve_n () from /opt/llama.cpp/build/bin/libggml-base.so
#4  0x00007f1e9553c244 in ggml_backend_sched_alloc_graph () from /opt/llama.cpp/build/bin/libggml-base.so
#5  0x00007f1e95646030 in llama_decode_impl(llama_context&, llama_batch) () from /opt/llama.cpp/build/bin/libllama.so
#6  0x00007f1e95646f57 in llama_decode () from /opt/llama.cpp/build/bin/libllama.so
#7  0x000055f47d6647c9 in server_context::update_slots() ()
#8  0x000055f47d64f4d1 in server_queue::start_loop() ()
#9  0x000055f47d5fd067 in main ()
[Inferior 1 (process 16431) detached]
Aborted (core dumped)

Without the --override-tensor and offloading 20 layers to the GPU it works fine.

Testing with 4x RTX 3090 and 320GiB RAM. Built with cmake -B build -DGGML_CUDA=ON -DGGML_SCHED_MAX_COPIES=1.

@jukofyork
Copy link
Contributor

Without the --override-tensor and offloading 20 layers to the GPU it works fine.

Testing with 4x RTX 3090 and 320GiB RAM. Built with cmake -B build -DGGML_CUDA=ON -DGGML_SCHED_MAX_COPIES=1.

Maybe try -ngl 61 to keep the output layer on the CPU too (that oddly worked for me earlier when I was having trouble with the RPC stuff).

@ro99
Copy link

ro99 commented Feb 5, 2025

Maybe try -ngl 61 to keep the output layer on the CPU too (that oddly worked for me earlier when I was having trouble with the RPC stuff).

No luck, still the same issue.

Oddly enough, the issue only happens when sending more than 450 tokens.

@slaren
Copy link
Member Author

slaren commented Feb 5, 2025

ggml_dyn_tallocr_alloc: not enough space in the buffer to allocate 18446744073709550624 bytes

It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable GGML_SCHED_DEBUG=2, it will print the graph before allocating it, which may give some indication of which tensor is causing this. Or just change the error message in ggml_dyn_tallocr_alloc to include the tensor name.

@ro99
Copy link

ro99 commented Feb 6, 2025

It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable GGML_SCHED_DEBUG=2, it will print the graph before allocating it, which may give some indication of which tensor is causing this. Or just change the error message in ggml_dyn_tallocr_alloc to include the tensor name.

It is the CPU#ffn_moe_topk-60#0 tensor.

Is it possible to try to force this particular one to be allocated into the GPU buffer?

@slaren
Copy link
Member Author

slaren commented Feb 6, 2025

This is most likely a bug, we need to understand why it is happening and fix it. Since you mentioned that it only happens with large prompts, I suspect that this is caused by a zero-sized tensors. When evaluating a batch where no logits are required (which happens when evaluating a prompt that needs to be split into multiple ubatches), zero-size tensors are created to skip the calculation of the logits.
I cannot run this model, so I would need your help to figure why this is happening. Can you print more details about the tensor? Something like this should do it:

diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 9a3bf9f29..470ef13e6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -179,6 +179,9 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
             // this should never happen
             GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
                     __func__, size, max_avail);
+            GGML_LOG_ERROR("%s: tensor: %s, shape: %ld %ld %ld %ld, size: %zu",
+                __func__, tensor->name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
+                ggml_nbytes(tensor));
             GGML_ABORT("not enough space in the buffer");
         }
     }

@slaren
Copy link
Member Author

slaren commented Feb 6, 2025

Ok nvm, I think I see the problem. I will push a possible fix soon.

@jukofyork
Copy link
Contributor

@slaren I've got the same bug now and only on large prompts too. I can test the fix tomorrow.

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Feb 6, 2025
@ro99
Copy link

ro99 commented Feb 6, 2025

Ok nvm, I think I see the problem. I will push a possible fix soon.

I confirm that the fix worked, thank you @slaren.

For the record, I am getting ~2.5 t/s with -ngl 99 -ot exps=CPU (I was getting ~1.0 t/s before).

@jukofyork
Copy link
Contributor

jukofyork commented Feb 8, 2025

@slaren

I'm just trying to understand how the batching works with regard to --override-tensor exps=CPU.

I tried tracing the GGML code to ggml_mul_mat_id and looked at the PRs (#6505 #6387), but I still can't see how the batching is working for prompt processing...

Am I right in thinking that if n_ubatch is say 512 and we use --override-tensor exps=CPU, this causes the full ffn_up_exps.weight tensor of dimension {7168, 2048, 256} to get copied from RAM to VRAM, then CUDA processes the 512 tokens in parallel though the copied tensor (where on average each expert tensor should get hit 16 times on average: 8 active per token x 512 tokens in the batch / 256 experts total).

Or is there something else happening where the selected 8 experts for each token is getting swapped in/out of RAM/VRAM?

I only ask as I can't seem to find much difference by changing n_ubatch as I expected it would?


Also, was the problem I mentioned above:

llama_init_from_model: pipeline parallelism enabled (n_copies=4)
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 1407257.91 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 1475616865280
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 351268.28 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 368331484928
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 353465.98 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 370635939584

which was solved by using -DGGML_SCHED_MAX_COPIES=1 a result of --override-tensor exps=CPU (possibly creating 57 stages or similar), or nothing to do with this PR?

@slaren
Copy link
Member Author

slaren commented Feb 8, 2025

When evaluating batches of >=32 tokens, all the weights are copied to VRAM and everything is evaluated on the GPU. For MoE weights, this means the entire tensor is copied with all the experts. That usually works well for the sizes of the models typically used locally, but I never did any testing on a model of the size of R1. It may be worth reevaluating that. This can be easily adjusted in the offload_op function of the backends, e.g. in ggml_backend_cuda_device_offload_op for CUDA. This function determines which weights are copied to VRAM.

The problem you mentioned happens because pipeline parallelism is being enabled, and this causes memory for multiple copies of the weights to be reserved. This shouldn't happen when not fully offloading a model, since the pipeline parallelism implementation doesn't work at all if a part of the model needs to be evaluated on the CPU. But currently it only checks the value of -ngl to determine if the model is being fully offloaded, so the check needs to be updated in this PR to also consider the overridden tensors. Building with -DGGML_SCHED_MAX_COPIES=1 effectively disables pipeline parallelism.

@jukofyork
Copy link
Contributor

Is there any way we could extend this to specify RPC backends too?

I've managed to weave this abomination:

./llama-server --host 192.168.1.111 --port 8080 \
  --model ./DeepSeek-R1-Q5_K.gguf --chat-template deepseek3 --alias "DeepSeek-R1" --ctx_size 32768 \
  --n-gpu-layers 62 --threads 30 \
  --rpc 192.168.1.112:50050,192.168.1.112:50051,192.168.1.113:50050,192.168.1.113:50051 \
  --device "RPC[192.168.1.112:50050],RPC[192.168.1.112:50051],RPC[192.168.1.113:50050],RPC[192.168.1.113:50051],CUDA0,CUDA1"   --tensor-split 5,4,4,4,23,22 \
  --override-tensor 'blk\.(1[7-9]|[2-5][0-9]|6[0-1])\..*_exps\.=CPU'

which is working, but slower than had I just kept it on one machine.

./llama-server --host 192.168.1.111 --port 8080 \
  --model ./DeepSeek-R1-Q5_K.gguf --chat-template deepseek3 --alias "DeepSeek-R1" --ctx_size 32768 \
  --n-gpu-layers 62 --threads 30 \
  --rpc 192.168.1.112:50050,192.168.1.112:50051,192.168.1.113:50050,192.168.1.113:50051 \
  --device "RPC[192.168.1.112:50050],RPC[192.168.1.112:50051],RPC[192.168.1.113:50050],RPC[192.168.1.113:50051],CUDA0,CUDA1"   
  --override-tensor "blk\.(([3-9]|1[0-2]))\.ffn_.*_exps\.=RPC[192.168.1.112:50050]" ...

To keep everything on the main machine, but offload the sets of 3 `_exp" tensors to the 4 RPC servers, but I get this:

Available buffer types:
  CPU
  CUDA0
  CUDA1
error while handling argument "--override-tensor": unknown buffer type

I also tried this:

--n-gpu-layers XX \
--override-tensor '^(?!.*exps).*blk\.[0-2][0-9]=CUDA0' \
--override-tensor '^(?!.*exps).*blk\.([3-5][0-9]|60)=CUDA1'

which almost worked, but it ended up putting the KV-caches on the RPC backends as well as the "_exps" tensors.

I could see this being very useful some somebody with 1 (or more) Mac M1/M2 Studio Ultra(s) as they wouldn't have as much latency as what I'm attempting.

@slaren
Copy link
Member Author

slaren commented Feb 8, 2025

It should work with RPC servers, as long as you pass --rpc before --override-tensor. I just did a quick test and it works as expected:

Available buffer types:
  CPU
  CUDA0
  RPC[127.0.0.1:50052]

but it ended up putting the KV-caches on the RPC backends as well as the "_exps" tensors

Using -ot this does not change where the KV cache is allocated. I guess you could hack it in some way using -ts, but I could change it so that -ot also applies to the KV tensors.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 8, 2025

When evaluating batches of >=32 tokens, all the weights are copied to VRAM and everything is evaluated on the GPU. For MoE weights, this means the entire tensor is copied with all the experts. That usually works well for the sizes of the models typically used locally, but I never did any testing on a model of the size of R1. It may be worth reevaluating that. This can be easily adjusted in the offload_op function of the backends, e.g. in ggml_backend_cuda_device_offload_op for CUDA. This function determines which weights are copied to VRAM.

Thanks. I think it probably does still make sense to copy to VRAM and 32 probably works well here too. If we simplify by assuming we select the 8 experts uniformly randomly with replacement as well as the 32 sets of selections, then 8 * 32 = 256 selections from 256 items and this has a nice solution of 1-1/e = ~63% coverage (see: Bootstrapping).

The problem you mentioned happens because pipeline parallelism is being enabled, and this causes memory for multiple copies of the weights to be reserved. This shouldn't happen when not fully offloading a model, since the pipeline parallelism implementation doesn't work at all if a part of the model needs to be evaluated on the CPU. But currently it only checks the value of -ngl to determine if the model is being fully offloaded, so the check needs to be updated in this PR to also consider the overridden tensors. Building with -DGGML_SCHED_MAX_COPIES=1 effectively disables pipeline parallelism.

Ah thanks, so I'm not losing anything by using -DGGML_SCHED_MAX_COPIES=1.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 8, 2025

It should work with RPC servers, as long as you pass --rpc before --override-tensor. I just did a quick test and it works as expected:

Available buffer types:
  CPU
  CUDA0
  RPC[127.0.0.1:50052]

Thanks - I figured it out!

but it ended up putting the KV-caches on the RPC backends as well as the "_exps" tensors

Using -ot this does not change where the KV cache is allocated. I guess you could hack it in some way using -ts, but I could change it so that -ot also applies to the KV tensors.

Yeah, I think a separate option for the KV-cache might work best eventually.

@jukofyork
Copy link
Contributor

Oh, I see I have the old --override-tensor exps=CPU --override-tensor attn_kv_b=CPU in there now! doh.

@slaren
Copy link
Member Author

slaren commented Feb 8, 2025

Ah yes, this happens because it builds the list of buffer types the first time -ot is passed, and stores it in a static. So new added devices will not be recognized. Maybe it would be better to rebuild the list of buffer types every time, it is not likely to affect performance in a meaningful way anyway.

@jukofyork
Copy link
Contributor

I've got it working:

numactl --interleave=all ./llama.cpp/build/bin/llama-server --host 192.168.1.111 --port 8080 \
  --model ./DeepSeek-R1-mla-Q5_K_XL.gguf --chat-template deepseek3 --alias "DeepSeek-R1-mla-Q5_K_XL" --ctx_size 32768 \
  --n-gpu-layers 62 --numa distribute --threads 30 \
  --temp 0.6 --min-p 0.0 --top-p 1.0 --top-k 0 --rpc 192.168.1.112:50050,192.168.1.112:50051,192.168.1.113:50050,192.168.1.113:50051 \
  --device "RPC[192.168.1.112:50050],RPC[192.168.1.112:50051],RPC[192.168.1.113:50050],RPC[192.168.1.113:50051],CUDA0,CUDA1" \
  --tensor-split 0,0,0,0,31,31 \
  --override-tensor 'blk\.([3-8])\..*_exps\.=RPC[192.168.1.112:50050]' \
  --override-tensor 'blk\.([9]|1[0-4])\..*_exps\.=RPC[192.168.1.112:50051]' \
  --override-tensor 'blk\.(1[5-9]|20)\..*_exps\.=RPC[192.168.1.113:50050]' \
  --override-tensor 'blk\.(2[1-6])\..*_exps\.=RPC[192.168.1.113:50051]' \
  --override-tensor 'blk\.(2[7-9]|[3-5][0-9]|60)\..*_exps\.=CPU'
llama_kv_cache_init:      CUDA0 KV buffer size =  2108.01 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =  2040.01 MiB
llama_init_from_model: KV self size  =    0.00 MiB, K (f16):    0.00 MiB, V (f16):    0.00 MiB
llama_init_from_model: KV self size  = 2196.00 MiB, K^R (f16):  244.00 MiB, c^KV (f16): 1952.00 MiB
llama_init_from_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_init_from_model: RPC[192.168.1.112:50050] compute buffer size =   159.00 MiB
llama_init_from_model: RPC[192.168.1.112:50051] compute buffer size =   159.00 MiB
llama_init_from_model: RPC[192.168.1.113:50050] compute buffer size =   159.00 MiB
llama_init_from_model: RPC[192.168.1.113:50051] compute buffer size =   159.00 MiB
llama_init_from_model:      CUDA0 compute buffer size = 16731.50 MiB
llama_init_from_model:      CUDA1 compute buffer size = 16666.00 MiB
llama_init_from_model:        CPU compute buffer size =    78.01 MiB
llama_init_from_model: graph nodes  = 5208 (with bs=512), 5330 (with bs=1)
llama_init_from_model: graph splits = 183 (with bs=512), 119 (with bs=1)

I think this could be a super powerful command line option when mixed with RPC! Thanks for adding this!

If anybody has a Mac Studio they want to test this on then I can help craft the regexes to test it - I'm interested to see what sort of boost you could get without so many stages of latency.

@Dango233
Copy link

Dango233 commented Feb 10, 2025

I've got it working:

numactl --interleave=all ./llama.cpp/build/bin/llama-server --host 192.168.1.111 --port 8080 \
  --model ./DeepSeek-R1-mla-Q5_K_XL.gguf --chat-template deepseek3 --alias "DeepSeek-R1-mla-Q5_K_XL" --ctx_size 32768 \
  --n-gpu-layers 62 --numa distribute --threads 30 \
  --temp 0.6 --min-p 0.0 --top-p 1.0 --top-k 0 --rpc 192.168.1.112:50050,192.168.1.112:50051,192.168.1.113:50050,192.168.1.113:50051 \
  --device "RPC[192.168.1.112:50050],RPC[192.168.1.112:50051],RPC[192.168.1.113:50050],RPC[192.168.1.113:50051],CUDA0,CUDA1" \
  --tensor-split 0,0,0,0,31,31 \
  --override-tensor 'blk\.([3-8])\..*_exps\.=RPC[192.168.1.112:50050]' \
  --override-tensor 'blk\.([9]|1[0-4])\..*_exps\.=RPC[192.168.1.112:50051]' \
  --override-tensor 'blk\.(1[5-9]|20)\..*_exps\.=RPC[192.168.1.113:50050]' \
  --override-tensor 'blk\.(2[1-6])\..*_exps\.=RPC[192.168.1.113:50051]' \
  --override-tensor 'blk\.(2[7-9]|[3-5][0-9]|60)\..*_exps\.=CPU'
llama_kv_cache_init:      CUDA0 KV buffer size =  2108.01 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =  2040.01 MiB
llama_init_from_model: KV self size  =    0.00 MiB, K (f16):    0.00 MiB, V (f16):    0.00 MiB
llama_init_from_model: KV self size  = 2196.00 MiB, K^R (f16):  244.00 MiB, c^KV (f16): 1952.00 MiB
llama_init_from_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_init_from_model: RPC[192.168.1.112:50050] compute buffer size =   159.00 MiB
llama_init_from_model: RPC[192.168.1.112:50051] compute buffer size =   159.00 MiB
llama_init_from_model: RPC[192.168.1.113:50050] compute buffer size =   159.00 MiB
llama_init_from_model: RPC[192.168.1.113:50051] compute buffer size =   159.00 MiB
llama_init_from_model:      CUDA0 compute buffer size = 16731.50 MiB
llama_init_from_model:      CUDA1 compute buffer size = 16666.00 MiB
llama_init_from_model:        CPU compute buffer size =    78.01 MiB
llama_init_from_model: graph nodes  = 5208 (with bs=512), 5330 (with bs=1)
llama_init_from_model: graph splits = 183 (with bs=512), 119 (with bs=1)

I think this could be a super powerful command line option when mixed with RPC! Thanks for adding this!

If anybody has a Mac Studio they want to test this on then I can help craft the regexes to test it - I'm interested to see what sort of boost you could get without so many stages of latency.

I'm up for the testing - I have a mac studio M2 ultra 192GB <---10Gbps---> 13700K+192GBDDR5+RTX6000ada.
I'll try running this myself first and see if I can get it rolling

Also if its helpful (seems to be?) I can get a Thunderbolt Gen4 egpu case and plug my RTX6000ada there...

@jukofyork
Copy link
Contributor

Also if its helpful (seems to be?) I can get a Thunderbolt Gen4 egpu case and plug my RTX6000ada there...

It didn't help me due the latency between the parts all pushing the hidden state.

I used 10gbit Ethernet for all the machines so not sure upping to 40gbit (or whatever Thunderbolt is) will make that much difference - I think the problem is latency rather than bandwidth for this part sadly.

Possibly using InfiniBand might help as IIRC it has lower latency, but not sure.

I think the eventual solution would be to have RPC use a better method of pipeline parallelism like Deepspeed:

deepspeedai/DeepSpeed#1110

It would definitely help the batch processing, and mixed data and pipeline would remove some latency if multiple GPUs per machine like I have.

@Dango233
Copy link

Just figured egpu won't help as Apple silicon cannot run cuda...
Not sure if RPC is the bottle neck here - my RTX got maxed out - probably due to the lack of flash attention?

@saood06
Copy link

saood06 commented Feb 10, 2025

Not sure if RPC is the bottle neck here - my RTX got maxed out - probably due to the lack of flash attention?

Can you post speeds (with whatever configurations you tested), also not sure how much flash attention would impact speed, but it would shrink that compute buffer.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 10, 2025

Not sure if RPC is the bottle neck here - my RTX got maxed out - probably due to the lack of flash attention?

Can you post speeds (with whatever configurations you tested), also not sure how much flash attention would impact speed, but it would shrink that compute buffer.

I think the RPC stuff is never really gonna work properly until it can do async buffering: the way it is set up now each stage in the pipeline is stalling for every communication and this adds the full latency. If it was async and buffered the next job would start almost immediately with no wait, and you could probably optimise this even more by having the fastest devices at the start of the pipeline and the slowest at the end to get almost no degradation from latency.

@Dango233
Copy link

Not sure if RPC is the bottle neck here - my RTX got maxed out - probably due to the lack of flash attention?

Can you post speeds (with whatever configurations you tested), also not sure how much flash attention would impact speed, but it would shrink that compute buffer.

The GPU ultilization could be an illussion. I'll try get some numbers across different setup.

@abc-nix
Copy link

abc-nix commented Feb 15, 2025

Many many thanks, @slaren, for this PR. I really hope it gets merged.

I have used this --override-tensor option to improve over 70% token generation speeds for Mixtral 8x22.

What I have learned so far (don't know if it is applicable for R1):

  • On Mixtral, there are 3 types of expert related tensors: ffn_gate_exps, ffn_up_exps and ffn_down_exps.
  • Try to offload as many layers as possible to GPU by keeping all expert related tensors on CPU (as explained in the merge description, -ot exps=CPU).
  • In most cases, you will have to keep the last layer on CPU (for mixtral 8x22 q4_k_m, that is 56 of 57 offloaded to GPU, the last layer will not fit).
  • Once you have all (minus one) layers offloaded to GPU, this is the order I found that improves token generation the most:
    1. Try to get as many ffn_down_exps tensors as possible on GPU. This means you need to override the ffn_gate_exps tensors and ffn_up_exps tensors and keep them on CPU (-ot ffn_gate_exps=CPU -ot ffn_up_exps=CPU)
    2. Once full, start offloading the ffn_up_exps tensors to GPU.
    3. Finally, offload any ffn_gate_exps tensors to GPU until full.

This is the order I found best improves token generation for mixtral on my system. This is with non-RPC devices (cuda), and may not correspond in the same way with other kind of backends.

Many thanks again for this PR.

@saood06
Copy link

saood06 commented Feb 15, 2025

I have used this --override-tensor option to improve over 70% token generation speeds for Mixtral 8x22.

If you don't mind can you post some more info:
How much VRAM/What GPU? Also what quant did you use for this? Can you post the actual performance numbers with the configurations you tested?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged ggml changes relating to the ggml tensor library for machine learning need feedback Testing and feedback with results are needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants