Skip to content

Add self‑speculative decoding (no draft model required)#18471

Merged
ggerganov merged 34 commits intoggml-org:masterfrom
srogmann:feature/self-speculative
Jan 28, 2026
Merged

Add self‑speculative decoding (no draft model required)#18471
ggerganov merged 34 commits intoggml-org:masterfrom
srogmann:feature/self-speculative

Conversation

@srogmann
Copy link
Collaborator

This PR introduces self-speculative decoding: instead of using a dedicated draft model (which is good, if available, see #18039), the current token history is used to predict future tokens. This can provide a speedup in cases where the output contains repeated parts of the prompt. A typical example is making many small changes in a large source file.

Example 1 (gpt-oss-120b in VRAM): Translation of a few comments in a Python script (chosen as a favorable case).

slot update_slots: id  3 | task 0 | created context checkpoint 1 of 8 (pos_min = 324, pos_max = 2883, size = 90.030 MiB)
slot print_timing: id  3 | task 0 | 
prompt eval time =     436.48 ms /  2948 tokens (    0.15 ms per token,  6754.03 tokens per second)
       eval time =   18886.86 ms /  3423 tokens (    5.52 ms per token,   181.24 tokens per second)
      total time =   19323.34 ms /  6371 tokens
slot      release: id  3 | task 0 | stop processing: n_tokens = 6370, truncated = 0

Same prompt with --draft-min 12 --draft-max 48 --spec-self 1:

slot update_slots: id  3 | task 0 | created context checkpoint 1 of 8 (pos_min = 324, pos_max = 2883, size = 90.030 MiB)
slot print_timing: id  3 | task 0 | 
prompt eval time =     431.85 ms /  2948 tokens (    0.15 ms per token,  6826.38 tokens per second)
       eval time =    7163.27 ms /  3193 tokens (    2.24 ms per token,   445.75 tokens per second)
      total time =    7595.13 ms /  6141 tokens
draft acceptance rate = 0.76827 ( 2397 accepted /  3120 generated)
slot      release: id  3 | task 0 | stop processing: n_tokens = 6140, truncated = 0

To keep the PR simple, the new argument --spec-self reuses the same draft-min and draft-max values as used for a potential draft model. When combining both speculative decoding methods, these values are shared (no independent tuning of min/max for each method).

Example 2 (Qwen3-235B, with heavy offloading):

slot update_slots: id  3 | task 0 | prompt done, n_tokens = 2962, batch.n_tokens = 914
slot print_timing: id  3 | task 0 |
prompt eval time =   15606.37 ms /  2962 tokens (    5.27 ms per token,   189.79 tokens per second)
       eval time =  252551.71 ms /  2973 tokens (   84.95 ms per token,    11.77 tokens per second)
      total time =  268158.08 ms /  5935 tokens
srv  log_server_r: request: POST /v1/chat/completions 192.168.32.208 200

Same prompt with --draft-min 15 --draft-max 40 --spec-self 1:

slot update_slots: id  3 | task 0 | prompt done, n_tokens = 2962, batch.n_tokens = 914
slot print_timing: id  3 | task 0 | 
prompt eval time =   15474.80 ms /  2962 tokens (    5.22 ms per token,   191.41 tokens per second)
       eval time =  141116.29 ms /  2963 tokens (   47.63 ms per token,    21.00 tokens per second)
      total time =  156591.09 ms /  5925 tokens
draft acceptance rate = 0.86304 ( 2382 accepted /  2760 generated)

This speedup factor (from ~12 to ~21 tokens/s) occurs only in favorable cases with large repeated sections!

The algorithm is simple: search for a pattern of length draft-min in the token history and use the subsequent draft-max tokens for speculation. No further optimizations are implemented. I had the idea for this PR while waiting for a source file to finish at 5 t/s ;-)

@congson1293
Copy link

@srogmann Is there paper for this PR?

@malaiwah
Copy link

Is this similar to ngram ? #4235

@ggerganov
Copy link
Member

Thanks for contributing - we should definitely revive the lookup decoding functionality and integrate it into llama-server now that we have relatively good small-batch decoding performance across most backends.

The patch here is simple which is nice, but it's better to build on top of the existing ngram-cache functionality (#5479).

