-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add llama_beam_search(). #2267
Add llama_beam_search(). #2267
Conversation
This is just going to use up all memory, especially on the GPU. We have new models now like LLaMA2 with 4096 token context and some others with even 8192. The KV cache could be gigabytes. Is it not possible to just save the tokens+n_past info in the beams? |
e98280b
to
51b98d6
Compare
Thank you, that's the kind of hint/reassurance I was looking for. That does seem to work, and avoids making any copies of I'm noticing that as the beams grow, so does the time to get the next token for each beam. Since the beams tend to converge, that is, share a common prefix vector of token_ids, then the common trunk can be "ingested" into the shared I'll try that next and if that works I'll move this PR out of draft mode. |
Alternatively, we can save In any case - this change is welcome. |
Thanks for the tip @ggerganov. I tried copying the This wasn't sufficient to satisfy the memory requirements, e.g. Does What are your thoughts on this? |
For sure no. Not sure about all the details. The closest thing we currently have is the beam-search decoder in Though it does not solve the problem of having I think there should be a "common" KV cache that corresponds to |
51b98d6
to
1ffbc52
Compare
A separate
I could use assistance on the following items, in order of highest priority to least:
|
For the most part, the Also, I used the public llama_eval API and not batch eval, because:
Since all beams share the same prefix, it sounds that sharing the prefix of the cache with some sort of "split cache" is possible but it's probably hard to satisfy all backends. Alternatively, use the same cache then sequentially eval each beam lol. The speed would probably be awful though. |
1ffbc52
to
a7eb5df
Compare
This is getting closer. It seems the main bottleneck involves calling This can be run/tested using After integrating this into Thanks @bullno1 for explaining the logic re: |
Warning, a lot of brain dump, might not be coherent. On efficiencyI feel like there is a lot of cache invalidation with how common prefix keeps getting revaluated: https://github.com/ggerganov/llama.cpp/pull/2267/files#diff-150dc86746a90bad4fc2c3334aeb9b5887b3adad3cc1459446717638605348efR3071 If I understand this correctly, let's say if we have n beams, there should only be at most n kv caches at a time because after expansions, we limit the list to n beams anyway. If we utilizes the cache of the original context, it's only n - 1 extra caches to be created. This sounds possible with the current public API. If n_beam = 2, it has the same memory overhead as CFG anw. On API designI think it should return a span of tokens instead of text: struct llama_beam {
size_t length;
llama_token* tokens;
}; We can add a Also, since this runs until completion or until it reaches limit, I have no idea how to even "stream" the result. Talking about chat, we often need an "antiprompt" stop condition which would take a bit more work to support. We would need something like a predicate on the beam search: typedef bool (*llama_beam_predicate_fn_t)(const struct llama_beam * beam, void * userdata); This is actually general enough for both anti prompt and/or a length limit. Putting it all together I think the API should be: struct llama_beam {
size_t length;
const llama_token* tokens;
};
typedef bool (*llama_beam_predicate_fn_t)(const struct llama_beam * beam, void * userdata);
struct llama_beam_stop_condition {
llama_beam_predicate_fn_t fn;
void * userdata;
};
LLAMA_API const struct llama_beam * llama_beam_search(
struct llama_context * ctx,
int n_past,
int n_beams,
struct llama_beam_stop_condition stop_condition,
int n_threads
);
// Stop after a number of tokens has been generated
// userdata is a uintptr_t which is the number of tokens
LLAMA_API bool llama_beam_stop_at_n_tokens(const struct llama_beam * beam, void * userdata);
// Stop after a token is encountered such as eos
// userdata is a uintptr_t which is the token id
LLAMA_API bool llama_beam_stop_at_token(const struct llama_beam * beam, void * userdata);
// Stop at a suffix string for anti-prompt
// userdata is a char* pointing to a null-terminated string
LLAMA_API bool llama_beam_stop_at_suffix(const struct llama_beam * beam, void * userdata);
// Logically "OR" all conditions
struct llama_beam_stop_condition_one_of {
int num_conditions;
struct llama_beam_stop_condition * conditions;
};
// When any of the conditions are met
// userdata is a struct llama_stop_condition_one_of*
LLAMA_API bool llama_beam_stop_at_one_of(const struct llama_beam * beam, void * userdata);
// Free the returned beam
LLAMA_API void llama_beam_free(struct llama_beam * beam); On cache reusingSay n_beams=3. At first we have 3 beams:
Then we expand each by 3:
Prune this list down to top 3:
Since both AA and AB are on 0, we need to split:
I think ggml_cpy can do this even for GPU-backed cache as long as both tensors are on GPU. |
91d65a8
to
fbbf0eb
Compare
Actually it does so happen that every once in a while, all beams will partially converge into sharing a common token prefix, of say On API designBased on your API suggestions @bullno1 I've added a callback type:
This is used for example:
The idea is that This seems to cover the use cases you mentioned. These struct definitions can be extended to cover future beam evolution control features. On efficiency + cache reuseI did an experiment in which each beam keeps their own
However this appeared to produce gibberish (due to either a bug in the code, or the Questions
Thanks again for the helpful feedback. Edits made on Aug 11: Replace |
db9657c
to
b989365
Compare
b989365
to
1528660
Compare
Ready-for-review NotesThis can be tested using the included
Additionally there is a llama-cpp-python beam_search branch which is/was working prior to recent breaking changes / improvements related to the new GGUF format, which I intend to open a PR for once this PR gets merged. The helpful feedback above involving stop suggestions are satisfied by a more general The concerns about caching are addressed by the fact that beams often converge to a common token prefix subvector. When this happens, the callback is notified and at that time the common prefix should be stored, as the following iteration will shift it away from all beams, thereby reducing the token vector lengths. Having said this, there are likely additional optimizations that can be further accomplished w.r.t. beam token caching, that are suitable for follow-up PRs. When run with CUBLAS, the qualitative experience is quite satisfactory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting work - I'm still trying to understand the details.
Some minor style changes requested and we can look into merging.
Also, make sure that the implementation in llama.cpp
is ordered correctly. I see it is after the "grammar" stuff, but in the header the declarations are after the "sampling" stuff. So either move the beam search definitions after the sampling or update the header to match the order.
We'll also have to figure out some simple tests for the CI in order to keep this functionality working in the long run. We can do in separate PR.
llama.h
Outdated
@@ -476,6 +476,39 @@ extern "C" { | |||
/// @details Accepts the sampled token into the grammar | |||
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); | |||
|
|||
struct llama_beam_view { | |||
llama_token const* tokens; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama_token const* tokens; | |
llama_token const * tokens; |
llama.h
Outdated
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. | ||
// These pointers are valid only during the synchronous callback, so should not be saved. | ||
struct llama_beams_state { | ||
llama_beam_view* beam_views; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama_beam_view* beam_views; | |
llama_beam_view * beam_views; |
llama.h
Outdated
/// @details Deterministically returns entire sentence constructed by a beam search. | ||
/// @param ctx Pointer to the llama_context. | ||
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. | ||
/// The return beam_search_control can be used to control the beam_search execution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beam_search_control
should probably be changed to eos flag
?
llama.h
Outdated
/// @param n_past Number of tokens already evaluated. | ||
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. | ||
/// @param n_threads Number of threads as passed to llama_eval(). | ||
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); | |
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); |
llama.cpp
Outdated
}; | ||
|
||
// A struct for calculating logit-related info. | ||
struct logit_info { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct logit_info { | |
struct llama_logit_info { |
llama.cpp
Outdated
std::vector<llama_token_data> top_k(size_t k) { | ||
std::vector<llama_token_data> min_heap; // min-heap by logit | ||
llama_token const k_min = std::min(static_cast<llama_token>(k), n_vocab); | ||
min_heap.reserve(k_min); | ||
for (llama_token token_id=0 ; token_id<k_min ; ++token_id) { | ||
min_heap.push_back(get_token_data(token_id)); | ||
} | ||
auto comp = [](llama_token_data const& a, llama_token_data const& b) { return a.logit > b.logit; }; | ||
std::make_heap(min_heap.begin(), min_heap.end(), comp); | ||
for (llama_token token_id=k_min ; token_id<n_vocab ; ++token_id) { | ||
if (min_heap.front().logit < logits[token_id]) { | ||
std::pop_heap(min_heap.begin(), min_heap.end(), comp); | ||
min_heap.back().id = token_id; | ||
min_heap.back().logit = logits[token_id]; | ||
std::push_heap(min_heap.begin(), min_heap.end(), comp); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style changes examples - please apply to rest of non-server code:
std::vector<llama_token_data> top_k(size_t k) { | |
std::vector<llama_token_data> min_heap; // min-heap by logit | |
llama_token const k_min = std::min(static_cast<llama_token>(k), n_vocab); | |
min_heap.reserve(k_min); | |
for (llama_token token_id=0 ; token_id<k_min ; ++token_id) { | |
min_heap.push_back(get_token_data(token_id)); | |
} | |
auto comp = [](llama_token_data const& a, llama_token_data const& b) { return a.logit > b.logit; }; | |
std::make_heap(min_heap.begin(), min_heap.end(), comp); | |
for (llama_token token_id=k_min ; token_id<n_vocab ; ++token_id) { | |
if (min_heap.front().logit < logits[token_id]) { | |
std::pop_heap(min_heap.begin(), min_heap.end(), comp); | |
min_heap.back().id = token_id; | |
min_heap.back().logit = logits[token_id]; | |
std::push_heap(min_heap.begin(), min_heap.end(), comp); | |
} | |
} | |
std::vector<llama_token_data> top_k(size_t k) { | |
std::vector<llama_token_data> min_heap; // min-heap by logit | |
const llama_token k_min = std::min(static_cast<llama_token>(k), n_vocab); | |
min_heap.reserve(k_min); | |
for (llama_token token_id = 0; token_id < k_min ; ++token_id) { | |
min_heap.push_back(get_token_data(token_id)); | |
} | |
auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; | |
std::make_heap(min_heap.begin(), min_heap.end(), comp); | |
for (llama_token token_id = k_min; token_id < n_vocab; ++token_id) { | |
if (min_heap.front().logit < logits[token_id]) { | |
std::pop_heap(min_heap.begin(), min_heap.end(), comp); | |
min_heap.back().id = token_id; | |
min_heap.back().logit = logits[token_id]; | |
std::push_heap(min_heap.begin(), min_heap.end(), comp); | |
} | |
} |
llama.cpp
Outdated
} | ||
}; | ||
|
||
struct beam_search { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct beam_search { | |
struct llama_beam_search { |
@@ -3354,6 +3354,253 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar | |||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; | |||
} | |||
|
|||
struct llama_beam { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct llama_beam { | |
// | |
// beam seach | |
// | |
struct llama_beam { |
llama.h
Outdated
// Type of pointer to the beam_search_callback function. | ||
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently | ||
// passed back to beam_search_callback. This avoids having to use global variables in the callback. | ||
typedef void (*llama_beam_search_callback_fn_t)(void* callback_data, llama_beams_state); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typedef void (*llama_beam_search_callback_fn_t)(void* callback_data, llama_beams_state); | |
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state); |
@@ -476,6 +476,39 @@ extern "C" { | |||
/// @details Accepts the sampled token into the grammar | |||
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); | |||
|
|||
struct llama_beam_view { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//
// Beam search
//
struct llama_beam_view { | |
struct llama_beam_view { |
Do I understand correctly that I can also use the Wondering if this is correct, maybe we should rename the flag to something more descriptive. Like |
1528660
to
5fa1ea2
Compare
Changes
ResponsesThanks for the feedback. I wasn't able to rename
The beam search code are immediately after the 3 functions:
in both Thus the 3 sections are ordered as follows, in both the
Yes. The
Yes, good point. End-of-beam ( I've attempted to address every issue. Please feel free to point out anything I have missed, as it was not intended. p.s. Food for thought: |
* master: (773 commits) server : add `/detokenize` endpoint (ggerganov#2802) convert.py : advanced option (ggerganov#2753) llama : use Unicode Escape Sequence to replace encoded characters (ggerganov#2814) flake.nix : add rocm support and cleanup (ggerganov#2808) llama : move #includes out of _GNU_SOURCE conditional (ggerganov#2817) main : fix bug (penalize_nl=false doesn't work) + suppress warning on mingw (ggerganov#1528) llama : use std::abs in llama_sample_tail_free (ggerganov#2800) k-quants : remove unnecessary tensor shape restrictions (ggerganov#2811) Better perplexity for 2- and 3-bit quantization for LLaMA-v2-70B (ggerganov#2807) Fix HellaSwag (ggerganov#2805) flake : build llama.cpp on Intel with nix (ggerganov#2795) Handle null rope scaling value (ggerganov#2793) Fix spm whitespaces (ggerganov#2806) examples : skip unnecessary external lib in server README.md how-to (ggerganov#2804) llama : fix struct decl (ggerganov#2790) Faster perplexity computation (ggerganov#2786) llama : add llama_beam_search() (ggerganov#2267) convert.py : Get rope scale from HuggingFace models (ggerganov#2772) llama-bench : add model sizes (ggerganov#2771) convert.py : export rope freq_base when converting CodeLlama from an HF model (ggerganov#2773) ...
Ah yeah sorry I put |
* Add llama_beam_search(). * Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token(). * Add space around * pointers and & references. * Add spaces around comparison and assignment operators. * Prefer west const. * Use llama_ prefix for structs in global namespace. * Delete obsolete comment from an earlier revision. * Change eos to eob in llama_beam and llama_beam_view structs.
@@ -1291,22 +1347,30 @@ int main(int argc, char **argv) | |||
llama.beginCompletion(); | |||
|
|||
if (!llama.stream) { | |||
size_t stop_pos = std::string::npos; | |||
if (llama.params.n_beams) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattpulver I noticed this check for llama.params.n_beams
, but n_beams param doesn't seem to be set anywhere. Am I misinterpreting? If I set it myself, will it work along with the grammar for this server example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example of where it is set and used:
params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2; |
In examples/server/server.cpp
I believe it may be set via the command line it should be set by server_params_parse()
but it seems that was not yet done. Feel free to submit that as a PR.
I don't think beam search and grammar will currently work together. That is currently an open item: #2923
Related issue: #1392
This is an initial attempt at beam search. It does appear to work as intended, insofar as generating higher quality deterministic responses.
Currently the execution times seems to slow down noticably as the beams grow in their token vectors. I'm going to see if ingesting the common trunk into the shared
llama_context
improves this, and if so then will move this PR our of draft mode.Thoughts/feedback are welcome.