-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add option to override model tensor buffers #11397
base: master
Are you sure you want to change the base?
Conversation
Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds. |
No, that's something that would need to handled at a lower level in the CPU backend. |
Thanks for the reply @slaren. I figured it wouldn't directly help, but that maybe you'd be adding useful metadata to tensor objects that could help coordinate affinity in the future. I'll start a fresh branch and see how far I get.
I'll also try to pull this branch and test it to see what the speedup and sysmem savings look like. |
Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU: -ngl 0 = 4.65t/s So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run. |
@bmtwl |
What are the shared expert tensors called in |
I believe the pattern |
Thanks - I'll give this a try later in the week. This PR together with Reddit post opens up the interesting possibility: https://old.reddit.com/r/LocalLLaMA/comments/1ibbloy/158bit_deepseek_r1_131gb_dynamic_gguf/ of quantising up/gate projections to q2_k and down projections to q4_k (or something similar), then keeping everything else as Sadly I need to move some stuff about to get space to upscale the fp8 download to bf16 before I can try it, but will report back when I do. |
It might be worth trying |
Just being able to split the experts between NUMA nodes would make a big difference, but not sure how easy that would be as IIRC the experts' tensors are all in one huge tensor now? |
During normal operation, When I fit a model between ram and vram, Does the offloading follow a set layer sequence? (layer 0 is chosen first to be offloaded to GPU, then layer 1, etc) Between GPU offloading and ram, which takes priority?
Do you remember how much of a speedup? No need for extensive benchmarks, just the rough % estimate. |
I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes. |
I had a similar problem where if I used a single GPU (via If I didn't use either of these it tried to allocate this 1.4TB monster buffer:
After some searching I found this issue: and recompiled using (It's likely nothing to do with this PR, but thought it might help!) |
I figured it out: you have to reorder the devices so the local and mainly these:
Means this works: --device "RPC[IP1:PORT1],RPC[IP1:PORT2],RPC[IP1:PORT1],RPC[IP2:PORT2],CUDA0,CUDA1" But if I don't do this I get OOM errors with plenty of VRAM left like you had. |
I'm testing this with and without #11446 and without on unsloth's UD-IQ2_XXS I was only able to offload 29 layers, and with I was able to allocate only 28 (on a Q4_K_S quant). This is not a VRAM issue, it would have plenty of spare VRAM, it would even get past allocation, and get to warmup, where the rpc-server would then just crash. The other issue is performance the more layers I allocate the worse performance gets while bmtwl shows performance increase with more layers offloaded with non-RPC based offloading. |
I am able to load the model with
But as soon as I send the prompt I receive:
Without the Testing with 4x RTX 3090 and 320GiB RAM. Built with |
Maybe try |
No luck, still the same issue. Oddly enough, the issue only happens when sending more than 450 tokens. |
It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable |
It is the Is it possible to try to force this particular one to be allocated into the GPU buffer? |
This is most likely a bug, we need to understand why it is happening and fix it. Since you mentioned that it only happens with large prompts, I suspect that this is caused by a zero-sized tensors. When evaluating a batch where no logits are required (which happens when evaluating a prompt that needs to be split into multiple ubatches), zero-size tensors are created to skip the calculation of the logits. diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 9a3bf9f29..470ef13e6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -179,6 +179,9 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
// this should never happen
GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
__func__, size, max_avail);
+ GGML_LOG_ERROR("%s: tensor: %s, shape: %ld %ld %ld %ld, size: %zu",
+ __func__, tensor->name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
+ ggml_nbytes(tensor));
GGML_ABORT("not enough space in the buffer");
}
} |
Ok nvm, I think I see the problem. I will push a possible fix soon. |
@slaren I've got the same bug now and only on large prompts too. I can test the fix tomorrow. |
I confirm that the fix worked, thank you @slaren. For the record, I am getting ~2.5 t/s with |
I'm just trying to understand how the batching works with regard to I tried tracing the GGML code to Am I right in thinking that if Or is there something else happening where the selected 8 experts for each token is getting swapped in/out of RAM/VRAM? I only ask as I can't seem to find much difference by changing Also, was the problem I mentioned above:
which was solved by using |
When evaluating batches of >=32 tokens, all the weights are copied to VRAM and everything is evaluated on the GPU. For MoE weights, this means the entire tensor is copied with all the experts. That usually works well for the sizes of the models typically used locally, but I never did any testing on a model of the size of R1. It may be worth reevaluating that. This can be easily adjusted in the The problem you mentioned happens because pipeline parallelism is being enabled, and this causes memory for multiple copies of the weights to be reserved. This shouldn't happen when not fully offloading a model, since the pipeline parallelism implementation doesn't work at all if a part of the model needs to be evaluated on the CPU. But currently it only checks the value of |
Is there any way we could extend this to specify RPC backends too? I've managed to weave this abomination:
which is working, but slower than had I just kept it on one machine.
To keep everything on the main machine, but offload the sets of 3 `_exp" tensors to the 4 RPC servers, but I get this:
I also tried this:
which almost worked, but it ended up putting the KV-caches on the RPC backends as well as the "_exps" tensors. I could see this being very useful some somebody with 1 (or more) Mac M1/M2 Studio Ultra(s) as they wouldn't have as much latency as what I'm attempting. |
It should work with RPC servers, as long as you pass
Using |
Thanks. I think it probably does still make sense to copy to VRAM and 32 probably works well here too. If we simplify by assuming we select the 8 experts uniformly randomly with replacement as well as the 32 sets of selections, then 8 * 32 = 256 selections from 256 items and this has a nice solution of 1-1/e = ~63% coverage (see: Bootstrapping).
Ah thanks, so I'm not losing anything by using |
Thanks - I figured it out!
Yeah, I think a separate option for the KV-cache might work best eventually. |
Oh, I see I have the old |
Ah yes, this happens because it builds the list of buffer types the first time |
I've got it working:
I think this could be a super powerful command line option when mixed with RPC! Thanks for adding this! If anybody has a Mac Studio they want to test this on then I can help craft the regexes to test it - I'm interested to see what sort of boost you could get without so many stages of latency. |
I'm up for the testing - I have a mac studio M2 ultra 192GB <---10Gbps---> 13700K+192GBDDR5+RTX6000ada. Also if its helpful (seems to be?) I can get a Thunderbolt Gen4 egpu case and plug my RTX6000ada there... |
It didn't help me due the latency between the parts all pushing the hidden state. I used 10gbit Ethernet for all the machines so not sure upping to 40gbit (or whatever Thunderbolt is) will make that much difference - I think the problem is latency rather than bandwidth for this part sadly. Possibly using InfiniBand might help as IIRC it has lower latency, but not sure. I think the eventual solution would be to have RPC use a better method of pipeline parallelism like Deepspeed: It would definitely help the batch processing, and mixed data and pipeline would remove some latency if multiple GPUs per machine like I have. |
Just figured egpu won't help as Apple silicon cannot run cuda... |
Can you post speeds (with whatever configurations you tested), also not sure how much flash attention would impact speed, but it would shrink that compute buffer. |
I think the RPC stuff is never really gonna work properly until it can do async buffering: the way it is set up now each stage in the pipeline is stalling for every communication and this adds the full latency. If it was async and buffered the next job would start almost immediately with no wait, and you could probably optimise this even more by having the fastest devices at the start of the pipeline and the slowest at the end to get almost no degradation from latency. |
The GPU ultilization could be an illussion. I'll try get some numbers across different setup. |
Many many thanks, @slaren, for this PR. I really hope it gets merged. I have used this What I have learned so far (don't know if it is applicable for R1):
This is the order I found best improves token generation for mixtral on my system. This is with non-RPC devices (cuda), and may not correspond in the same way with other kind of backends. Many thanks again for this PR. |
If you don't mind can you post some more info: |
Adds command line parameter
--override-tensor
(-ot
) that allows changing the buffer type where a model tensor is allocated. This gives user fine grained control over what tensors are to offloaded to each device.How is this useful: for example, to force the experts in MoE models to stay on the CPU, while offloading the rest to the GPU, you could use
-ngl 99 -ot exps=CPU
. This may allow more efficient offloading schemes.The syntax is
<tensor name pattern>=<buffer type>
. Currently the pattern is just a string search (edit: this is no longer the case, it is a C++ regex search), ie. any tensors that contains the characters in<tensor name pattern>
will be matched and loaded into the given buffer type. Multiple overrides can be given by separating them with commas, or passing the-ot
option multiple times. To see what tensors are being matched, enable debugging output with-v
.At this point it is just a demo, feel free to experiment and report if you find any interesting uses.
Edit: added regex support, for example to keep experts of layers 20-99 in the CPU you could use
-ot "[2-9][0-9]\.ffn_.*_exps\.=CPU"