Skip to content

Comments

spec : add ngram-mod#19164

Merged
ggerganov merged 7 commits intomasterfrom
gg/spec-ngram-mod
Jan 30, 2026
Merged

spec : add ngram-mod#19164
ggerganov merged 7 commits intomasterfrom
gg/spec-ngram-mod

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Jan 28, 2026

cont #18471

Add basic ngram hasher for speculative decoding:

  • For each ngram, compute a hash using LCG
  • For each computed hash, store the next token
  • During speculation, iteratively compute the rolling hash of the last n tokens and pick the next token from the storage

Some characteristics:

  • Lightweight (~16 MB)
  • Constant memory and complexity
  • Can generate variable draft lengths (i.e. m is not fixed)

Currently, a single hash pool is shared across all server slots, so different requests can benefit from each other.

Sample usage:

# notes:
# - small `n` are not recommended
# - MoEs require long drafts
# - dense models: can reduce `--draft-min` and `--draft-max`

llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64

Applications:

Example:

spec-mod-0.mov

TODO:

  • Reset criteria?

@HDembinski
Copy link

Wow, that is wild.

if (!ngram_mod && params_base.speculative.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD) {
ngram_mod = std::make_unique<common_ngram_mod>(params_base.speculative.ngram_size_n, 1024*1024);

params_base.speculative.ngram_mod = ngram_mod.get();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you have to do this way because unique_ptr doesn't accept forward declared struct?

If that's the case, probably using std::shared_ptr or std::optional can be a better hack

void add(const int32_t * tokens);
int32_t get(const int32_t * tokens, int32_t offs) const; // return -1 if not found

uint16_t n; // ngram size to hash
Copy link
Collaborator

@ngxson ngxson Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in multiple places in the code, we need to cast this to size_t, so I think it's probably better to use size_t

Suggested change
uint16_t n; // ngram size to hash
size_t n; // ngram size to hash

@ngxson
Copy link
Collaborator

ngxson commented Jan 28, 2026

  • Reset criteria?

Can EOG/EOS token a good criteria? ngram can be different between user message and assistant message

@ggerganov
Copy link
Member Author

ggerganov commented Jan 29, 2026

  • Reset criteria?

Can EOG/EOS token a good criteria? ngram can be different between user message and assistant message

EOG/EOS seems way too often. The hash container can store a lot of ngram hashes (hundred thousands with the current size) before collisions start to occur.

I'm thinking more about logic such as: if more than X ~ 0.25-0.50 fraction of all possible hashes have been seen so far, randomly erase some fraction from the container.

std::vector<common_ngram_mod_ext_entry> entries;
};

using common_ngram_mod_ext_ptr = std::unique_ptr<common_ngram_mod_ext>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be unused now

std::vector<entry_t> entries;
};

using common_ngram_mod_ptr = std::unique_ptr<common_ngram_mod>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also unused

@ggerganov ggerganov merged commit dabaa2e into master Jan 30, 2026
52 of 75 checks passed
@ggerganov ggerganov deleted the gg/spec-ngram-mod branch January 30, 2026 16:21
@MikeLP
Copy link

MikeLP commented Jan 30, 2026

Looks like llama-bench doesn't know about --spec-type ngram-mod param.

@characharm
Copy link
Contributor

It seems this PR has an additional positive side effect: in the case of GPT-OSS in high mode, when the model falls into a reasoning loop, it can now recover much faster. Token generation jumps to around 200, and the model even produces a meaningful result.

@ggerganov
Copy link
Member Author

@MikeLP This does not affect llama-bench - it needs actual context to do anything.

@characharm Yes, I also noticed that. Overall, I think this speculator can become enabled by default in llama-server.

@dagbdagb
Copy link

Sample usage:

# notes:
# - small `n` are not recommended
# - MoEs require long drafts
# - dense models: can reduce `--draft-min` and `--draft-max`

llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64

Is this example for a MoE or a dense model? I have no intuitive feel for what constitutes 'small n' or a 'long draft'.

I assume the optimal value depends on both model, model architecture and task at hand.

@bfroemel
Copy link

On the same prompt (just to repeat in verbatim 200 lines of given source code) I sometimes see a draft acceptance rate of 0, while on most other runs it's 0.90+ on gpt-oss-120b with llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64. Generated model output (the repeated 200 lines of source code) is exactly the same for good and bad cases. Somehow expected, or should I submit a bug report?

Below logs of a bad case followed by a good case. (I also observed a good case right after starting llama-server, so it's not like that the first request is always "bad").

Log
main: model loaded
main: server is listening on http://0.0.0.0:8060
main: starting the main loop...
srv  update_slots: all slots are idle
srv  params_from_: Chat format: GPT-OSS
slot get_availabl: id  3 | task -1 | selected slot by LRU, t_last = -1
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> ?top-k -> ?typical -> ?top-p -> min-p -> ?xtc -> ?temp-ext -> dist 
slot launch_slot_: id  3 | task 0 | processing task, is_child = 0
slot update_slots: id  3 | task 0 | new prompt, n_ctx_slot = 64000, n_keep = 0, task.n_tokens = 1543
slot update_slots: id  3 | task 0 | n_tokens = 0, memory_seq_rm [0, end)
slot update_slots: id  3 | task 0 | prompt processing progress, n_tokens = 1479, batch.n_tokens = 1479, progress = 0.958522
slot update_slots: id  3 | task 0 | n_tokens = 1479, memory_seq_rm [1479, end)
slot update_slots: id  3 | task 0 | prompt processing progress, n_tokens = 1543, batch.n_tokens = 64, progress = 1.000000
slot update_slots: id  3 | task 0 | prompt done, n_tokens = 1543, batch.n_tokens = 64
slot init_sampler: id  3 | task 0 | init sampler, took 0.19 ms, tokens: text = 1543, total = 1543
slot update_slots: id  3 | task 0 | created context checkpoint 1 of 8 (pos_min = 455, pos_max = 1478, size = 36.012 MiB)
begin: ngram_mod occupancy = 1519/4194304 (0.00)
slot print_timing: id  3 | task 0 | 
prompt eval time =     399.04 ms /  1543 tokens (    0.26 ms per token,  3866.76 tokens per second)
       eval time =    8750.09 ms /  1589 tokens (    5.51 ms per token,   181.60 tokens per second)
      total time =    9149.14 ms /  3132 tokens
draft acceptance rate = 0.00000 (    0 accepted /    64 generated)
statistics ngram_mod: #calls = 1588, #gen drafts = 1, #acc drafts = 0, #gen tokens = 64, #acc tokens = 0, dur = 0.731 ms
slot      release: id  3 | task 0 | stop processing: n_tokens = 3131, truncated = 0
srv  update_slots: all slots are idle
srv  log_server_r: done request: POST /v1/chat/completions xxx.xxx.xxx.xxx 200
srv  params_from_: Chat format: GPT-OSS
slot get_availabl: id  3 | task -1 | selected slot by LCP similarity, sim_best = 1.000 (> 0.100 thold), f_keep = 0.493
srv  get_availabl: updating prompt cache
srv   prompt_save:  - saving prompt with length 3131, total state size = 146.123 MiB
srv          load:  - looking for better prompt, base f_keep = 0.493, sim = 1.000
srv        update:  - cache state: 1 prompts, 182.135 MiB (limits: 8192.000 MiB, 64000 tokens, 140825 est)
srv        update:    - prompt 0x5567835dd300:    3131 tokens, checkpoints:  1,   182.135 MiB
srv  get_availabl: prompt cache update took 36.46 ms
slot launch_slot_: id  3 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> ?top-k -> ?typical -> ?top-p -> min-p -> ?xtc -> ?temp-ext -> dist 
slot launch_slot_: id  3 | task 1591 | processing task, is_child = 0
slot update_slots: id  3 | task 1591 | new prompt, n_ctx_slot = 64000, n_keep = 0, task.n_tokens = 1543
slot update_slots: id  3 | task 1591 | n_past = 1543, slot.prompt.tokens.size() = 3131, seq_id = 3, pos_min = 2107, n_swa = 128
slot update_slots: id  3 | task 1591 | restored context checkpoint (pos_min = 455, pos_max = 1478, size = 36.012 MiB)
slot update_slots: id  3 | task 1591 | n_tokens = 1478, memory_seq_rm [1478, end)
slot update_slots: id  3 | task 1591 | prompt processing progress, n_tokens = 1479, batch.n_tokens = 1, progress = 0.958522
slot update_slots: id  3 | task 1591 | n_tokens = 1479, memory_seq_rm [1479, end)
slot update_slots: id  3 | task 1591 | prompt processing progress, n_tokens = 1543, batch.n_tokens = 64, progress = 1.000000
slot update_slots: id  3 | task 1591 | prompt done, n_tokens = 1543, batch.n_tokens = 64
slot init_sampler: id  3 | task 1591 | init sampler, took 0.18 ms, tokens: text = 1543, total = 1543
begin: ngram_mod occupancy = 3101/4194304 (0.00)
slot print_timing: id  3 | task 1591 | 
prompt eval time =      53.72 ms /    65 tokens (    0.83 ms per token,  1210.09 tokens per second)
       eval time =    1864.30 ms /  1517 tokens (    1.23 ms per token,   813.71 tokens per second)
      total time =    1918.01 ms /  1582 tokens
draft acceptance rate = 0.92188 ( 1416 accepted /  1536 generated)
statistics ngram_mod: #calls = 1688, #gen drafts = 25, #acc drafts = 24, #gen tokens = 1600, #acc tokens = 1416, dur = 1.167 ms
slot      release: id  3 | task 1591 | stop processing: n_tokens = 3059, truncated = 0
srv  update_slots: all slots are idle
srv  log_server_r: done request: POST /v1/chat/completions xxx.xxx.xxx.xxx 200

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Jan 31, 2026

I’m experimenting with the n/min/max settings, but I don’t understand the balance yet. Does a large min–max range hurt us somehow?

Qwen 30B and settings from the post:

Details
       eval time =    4838.87 ms /   630 tokens (    7.68 ms per token,   130.20 tokens per second)
statistics ngram_mod: #calls = 629, #gen drafts = 0, #acc drafts = 0, #gen tokens = 0, #acc tokens = 0, dur = 0.474 ms
       eval time =    9139.89 ms /  1140 tokens (    8.02 ms per token,   124.73 tokens per second)
draft acceptance rate = 0.06250 (   12 accepted /   192 generated)
statistics ngram_mod: #calls = 1756, #gen drafts = 3, #acc drafts = 3, #gen tokens = 192, #acc tokens = 12, dur = 1.403 ms
       eval time =    9317.38 ms /  1150 tokens (    8.10 ms per token,   123.43 tokens per second)
draft acceptance rate = 0.25000 (   32 accepted /   128 generated)
statistics ngram_mod: #calls = 2873, #gen drafts = 5, #acc drafts = 5, #gen tokens = 320, #acc tokens = 44, dur = 2.443 ms
       eval time =    5384.53 ms /   640 tokens (    8.41 ms per token,   118.86 tokens per second)
draft acceptance rate = 0.04688 (    3 accepted /    64 generated)
statistics ngram_mod: #calls = 3509, #gen drafts = 6, #acc drafts = 6, #gen tokens = 384, #acc tokens = 47, dur = 2.975 ms
       eval time =    9901.95 ms /  1247 tokens (    7.94 ms per token,   125.93 tokens per second)
draft acceptance rate = 0.37187 (  119 accepted /   320 generated)
statistics ngram_mod: #calls = 4636, #gen drafts = 11, #acc drafts = 11, #gen tokens = 704, #acc tokens = 166, dur = 4.006 ms
       eval time =   12292.57 ms /  1407 tokens (    8.74 ms per token,   114.46 tokens per second)
draft acceptance rate = 0.04688 (    3 accepted /    64 generated)
statistics ngram_mod: #calls = 6039, #gen drafts = 12, #acc drafts = 12, #gen tokens = 768, #acc tokens = 169, dur = 5.146 ms
       eval time =    9116.82 ms /  1523 tokens (    5.99 ms per token,   167.05 tokens per second)
draft acceptance rate = 0.63379 (  649 accepted /  1024 generated)
statistics ngram_mod: #calls = 6912, #gen drafts = 28, #acc drafts = 28, #gen tokens = 1792, #acc tokens = 818, dur = 6.118 ms

I see accept: low acceptance streak (3) – resetting ngram_mod almost constantly across all models. After a quick look at the code, I’m not sure n_low should be logged at all, since it’s always 3.

@ggerganov
Copy link
Member Author

Larger ngram size and larger drafts increase the chances that we will draft only when the LLM is repeating an existing text. Basically, we are trying to detect long repeating blocks without doing exhaustive searches. So unless your use case involves such repeating blocks of text, this method won't help.

Yes, the 3 is currently hardcoded. Btw, you can set env LLAMA_TRACE from some extra information about how many tokens are being accepted.

@jacekpoplawski
Copy link
Contributor

Larger ngram size and larger drafts increase the chances that we will draft only when the LLM is repeating an existing text. Basically, we are trying to detect long repeating blocks without doing exhaustive searches.

Do I understand correctly that

statistics ngram_mod: #calls = 6039, #gen drafts = 12, #acc drafts = 12, #gen tokens = 768, #acc tokens = 169, dur = 5.146 ms

means that the total "cost" of ngram_mod was only about 5 ms?

My point is: should I try to increase that time by changing --spec-ngram-size-n / --draft-min / --draft-max, since even 500 ms still wouldn’t be noticeable?

@ggerganov
Copy link
Member Author

So far, I think --draft-min and --draft-max likely don't need to be changed from the recommended values for now.

Regarding --spec-ngram-size-n, I am thinking about trying larger values like 32, 48. Also, it might be better if I implement a multi-level ngram speculation in such a way that if short ngrams fail, we continue to try with large ngrams. Still thinking about this.

@bfroemel
Copy link

bfroemel commented Feb 1, 2026

ad #19164 (comment) )

I think I just observed the effects of an early low acceptance streak (3) and the triggered reset clears the actually still very useful ngrams from prompt processing. Naively I would rather want to not clear the complete hash pool, but keep the ngrams from the prompt processing of the current request even if there are streaks?

@bfroemel
Copy link

bfroemel commented Feb 1, 2026

Just an idea to have a more consistent and sustained speedup behavior/avoid disadvantages of early low acceptance streaks in the current pruning mechanism: track for each ngram in the pool a capped score, initially set to 1 on insert. If an ngram was used successfully in a draft, count it up. If the draft was rejected count it down. On streaks remove all ngrams smaller or equal 0. Not sure if it's important to keep occupancy below a certain threshold.

@jacekpoplawski
Copy link
Contributor

It works pretty well in OpenCode (GLM 4.7 Flash with thinking enabled), but I’m not sure if it’s real or placebo. I assume that a draft acceptance rate above 0.1 indicates some speedup. (I see also >0.5)

@easyfab
Copy link

easyfab commented Feb 1, 2026

To add to my message #19231 I think there still a problem.
I think it can also be reproduced simply by regenerating the message in the chat.
On my end, the regeneration is no longer accelerated. Is it just me?

At start :
image

and after regenrate :

image

@srogmann
Copy link
Collaborator

srogmann commented Feb 1, 2026

Also, it might be better if I implement a multi-level ngram speculation in such a way that if short ngrams fail, we continue to try with large ngrams.

This is one of the ideas leading to the vector std::vector<std::unique_ptr<common_speculative_state>> impls in common/speculative.cpp: This vector can take the same implementation several times but with different ngram-size configurations.

anubhavgupta added a commit to anubhavgupta/llama-cpp-manager that referenced this pull request Feb 1, 2026
@ddh0
Copy link
Contributor

ddh0 commented Feb 1, 2026

To add to my message #19231 I think there still a problem. I think it can also be reproduced simply by regenerating the message in the chat. On my end, the regeneration is no longer accelerated. Is it just me?

At start : image

and after regenrate :

image

I see the same thing.

4b1tQu4ntN3k0 pushed a commit to 4b1tQu4ntN3k0/llama.cpp that referenced this pull request Feb 2, 2026
* spec : add ngram-mod

* cont : simplify + keep track of occupancy

* cont : cleanup

* cont : move initialization to common/speculative

* cont : cleanup

* cont : cleanup

* cont : fix
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* spec : add ngram-mod

* cont : simplify + keep track of occupancy

* cont : cleanup

* cont : move initialization to common/speculative

* cont : cleanup

* cont : cleanup

* cont : fix
@EndeavoringOrb
Copy link

Not a bug, just something I noticed. When the prompt contains an uploaded/pasted file with CRLF line endings, the ngrams often don't get accepted (even if task is repeating file verbatim) because models prefer LF endings.

Running with:

llama-server -m Devstral-2-123B-Instruct-2512-UD-Q5_K_XL-00001-of-00002.gguf --no-mmap --temp 0.15 --port 55553 --metrics --min-p 0.01 -c 32768 --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 32 --draft-max 48

build: 7992 (612db61) with GNU 13.3.0 for Linux aarch64

Stats for CRLF file + prompt "Repeat verbatim." (temp set to 0 in UI)

       eval time =  240138.40 ms /   557 tokens (  431.13 ms per token,     2.32 tokens per second)
draft acceptance rate = 0.11458 (   11 accepted /    96 generated)

Stats for LF file + prompt "Repeat verbatim." (temp set to 0 in UI)

       eval time =   18696.70 ms /   557 tokens (   33.57 ms per token,    29.79 tokens per second)
draft acceptance rate = 0.83013 (  518 accepted /   624 generated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.