-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add support for batched inference #2813
Comments
Would it be possible to design this API in such a way ot support the following use case? In some cases, its useful to do constrained evaluation of logits based on a union of possible text values, then pick the sum { logits } (i.e. E.g. template (using MS
To definitely make the best choice, we'd need to calculate the probability of all 3 token sequences. Its easy if all the choices map to a single token, but with multiple tokens we'd need not just parallel generation but parallel logit evaluation of multiple possible paths. If we go greedy, we might get suboptimal results in cases multiple choices start with the same logit. Appologies if I misunderstood the feature in advance. 😅 |
is there any issue / improvement with respecting to supporting multiple requests in the server implementation similiarly? |
@AdityaSher The server serving multiple requests is tangentially related to this functionality. The problem there would be to have a logic that batches the different requests together - but this is high-level logic not related to the @spion This sounds related to the new beam-search functionality, although I'm not sure I fully understand the use case that you have in mind. |
@ggerganov Sort of, but its slightly more constrained, I think. There is a set number of well defined potential paths that need to be evaluated in a manner similar to how beam search does it, but with the constraint that most followup tokens are impossible and don't need to be considered at all (as they would not fit the constraints of the specified union) |
I guess you can combine beam search with grammar sampling. cc @ejones and @mattpulver for some insights |
@ggerganov You can use shared memory/anonymous pages and mmap to map the same physical page to multiple virtual pages, allowing you to reuse the common prompt context without copying it. Only works for CPU side of course, and you can only share full pages of context like this, any cached values that end up in the partial page would need to be either copied manually or recalculated. So the workflow would go like this:
Disclaimer: I have not looked too deep into the KV cache code, I just assumed it's a large contiguous segment Reference: https://nullprogram.com/blog/2016/04/10/ https://man7.org/linux/man-pages/man2/mmap.2.html https://groups.google.com/g/comp.os.linux.development.system/c/Prx7ExCzsv4 |
@ggerganov grammar sampling should work with any stream of tokens. Based on first impressions from the beam search code I wonder if each beam could maintain a grammar (accepting tokens as they're added to a beam) and apply the grammar penalties in the top-k selection? cc @mattpulver |
Currently Lines 3434 to 3440 in 53885d7
If each beam is to hold its own grammar state, then presumably they can share a Lines 4334 to 4337 in 53885d7
Then we can abstract out the call to get the Line 4457 in 53885d7
and replace it with a call to a user-supplied callback function that will determine the Does that sound satisfactory? |
Yeah, that sounds about right! |
How big in terms of memory is |
There was an earlier experiment w/ SQL grammar in which the If/when we do decide to implement that optimization, consider replacing
with
Then the rules can be lightly/freely copied after instantiation. |
Moving the discussion about the advanced beam search in #2923 |
At a glance, CUDA seems to have low level virtual memory APIs now? A GPU-based draft model handing batches of 100 over to a cpu model to pick from, is that where this is going potentially? |
i have pored over the code for a while now but as someone who doesn't write C everyday, apologies if this is a dumb question. @ggerganov I have been trying to wrap my head around this point you made earlier:
Let say the server has 100 requests outstanding. The ideal implementation of batching would batch 16 requests of similar length into one request into llama.cpp The goal of doing this would be to perform multiple parallel inferences for each forward batch pass, amortizing the cost of retrieving things from RAM across multiple batches (even though tokens are different). It seems to me this PR is going 90% of the way there by enabling batched inference for the same shared prompt but then allowing the context to diverge once tokens are generated. (and we are already talking about separate infrence states per batch anyway). In order to support proper cross-request batching in the server, wouldn't the llama.cpp API need to support at least a vector of kv caches (or 2D?)? |
Done via #3228 |
Does llama-bench support this for benchmarking now? |
It doesn't, but @ggerganov is working on another tool for that in #3545. |
How does one use this batched inference? |
@segmond refer to server example |
We want to be able to generate multiple sequences sharing the same context (a.k.a. prompt) in parallel.
Demonstrated in one of the examples by @xaedes :
llama.cpp/examples/baby-llama/baby-llama.cpp
Lines 785 to 794 in eff86d4
Should become part of the official
llama.cpp
APIref: #2789
Implementation details
Regarding the API for the batched inference functionality, one way is to add a function:
This would reallocate the
kv_self
cache to fitn_batches
batches.During
llama_eval
, we do what we normally do, with the extra step of batching the input as demonstrated in the example. We can probably avoid changing theeval
API by adding the implicit assumption thattokens
will contain the tokens forn_batches
batches:llama.cpp/llama.h
Lines 315 to 320 in dd0dc36
In the end, we just need to update the API for accessing the logits of all the batches, or once again - without changing the API, have an implicit assumption that the results will be for
n_batches
batches:llama.cpp/llama.h
Line 341 in dd0dc36
So on first thought, we would just need a single new function added to
llama.h
-llama_context_set_parallel()
.I think this should be enough, but I could be missing something
One optimization to consider is if we can avoid having separate KV caches for the common prefix of the parallel runs. The straightforward implementation would create a copy of this for each batch, while in theory we need just one. Not sure how complicated it would be to handle this. Might need to implement Paged Attention, which is probably a job for another time
The text was updated successfully, but these errors were encountered: