Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add support for batched inference #2813

Closed
ggerganov opened this issue Aug 26, 2023 · 19 comments
Closed

llama : add support for batched inference #2813

ggerganov opened this issue Aug 26, 2023 · 19 comments
Labels
performance Speed related topics

Comments

@ggerganov
Copy link
Owner

ggerganov commented Aug 26, 2023

We want to be able to generate multiple sequences sharing the same context (a.k.a. prompt) in parallel.

Demonstrated in one of the examples by @xaedes :

struct ggml_tensor * forward_batch(
struct llama_model * model,
struct llama_kv_cache * cache,
struct ggml_context * ctx0,
struct ggml_cgraph * gf,
struct ggml_tensor * tokens_input,
const int n_tokens,
const int n_past,
const int n_batch) {

Should become part of the official llama.cpp API

ref: #2789

Implementation details

Regarding the API for the batched inference functionality, one way is to add a function:

// TODO: better name?
void llama_context_set_parallel(struct llama_context * ctx, int n_batches);

This would reallocate the kv_self cache to fit n_batches batches.

During llama_eval, we do what we normally do, with the extra step of batching the input as demonstrated in the example. We can probably avoid changing the eval API by adding the implicit assumption that tokens will contain the tokens for n_batches batches:

llama.cpp/llama.h

Lines 315 to 320 in dd0dc36

LLAMA_API int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);

In the end, we just need to update the API for accessing the logits of all the batches, or once again - without changing the API, have an implicit assumption that the results will be for n_batches batches:

LLAMA_API float * llama_get_logits(struct llama_context * ctx);


So on first thought, we would just need a single new function added to llama.h - llama_context_set_parallel().
I think this should be enough, but I could be missing something

One optimization to consider is if we can avoid having separate KV caches for the common prefix of the parallel runs. The straightforward implementation would create a copy of this for each batch, while in theory we need just one. Not sure how complicated it would be to handle this. Might need to implement Paged Attention, which is probably a job for another time

@ggerganov ggerganov added the performance Speed related topics label Aug 26, 2023
@ggerganov ggerganov moved this to Todo in ggml : roadmap Aug 26, 2023
@spion
Copy link

spion commented Aug 26, 2023

Would it be possible to design this API in such a way ot support the following use case?

In some cases, its useful to do constrained evaluation of logits based on a union of possible text values, then pick the sum { logits } (i.e. product(probabilities)) that gives the most probable outcome overall.

E.g. template (using MS guidance)

{{#select 'armor'}}leather{{or}}chainmail{{or}}plate{{/select}}

To definitely make the best choice, we'd need to calculate the probability of all 3 token sequences. Its easy if all the choices map to a single token, but with multiple tokens we'd need not just parallel generation but parallel logit evaluation of multiple possible paths.

If we go greedy, we might get suboptimal results in cases multiple choices start with the same logit.

Appologies if I misunderstood the feature in advance. 😅

@AdityaSher
Copy link

is there any issue / improvement with respecting to supporting multiple requests in the server implementation similiarly?

@ggerganov
Copy link
Owner Author

@AdityaSher The server serving multiple requests is tangentially related to this functionality. The problem there would be to have a logic that batches the different requests together - but this is high-level logic not related to the llama.cpp library

@spion This sounds related to the new beam-search functionality, although I'm not sure I fully understand the use case that you have in mind.

@spion
Copy link

spion commented Aug 28, 2023

@ggerganov Sort of, but its slightly more constrained, I think. There is a set number of well defined potential paths that need to be evaluated in a manner similar to how beam search does it, but with the constraint that most followup tokens are impossible and don't need to be considered at all (as they would not fit the constraints of the specified union)

@ggerganov
Copy link
Owner Author

@spion

I guess you can combine beam search with grammar sampling.
Interesting if we can make this work

cc @ejones and @mattpulver for some insights

@AutonomicPerfectionist
Copy link
Contributor

One optimization to consider is if we can avoid having separate KV caches for the common prefix of the parallel runs.

@ggerganov You can use shared memory/anonymous pages and mmap to map the same physical page to multiple virtual pages, allowing you to reuse the common prompt context without copying it. Only works for CPU side of course, and you can only share full pages of context like this, any cached values that end up in the partial page would need to be either copied manually or recalculated.

So the workflow would go like this:

  1. Initially create a shared memory segment and ftruncate to the calculated cache size
  2. mmap it for prompt processing
  3. When it's time to begin the batched run, calculate how many full pages the prompt context consumes so you know how big you want the shared map
  4. Map an anonymous segment with the calculated cache size so the kernel finds a place in the virtual memory space with enough contiguous, free addresses
  5. Memory map the prompt context shared segment over the start of the new anonymous segment with MAP_FIXED to force clobber the overlapping pages

Disclaimer: I have not looked too deep into the KV cache code, I just assumed it's a large contiguous segment

Reference:

https://nullprogram.com/blog/2016/04/10/

https://man7.org/linux/man-pages/man2/mmap.2.html

https://groups.google.com/g/comp.os.linux.development.system/c/Prx7ExCzsv4

@ejones
Copy link
Collaborator

ejones commented Aug 29, 2023

@ggerganov grammar sampling should work with any stream of tokens. Based on first impressions from the beam search code I wonder if each beam could maintain a grammar (accepting tokens as they're added to a beam) and apply the grammar penalties in the top-k selection? cc @mattpulver

@mattpulver
Copy link
Contributor

Currently llama_grammar has 3 member variables:

llama.cpp/llama.cpp

Lines 3434 to 3440 in 53885d7

struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
};

If each beam is to hold its own grammar state, then presumably they can share a const reference to rules but must have their own copies of stacks and partial_utf8. These will be copied as beams branch. These can be added to

llama.cpp/llama.cpp

Lines 4334 to 4337 in 53885d7

struct llama_beam {
std::vector<llama_token> tokens;
float p; // Cumulative beam probability (renormalized relative to all beams)
bool eob; // Initialize end-of-beam to false. Callback sets this to true.

Then we can abstract out the call to get the next_tokens for each beam, at or around

std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);

and replace it with a call to a user-supplied callback function that will determine the next_tokens for each beam. The callback will have access to the above llama_grammar variables, and will only need to select the top n_beams tokens for each beam.

Does that sound satisfactory?

@ejones
Copy link
Collaborator

ejones commented Aug 29, 2023

Yeah, that sounds about right! rules can definitely be shared. And then as the beams are created/tokens added to beams, the grammar state needs to be updated with (the equivalent of) llama_grammar_accept_token.

@ggerganov
Copy link
Owner Author

If each beam is to hold its own grammar state, then presumably they can share a const reference to rules but must have their own copies of stacks and partial_utf8.

How big in terms of memory is rules? If not too big, I would rather just have a copy of the entire grammar in each beam to make things simpler.

@mattpulver
Copy link
Contributor

mattpulver commented Aug 30, 2023

There was an earlier experiment w/ SQL grammar in which the rules got somewhat large. However to your point, sharing rules is an optimization we can keep separate from this PR either way. Thus the first iteration of this PR can continue to use the existing struct llama_grammar without defining a light-weight llama_grammar_view, etc.

If/when we do decide to implement that optimization, consider replacing

const std::vector<std::vector<llama_grammar_element>> rules;

with

using llama_grammar_rules = std::vector<std::vector<llama_grammar_element>>;
std::shared_ptr<const llama_grammar_rules> rules;

Then the rules can be lightly/freely copied after instantiation.

@ggerganov
Copy link
Owner Author

Moving the discussion about the advanced beam search in #2923

@dougmoscrop
Copy link

Only works for CPU side of course,

At a glance, CUDA seems to have low level virtual memory APIs now?

A GPU-based draft model handing batches of 100 over to a cpu model to pick from, is that where this is going potentially?

@kiratp
Copy link

kiratp commented Sep 4, 2023

i have pored over the code for a while now but as someone who doesn't write C everyday, apologies if this is a dumb question.

@ggerganov I have been trying to wrap my head around this point you made earlier:

One optimization to consider is if we can avoid having separate KV caches for the common prefix of the parallel runs. The straightforward implementation would create a copy of this for each batch, while in theory we need just one. Not sure how complicated it would be to handle this. Might need to implement Paged Attention, which is probably a job for another time

The server serving multiple requests is tangentially related to this functionality. The problem there would be to have a logic that batches the different requests together - but this is high-level logic not related to the llama.cpp library

Let say the server has 100 requests outstanding.
Lets say batch size == 16

The ideal implementation of batching would batch 16 requests of similar length into one request into llama.cpp eval() i.e. continuous batching like vLLM.ai and HF text inference does. (https://github.com/huggingface/text-generation-inference/tree/main/router)

The goal of doing this would be to perform multiple parallel inferences for each forward batch pass, amortizing the cost of retrieving things from RAM across multiple batches (even though tokens are different). It seems to me this PR is going 90% of the way there by enabling batched inference for the same shared prompt but then allowing the context to diverge once tokens are generated. (and we are already talking about separate infrence states per batch anyway).

In order to support proper cross-request batching in the server, wouldn't the llama.cpp API need to support at least a vector of kv caches (or 2D?)?

@ggerganov
Copy link
Owner Author

Done via #3228

@ggerganov ggerganov moved this from Todo to Done in ggml : roadmap Sep 28, 2023
@jammm
Copy link
Contributor

jammm commented Oct 11, 2023

Does llama-bench support this for benchmarking now?

@slaren
Copy link
Collaborator

slaren commented Oct 11, 2023

It doesn't, but @ggerganov is working on another tool for that in #3545.

@segmond
Copy link

segmond commented Jan 23, 2024

How does one use this batched inference?

@PenutChen
Copy link

@segmond refer to server example --parallel and --cont-batching arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics
Projects
Status: Done
Development

No branches or pull requests