Add self‑speculative decoding (no draft model required)#18471
Add self‑speculative decoding (no draft model required)#18471ggerganov merged 34 commits intoggml-org:masterfrom
Conversation
|
@srogmann Is there paper for this PR? |
|
Is this similar to ngram ? #4235 |
|
Thanks for contributing - we should definitely revive the lookup decoding functionality and integrate it into The patch here is simple which is nice, but it's better to build on top of the existing Make sure to check @JohannesGaessler's previous work (#6828, #8648) on this topic. The server integration got stale, but I think now it should be much simpler to implement. The lookup example is a good starting point of how to use the Ideally, you should wrap the |
|
I previously did some work along these lines and will happily share my experiences. If you would like to talk via VoIP, please send me an email (can be found on my Github profile). The code in Generally speaking, when it comes to speculative techniques I think it would be useful to establish testing methodology for token generation latency as a function of the number of draft tokens. We can expect the model evaluation to become slower as more draft tokens are added and we can use the ratio vs. 0 draft tokens to calculate what minimum number of draft tokens we would need to accept on average in order to break even. Ideally we could then in turn use these numbers for our drafting code to control how confident a draft should be to even try it. Very long-term I think the way to go will be to distill the original model into a very small neural network and to use that to generate drafts.
There are possibly papers related to this technique but I'm not sure they would be of much use. This is 95% an engineering problem rather than a science problem. |
|
The current implementation of this PR looks for the last matching
As a first step I moved the call |
ggerganov
left a comment
There was a problem hiding this comment.
Likely we have to maintain 2 instances of common_speculative for each slot:
common_speculative * spec_dft = nullptr;
common_speculative * spec_self = nullptr;The spec_dft will be used as we currently do when a draft model is specified.
The spec_self will always be present - need to update server_context_impl::load_model() to construct it.
LMK if these comments make sense.
tools/server/server-context.cpp
Outdated
|
|
||
| int get_n_draft_max() const { | ||
| if (!can_speculate()) { | ||
| if (!can_speculate() && !task->params.speculative.use_self) { |
There was a problem hiding this comment.
Refactor this slightly:
- Move the
llama_context * ctx_dftfromstruct server_slotintocommon_speculative * spec, so that nowspecwould manage thectx_dftlifetime - Avoid direct references to
ctx_dft- instead referencespecwhere needed.
For example, can_speculate() becomes:
bool can_speculate() const {
return !!spec;
}- This condition here remains simply
if (!can_speculate()) {
common/speculative.cpp
Outdated
| if (params.self_mode == 1) { | ||
| // Look in the current context for a n-gram and return the following tokens as the draft. | ||
| llama_tokens draft_self = common_speculative_gen_self_draft(prompt_tgt_main_model, id_last, | ||
| params.self_ngram_size, params.n_draft); | ||
| if (!draft_self.empty()) { | ||
| return draft_self; | ||
| } | ||
| } | ||
| if (spec == nullptr) { | ||
| return {}; | ||
| } |
There was a problem hiding this comment.
Eventually, the common_speculative should become an abstraction over different types of implementation:
- The current one using a draft model
- The new one using the prompt
- Another one using
ngram-cache - etc.
|
@ngxson Do we need to whitelist the |
|
All flags should be allowed by default on router mode, we only implement blacklist and not whitelist. The whitelisting flags only applies to remote presets (since they are downloaded from internet). But I think it's not relevant for now |
|
I've been running some tests with this PR for coding stuff and think it works well. I like the functionality of being able to combine a fast draftless speculation based on the context + a draft-model speculation as a fallback. The former is very useful when the model has to reproduce an existing chunk of text (such as code during refactor) and is essentially free speedup, while the latter helps in free-form thinking/reasoning/generation parts of the inference. I think the explicit "draftless" categorization of the speculative approaches is not really necessary and we can remove it in order to simplify. In bc33838 I've removed the term from the parameters in the code. Propose to rename |
common/ngram-map.cpp
Outdated
| if (map.idx_last_check + map.check_rate > cur_len && cur_len > map.idx_last_check) { | ||
| return; | ||
| } |
There was a problem hiding this comment.
@srogmann This improves performance with reasoning models because the prompt for the new message does not include the reasoning from the previous message, so with the old logic we skipped a lot of speculations.
There was a problem hiding this comment.
because the prompt for the new message does not include the reasoning from the previous message
Oh, so ngram-map-k and ngram-map-k4v should delete the last map entries if cur_len < map.idx_last_check.
There was a problem hiding this comment.
@ggerganov
Implementations like ngramp-map-k and ngram-map-k4v should be notified when the server truncates or modifies the token history, as they need to remove the keys of the n-grams of a skipped reasoning part).
The server could inform the speculative decoder when tokens have been removed (to be addressed in a follow-up PR?):
// informs the speculative decoder if the token history has been reduced by the server
void common_speculative_size(struct common_speculative * spec, size_t new_size);
There was a problem hiding this comment.
Ok, I'll revert this change for now and will follow-up with a new PR that does some refactoring and introduces this extra interface for notifying the speculative context.
common/speculative.cpp
Outdated
| case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: | ||
| { | ||
| // Use common_ngram_map_draft to generate a draft from the current context. | ||
| auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get()); | ||
| if (state) { | ||
| common_ngram_map_draft(state->map, prompt_tgt, id_last, result); | ||
| } else { | ||
| GGML_ABORT("unexpected implementation in type %d", impl.get()->type); | ||
| } | ||
| } break; | ||
| case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: | ||
| { | ||
| // Use common_ngram_map_draft to generate a draft from the current context. | ||
| auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get()); | ||
| if (state) { | ||
| common_ngram_map_draft(state->map, prompt_tgt, id_last, result); | ||
| } else { | ||
| GGML_ABORT("unexpected implementation in type %d", impl.get()->type); | ||
| } | ||
| } break; |
There was a problem hiding this comment.
It looks like currently there is no difference between COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K and COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, or am I missing something?
Edit: nvm - found the difference in key_only:
llama.cpp/common/speculative.cpp
Lines 283 to 291 in 003c903
There was a problem hiding this comment.
Edit: nvm - found the difference in key_only:
Yes, that is the difference. I should have added a comment for the case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V (why common_speculative_ngram_map_k).
- In
ngram-map-konly one value is stored per key. This is very similar tongram-simple, but inngram-map-kthe number of accepted tokens is stored per key. - In
ngram-map-k4vup to four values are stored per key so the implementation can check whether one value is preferred.
|
|
@srogmann I added the |
|
Just a heads up, with this PR I get this issue when building llamacpp with: cmake -B lenux Fedora 42, CUDA 13.1 Reverting to the previous commit (ebf5725) works fine. On the other hand, on Windows 11 it builds fine with
|
Would it be possible to port this to ik? |
…ctor (ggml-org#18471) * server: introduce self-speculative decoding * server: moved self-call into speculative.cpp * can_speculate() includes self-speculation Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: can_speculate() tests self-spec * server: replace can_speculate() with slot.can_speculate() Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * common: use %zu format specifier for size_t in logging Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * server: can_speculate() requires a task instance * common: ngram map, config self-speculative decoding * common: add enum common_speculative_type * common: add vector of speculative states * common: add option --spec-draftless * server: cleanup (remove slot.batch_spec, rename) * common: moved self-spec impl to ngram-map * common: cleanup (use common_speculative_state_draft) * spec : refactor * cont : naming * spec: remove --spec-config * doc: (draftless) speculative decoding * common: print performance in spec decoding * minor : cleanup * common : better names * minor : cleanup + fix build * minor: comments * CODEOWNERS: add common/ngram-map.* (ggml-org#18471) * common : rename speculative.draftless_type -> speculative.type * ngram-map : fix uninitialized values * ngram-map : take into account the input can become shorter * ngram-map : revert len check for now * arg : change `--spec-draftless` -> `--spec-type` * spec : add common_speculative_state::accept() * spec : refactor + add common_speculative_begin() * spec : fix begin() call with mtmd * spec : additional refactor + remove common_speculative_params --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
…ctor (ggml-org#18471) * server: introduce self-speculative decoding * server: moved self-call into speculative.cpp * can_speculate() includes self-speculation Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: can_speculate() tests self-spec * server: replace can_speculate() with slot.can_speculate() Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * common: use %zu format specifier for size_t in logging Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * server: can_speculate() requires a task instance * common: ngram map, config self-speculative decoding * common: add enum common_speculative_type * common: add vector of speculative states * common: add option --spec-draftless * server: cleanup (remove slot.batch_spec, rename) * common: moved self-spec impl to ngram-map * common: cleanup (use common_speculative_state_draft) * spec : refactor * cont : naming * spec: remove --spec-config * doc: (draftless) speculative decoding * common: print performance in spec decoding * minor : cleanup * common : better names * minor : cleanup + fix build * minor: comments * CODEOWNERS: add common/ngram-map.* (ggml-org#18471) * common : rename speculative.draftless_type -> speculative.type * ngram-map : fix uninitialized values * ngram-map : take into account the input can become shorter * ngram-map : revert len check for now * arg : change `--spec-draftless` -> `--spec-type` * spec : add common_speculative_state::accept() * spec : refactor + add common_speculative_begin() * spec : fix begin() call with mtmd * spec : additional refactor + remove common_speculative_params --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This PR introduces self-speculative decoding: instead of using a dedicated draft model (which is good, if available, see #18039), the current token history is used to predict future tokens. This can provide a speedup in cases where the output contains repeated parts of the prompt. A typical example is making many small changes in a large source file.
Example 1 (
gpt-oss-120bin VRAM): Translation of a few comments in a Python script (chosen as a favorable case).Same prompt with
--draft-min 12 --draft-max 48 --spec-self 1:To keep the PR simple, the new argument
--spec-selfreuses the samedraft-minanddraft-maxvalues as used for a potential draft model. When combining both speculative decoding methods, these values are shared (no independent tuning of min/max for each method).Example 2 (
Qwen3-235B, with heavy offloading):Same prompt with
--draft-min 15 --draft-max 40 --spec-self 1:This speedup factor (from ~12 to ~21 tokens/s) occurs only in favorable cases with large repeated sections!
The algorithm is simple: search for a pattern of length
draft-minin the token history and use the subsequentdraft-maxtokens for speculation. No further optimizations are implemented. I had the idea for this PR while waiting for a source file to finish at 5 t/s ;-)