Make sure to check @JohannesGaessler's previous work (#6828, #8648) on this topic. The server integration got stale, but I think now it should be much simpler to implement. The lookup example is a good starting point of how to use the ngram-cache functionality.

Ideally, you should wrap the ngram-cache states in the common_speculative object and then we can use the same common_speculative_gen_draft() call for both self and non-self drafting (i.e. common_speculative_gen_self_draft() should not be necessary).

@JohannesGaessler
Copy link
Collaborator

I previously did some work along these lines and will happily share my experiences. If you would like to talk via VoIP, please send me an email (can be found on my Github profile). The code in common/ngram-cache.cpp was developed in the context of models with a vocabulary size of 32k where correctly guessing the next token was comparatively much easier. There is no guarantee that the heuristics I came up with are optimal and I would also be fine with the code being scrapped if someone else is willing to develop and maintain a competing version.

Generally speaking, when it comes to speculative techniques I think it would be useful to establish testing methodology for token generation latency as a function of the number of draft tokens. We can expect the model evaluation to become slower as more draft tokens are added and we can use the ratio vs. 0 draft tokens to calculate what minimum number of draft tokens we would need to accept on average in order to break even. Ideally we could then in turn use these numbers for our drafting code to control how confident a draft should be to even try it.

Very long-term I think the way to go will be to distill the original model into a very small neural network and to use that to generate drafts.

Is there paper for this PR?

There are possibly papers related to this technique but I'm not sure they would be of much use. This is 95% an engineering problem rather than a science problem.

@srogmann
Copy link
Collaborator Author

The current implementation of this PR looks for the last matching draft-min-gram in the context and uses the following draft-max tokens as draft.

then we can use the same common_speculative_gen_draft() call for both self and non-self drafting

As a first step I moved the call common_speculative_gen_self_draft() into common/speculative.cpp.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely we have to maintain 2 instances of common_speculative for each slot:

    common_speculative * spec_dft  = nullptr;
    common_speculative * spec_self = nullptr;

The spec_dft will be used as we currently do when a draft model is specified.

The spec_self will always be present - need to update server_context_impl::load_model() to construct it.

LMK if these comments make sense.


int get_n_draft_max() const {
if (!can_speculate()) {
if (!can_speculate() && !task->params.speculative.use_self) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor this slightly:

  • Move the llama_context * ctx_dft from struct server_slot into common_speculative * spec, so that now spec would manage the ctx_dft lifetime
  • Avoid direct references to ctx_dft - instead reference spec where needed.

For example, can_speculate() becomes:

    bool can_speculate() const {
        return !!spec;
    }
  • This condition here remains simply if (!can_speculate()) {

Comment on lines 190 to 200
if (params.self_mode == 1) {
// Look in the current context for a n-gram and return the following tokens as the draft.
llama_tokens draft_self = common_speculative_gen_self_draft(prompt_tgt_main_model, id_last,
params.self_ngram_size, params.n_draft);
if (!draft_self.empty()) {
return draft_self;
}
}
if (spec == nullptr) {
return {};
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, the common_speculative should become an abstraction over different types of implementation:

  • The current one using a draft model
  • The new one using the prompt
  • Another one using ngram-cache
  • etc.

@CISC CISC closed this Jan 1, 2026
@CISC CISC reopened this Jan 1, 2026
@github-actions github-actions bot added documentation Improvements or additions to documentation model Model specific script Script related testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs labels Jan 1, 2026
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge after the CI is green.

@srogmann Consider adding yourself to CODEOWNERS if wish to help with maintenance in the future. Thanks.

@ggerganov
Copy link
Member

@ngxson Do we need to whitelist the --spec-draftless argument somewhere so it works with the model router? Haven't tested this yet - it might already work, not sure.

@ngxson
Copy link
Collaborator

ngxson commented Jan 26, 2026

All flags should be allowed by default on router mode, we only implement blacklist and not whitelist.

The whitelisting flags only applies to remote presets (since they are downloaded from internet). But I think it's not relevant for now

@ggerganov
Copy link
Member

I've been running some tests with this PR for coding stuff and think it works well.

I like the functionality of being able to combine a fast draftless speculation based on the context + a draft-model speculation as a fallback. The former is very useful when the model has to reproduce an existing chunk of text (such as code during refactor) and is essentially free speedup, while the latter helps in free-form thinking/reasoning/generation parts of the inference.

I think the explicit "draftless" categorization of the speculative approaches is not really necessary and we can remove it in order to simplify. In bc33838 I've removed the term from the parameters in the code.

Propose to rename --spec-draftless CLI arg to --spec-type?

Comment on lines 122 to 124
if (map.idx_last_check + map.check_rate > cur_len && cur_len > map.idx_last_check) {
return;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srogmann This improves performance with reasoning models because the prompt for the new message does not include the reasoning from the previous message, so with the old logic we skipped a lot of speculations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the prompt for the new message does not include the reasoning from the previous message

Oh, so ngram-map-k and ngram-map-k4v should delete the last map entries if cur_len < map.idx_last_check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov
Implementations like ngramp-map-k and ngram-map-k4v should be notified when the server truncates or modifies the token history, as they need to remove the keys of the n-grams of a skipped reasoning part).

The server could inform the speculative decoder when tokens have been removed (to be addressed in a follow-up PR?):

// informs the speculative decoder if the token history has been reduced by the server
void common_speculative_size(struct common_speculative * spec, size_t new_size);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll revert this change for now and will follow-up with a new PR that does some refactoring and introduces this extra interface for notifying the speculative context.

Comment on lines 777 to 796
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
{
// Use common_ngram_map_draft to generate a draft from the current context.
auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
if (state) {
common_ngram_map_draft(state->map, prompt_tgt, id_last, result);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
} break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
{
// Use common_ngram_map_draft to generate a draft from the current context.
auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
if (state) {
common_ngram_map_draft(state->map, prompt_tgt, id_last, result);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
} break;
Copy link
Member

@ggerganov ggerganov Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like currently there is no difference between COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K and COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, or am I missing something?

Edit: nvm - found the difference in key_only:

static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
uint16_t size_key = config.params.ngram_size_n;
uint16_t size_value = config.params.ngram_size_m;
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
uint16_t check_rate = config.params.ngram_check_rate;
uint16_t min_hits = config.params.ngram_min_hits;
return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: nvm - found the difference in key_only:

Yes, that is the difference. I should have added a comment for the case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V (why common_speculative_ngram_map_k).

  • In ngram-map-k only one value is stored per key. This is very similar to ngram-simple, but in ngram-map-k the number of accepted tokens is stored per key.
  • In ngram-map-k4v up to four values are stored per key so the implementation can check whether one value is preferred.

@srogmann
Copy link
Collaborator Author

Propose to rename --spec-draftless CLI arg to --spec-type?

--spec-type would be easier to understand and remember and it would be open for other implementations.

@ggerganov
Copy link
Member

@srogmann I added the begin() API and did some refactoring. I think this is a good point to merge. In next PR you can try to address the issue from #18471 (comment), since the existing logic is suboptimal for reasoning models and when re-generating from an earlier point in the conversation.

@ggerganov ggerganov merged commit 72d3b18 into ggml-org:master Jan 28, 2026
77 of 78 checks passed
@ggerganov ggerganov mentioned this pull request Jan 28, 2026
1 task
@Panchovix
Copy link

Just a heads up, with this PR I get this issue when building llamacpp with:

cmake -B lenux
-DGGML_CUDA=ON
-DGGML_CUDA_FA_ALL_QUANTS=ON
-DGGML_BLAS=OFF
-DGGML_RPC=ON
-DCMAKE_CUDA_ARCHITECTURES="86;89;120"
-DGGML_MAX_CONTEXTS=2048
-DGGML_SCHED_MAX_COPIES=1 \

Fedora 42, CUDA 13.1

[ 99%] Generating index.html.gz.hpp
[ 99%] Built target llama-imatrix
[ 99%] Built target llama-batched-bench
[ 99%] Built target llama-completion
[ 99%] Built target llama-perplexity
[ 99%] Built target llama-tts
/usr/bin/ld: ../server/libserver-context.a(server-task.cpp.o): in function `server_task::params_from_json_cmpl(llama_vocab const*, common_params const&, int, nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> >, void> const&)':
server-task.cpp:(.text+0xb634): undefined reference to `_ZNSt7__cxx1112basic_stringIcSt11char_traitsicESaIcEE10_M_disposeEv'
[ 99%] Built target llama-cvector-generator
[ 99%] Built target llama-mtmd-cli
[ 99%] Built target llama-export-lora
[ 99%] Built target llama-fit-params
collect2: error: ld returned 1 exit status
gmake[2]: *** [tools/cli/CMakeFiles/llama-cli.dir/build.make:114: bin/llama-cli] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:5468: tools/cli/CMakeFiles/llama-cli.dir/all] Error 2
gmake[1]: *** Waiting for unfinished jobs....
[ 99%] Building CXX object tools/server/CMakeFiles/llama-server.dir/server.cpp.o
[ 99%] Building CXX object tools/server/CMakeFiles/llama-server.dir/server-models.cpp.o
[ 99%] Building CXX object tools/server/CMakeFiles/llama-server.dir/server-http.cpp.o
[100%] Linking CXX executable ../../bin/llama-server
/usr/bin/ld: libserver-context.a(server-task.cpp.o): in function `server_task::params_from_json_cmpl(llama_vocab const*, common_params const&, int, nlohmann::json_abi_v3_12_0::basic_json<nlohmann::json_abi_v3_12_0::ordered_map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::json_abi_v3_12_0::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> >, void> const&)':
server-task.cpp:(.text+0xb634): undefined reference to `_ZNSt7__cxx1112basic_stringIcSt11char_traitsicESaIcEE10_M_disposeEv'
collect2: error: ld returned 1 exit status
gmake[2]: *** [tools/server/CMakeFiles/llama-server.dir/build.make:154: bin/llama-server] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:5551: tools/server/CMakeFiles/llama-server.dir/all] Error 2
gmake: *** [Makefile:146: all] Error 2

Reverting to the previous commit (ebf5725) works fine.

On the other hand, on Windows 11 it builds fine with

cmake -B windows -DGGML_CUDA=ON -DGGML_RPC=ON -DGGML_CUDA_FA_ALL_QUANTS=ON -DGGML_BLAS=OFF -DCMAKE_CUDA_ARCHITECTURES="86;89;120" -DLLAMA_CURL=OFF -DGGML_MAX_CONTEXTS=2048

@MrHills-rs
Copy link

This PR introduces self-speculative decoding: instead of using a dedicated draft model (which is good, if available, see #18039), the current token history is used to predict future tokens. This can provide a speedup in cases where the output contains repeated parts of the prompt. A typical example is making many small changes in a large source file.

Example 1 (gpt-oss-120b in VRAM): Translation of a few comments in a Python script (chosen as a favorable case).

slot update_slots: id  3 | task 0 | created context checkpoint 1 of 8 (pos_min = 324, pos_max = 2883, size = 90.030 MiB)
slot print_timing: id  3 | task 0 | 
prompt eval time =     436.48 ms /  2948 tokens (    0.15 ms per token,  6754.03 tokens per second)
       eval time =   18886.86 ms /  3423 tokens (    5.52 ms per token,   181.24 tokens per second)
      total time =   19323.34 ms /  6371 tokens
slot      release: id  3 | task 0 | stop processing: n_tokens = 6370, truncated = 0

Same prompt with --draft-min 12 --draft-max 48 --spec-self 1:

slot update_slots: id  3 | task 0 | created context checkpoint 1 of 8 (pos_min = 324, pos_max = 2883, size = 90.030 MiB)
slot print_timing: id  3 | task 0 | 
prompt eval time =     431.85 ms /  2948 tokens (    0.15 ms per token,  6826.38 tokens per second)
       eval time =    7163.27 ms /  3193 tokens (    2.24 ms per token,   445.75 tokens per second)
      total time =    7595.13 ms /  6141 tokens
draft acceptance rate = 0.76827 ( 2397 accepted /  3120 generated)
slot      release: id  3 | task 0 | stop processing: n_tokens = 6140, truncated = 0

To keep the PR simple, the new argument --spec-self reuses the same draft-min and draft-max values as used for a potential draft model. When combining both speculative decoding methods, these values are shared (no independent tuning of min/max for each method).

Example 2 (Qwen3-235B, with heavy offloading):

slot update_slots: id  3 | task 0 | prompt done, n_tokens = 2962, batch.n_tokens = 914
slot print_timing: id  3 | task 0 |
prompt eval time =   15606.37 ms /  2962 tokens (    5.27 ms per token,   189.79 tokens per second)
       eval time =  252551.71 ms /  2973 tokens (   84.95 ms per token,    11.77 tokens per second)
      total time =  268158.08 ms /  5935 tokens
srv  log_server_r: request: POST /v1/chat/completions 192.168.32.208 200

Same prompt with --draft-min 15 --draft-max 40 --spec-self 1:

slot update_slots: id  3 | task 0 | prompt done, n_tokens = 2962, batch.n_tokens = 914
slot print_timing: id  3 | task 0 | 
prompt eval time =   15474.80 ms /  2962 tokens (    5.22 ms per token,   191.41 tokens per second)
       eval time =  141116.29 ms /  2963 tokens (   47.63 ms per token,    21.00 tokens per second)
      total time =  156591.09 ms /  5925 tokens
draft acceptance rate = 0.86304 ( 2382 accepted /  2760 generated)

This speedup factor (from ~12 to ~21 tokens/s) occurs only in favorable cases with large repeated sections!

The algorithm is simple: search for a pattern of length draft-min in the token history and use the subsequent draft-max tokens for speculation. No further optimizations are implemented. I had the idea for this PR while waiting for a source file to finish at 5 t/s ;-)

Would it be possible to port this to ik?

ikawrakow/ik_llama.cpp#1197

4b1tQu4ntN3k0 pushed a commit to 4b1tQu4ntN3k0/llama.cpp that referenced this pull request Feb 2, 2026
…ctor (ggml-org#18471)

* server: introduce self-speculative decoding

* server: moved self-call into speculative.cpp

* can_speculate() includes self-speculation

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: can_speculate() tests self-spec

* server: replace can_speculate() with slot.can_speculate()

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* common: use %zu format specifier for size_t in logging

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* server: can_speculate() requires a task instance

* common: ngram map, config self-speculative decoding

* common: add enum common_speculative_type

* common: add vector of speculative states

* common: add option --spec-draftless

* server: cleanup (remove slot.batch_spec, rename)

* common: moved self-spec impl to ngram-map

* common: cleanup (use common_speculative_state_draft)

* spec : refactor

* cont : naming

* spec: remove --spec-config

* doc: (draftless) speculative decoding

* common: print performance in spec decoding

* minor : cleanup

* common : better names

* minor : cleanup + fix build

* minor: comments

* CODEOWNERS: add common/ngram-map.* (ggml-org#18471)

* common : rename speculative.draftless_type -> speculative.type

* ngram-map : fix uninitialized values

* ngram-map : take into account the input can become shorter

* ngram-map : revert len check for now

* arg : change `--spec-draftless` -> `--spec-type`

* spec : add common_speculative_state::accept()

* spec : refactor + add common_speculative_begin()

* spec : fix begin() call with mtmd

* spec : additional refactor + remove common_speculative_params

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
…ctor (ggml-org#18471)

* server: introduce self-speculative decoding

* server: moved self-call into speculative.cpp

* can_speculate() includes self-speculation

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: can_speculate() tests self-spec

* server: replace can_speculate() with slot.can_speculate()

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* common: use %zu format specifier for size_t in logging

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* server: can_speculate() requires a task instance

* common: ngram map, config self-speculative decoding

* common: add enum common_speculative_type

* common: add vector of speculative states

* common: add option --spec-draftless

* server: cleanup (remove slot.batch_spec, rename)

* common: moved self-spec impl to ngram-map

* common: cleanup (use common_speculative_state_draft)

* spec : refactor

* cont : naming

* spec: remove --spec-config

* doc: (draftless) speculative decoding

* common: print performance in spec decoding

* minor : cleanup

* common : better names

* minor : cleanup + fix build

* minor: comments

* CODEOWNERS: add common/ngram-map.* (ggml-org#18471)

* common : rename speculative.draftless_type -> speculative.type

* ngram-map : fix uninitialized values

* ngram-map : take into account the input can become shorter

* ngram-map : revert len check for now

* arg : change `--spec-draftless` -> `--spec-type`

* spec : add common_speculative_state::accept()

* spec : refactor + add common_speculative_begin()

* spec : fix begin() call with mtmd

* spec : additional refactor + remove common_speculative_params

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation examples server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants