Skip to content

llama: allow partial seq_rm for GDN models for speculative decoding#22400

Draft
am17an wants to merge 4 commits into
ggml-org:masterfrom
am17an:hybrid-mem-slot-rollback
Draft

llama: allow partial seq_rm for GDN models for speculative decoding#22400
am17an wants to merge 4 commits into
ggml-org:masterfrom
am17an:hybrid-mem-slot-rollback

Conversation

@am17an
Copy link
Copy Markdown
Contributor

@am17an am17an commented Apr 26, 2026

Overview

Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto draft_max by storing the GDN intermediates. This is the similar to how vLLM does it.

This is breaking change in the GDN API, so it doesn't work for anything except for CPU and CUDA as of now. Also there seems to be some issue with -fit

Additional information

I'm not sure how to exactly measure spec decoding comprehensively (probably different types of prompts), but for a simple prompt, Write a complete Python implementation of a doubly linked list with insert, delete, search, and reverse methods. Include type hints and docstrings.

On a 5090: llama-server -m /opt/models/Qwen3.5-27B-Q4_K_M.gguf -md /opt/models/Qwen3.5-0.8B-Q8_0.gguf --draft-max 16 -c 8192 -ngl 99

Master:

  "timings": {
    "cache_n": 0,
    "prompt_n": 100,
    "prompt_ms": 122.986,
    "prompt_per_token_ms": 1.22986,
    "prompt_per_second": 813.1006781259655,
    "predicted_n": 400,
    "predicted_ms": 5637.426,
    "predicted_per_token_ms": 14.093565000000002,
    "predicted_per_second": 70.95436818150695,
    "draft_n": 344,
    "draft_n_accepted": 344
  }

PR:

  "timings": {
    "cache_n": 0,
    "prompt_n": 100,
    "prompt_ms": 108.536,
    "prompt_per_token_ms": 1.08536,
    "prompt_per_second": 921.3532837031032,
    "predicted_n": 400,
    "predicted_ms": 3173.053,
    "predicted_per_token_ms": 7.9326324999999995,
    "predicted_per_second": 126.06155648834105,
    "draft_n": 374,
    "draft_n_accepted": 339
  }

So about 1.8x better

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, initially for prototyping. Then for debugging and reviewing. Also understanding the speculative decoding codebase

@github-actions github-actions Bot added model Model specific testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend examples server ggml changes relating to the ggml tensor library for machine learning labels Apr 26, 2026
@am17an am17an force-pushed the hybrid-mem-slot-rollback branch from c5a52bc to 9332573 Compare April 26, 2026 18:15
@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Apr 26, 2026

Okay, I'm trying to understand the machinery here, correct me if I'm wrong (I don't have any PoC code yet):

It seems the main gains from this approach stem from the fact that instead of using full checkpointing, we actually keep a "sliding window attention" buffer for recurrent states of size N, so we can do seq_rm instead of reloading the full checkpoint, which will do normal seq_rm on the dense layers and a rewind to an earlier state (so like an internal "light" checkpoint) on the recurrent layers.

However, the approach utilizes major changes to the kernels for fused GDN, which are needed so that in case you have X prefill tokens but <N generation tokens yet, you can still roll back N tokens (so you need to save the recurrent state and convolution state for the last N tokens of prefill). But if I understand correctly how speculative decoding works, this is completely not needed for this exact use case since you never roll back beyond the prefill barrier.

So, instead of having this major rewrite, what could work is something like this:

  • increase size of recurrent cache from 1 to N like here
  • prefill writes just last token of state like before
  • generation writes a token to the first free slot
  • if no free slot, oldest slot is evicted and the entire cache moved 1 back to make space for the slot
  • since we never need to read more than one token from the cache anyway, we can make it a cyclic buffer to make the overhead as low as possible, i.e. for example for N=5, writing order would look like this (P = prefill, 1...K = generation): [P____] -> [P1___] -> [P12__] -> [P123_] -> [P1234] -> [51234] -> [56234] -> ...

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 27, 2026

Not exactly clear on what you mean, but rolling back the rejected tokens and skipping the accepted tokens for the next draft is an optimization. However, I think a lot of benefit can be attributed to no device-to-host copies in this PR (i.e. the slots are all kept on device), so perhaps something simpler would be keeping a device buffer for the spec checkpoint first.

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Apr 27, 2026

@am17an what I mean is that speculative decoding works, as far as I understand, like this:

-> we prefill X tokens
-> we start generating speculative tokens (so we're at X + N)
-> we need to roll back the speculated tokens to verify (so we're back at X + N - N = X)

So if I get the process correctly, we will never move back beyond X, which is what this whole machinery of modifying the GDN kernels is intended to do.

I'll try to put up some PoC later on.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 27, 2026

No it is not like that. We first generate N drafts auto-regressively via the draft model, then the target model verifies this via a batch (prefill), say K are accepted. So the rollback required is N-K

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Apr 27, 2026

No it is not like that. We first generate N drafts auto-regressively via the draft model, then the target model verifies this via a batch (prefill), say K are accepted. So the rollback required is N-K

Yeah, but the rollback of N-K is within the auto-regressively generated tokens (which can normally just save their states one-token-at-a-time to the bigger cache). That's what mean. The GDN kernel changes would only be needed if you actually wanted to rollback into the original prefill content.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 27, 2026

Yeah, but the rollback of N-K is within the auto-regressively generated tokens (which can normally just save their states one-token-at-a-time to the bigger cache).

These aren't auto-regressively generated by the target model - who's states we actually want. These are auto-regressively generated by the draft model, the target models verifies via prefill (i.e. all drafts together in a batch). However the fused GDN we use currently is token by token, we just need to expose where to write the state at each token, which is what this PR does

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Apr 27, 2026

Okay, now I get it :)

@Ziqiao-git
Copy link
Copy Markdown

Working on a Metal port of the keep_intermediates path locally. Have a draft kernel that mirrors the CPU/CUDA logic — token loop writes per-token snapshots into the dst state region, with pipeline lookup gated by op_params[0]. test-backend-ops passes 6/6 keep_intermediates cases bit-exact vs the CPU reference (head_size 16/32/64/128, n_seq_tokens 2/4, n_seqs 1/2, kda on/off, v_repeat 1/2).

Happy to push to your branch, open a PR against it, or wait until merge and submit a follow-up — whichever fits your workflow.

@am17an am17an force-pushed the hybrid-mem-slot-rollback branch 2 times, most recently from 86d9978 to 4e95702 Compare April 28, 2026 12:56
Comment thread include/llama.h Outdated
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);

LLAMA_API bool llama_model_supports_recurrent_partial_rollback(const struct llama_model * model);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding the concept of "rollback" I think you can add an integer parameter:

uint32_t n_rs_seq; // number of recurrent states per sequence

@ggerganov
Copy link
Copy Markdown
Member

perhaps something simpler would be keeping a device buffer for the spec checkpoint first.

Yeah, it would be nice to explore this. I prefer the current approach on master as it does not introduce memory overhead for extra recurrent states so we can have large drafts (e.h. 64 tokens long) for prompt-based speculative decoding without adding a lot of memory.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 28, 2026

@ggerganov yes will explore keeping the recurrent checkpoint on device. In terms of memory, we still use less than a normal attention model. My initial idea was to support MTP with this since I think the highest quality draft is the MTP layer, which will require max 5 draft tokens. In that case I think it will matter if we can do a partial rollback for performance.

@ruixiang63
Copy link
Copy Markdown

Here is target-side deferred commit (SGLang Implementation), maybe this can give you some more insights :)

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 28, 2026

@ruixiang63 this is what this PR also does

@ggerganov
Copy link
Copy Markdown
Member

@am17an Try to rerun the tests on master now that #22506 is merged. It should improve the baseline performance.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 30, 2026

@ggerganov on the same prompt using the same parameters I get this (pasted below). I also tried --spec-n-draft-max 64

  "timings": {
    "cache_n": 0,
    "prompt_n": 100,
    "prompt_ms": 121.479,
    "prompt_per_token_ms": 1.21479,
    "prompt_per_second": 823.1875468187917,
    "predicted_n": 100,
    "predicted_ms": 1505.489,
    "predicted_per_token_ms": 15.05489,
    "predicted_per_second": 66.42360057097727,
    "draft_n": 96,
    "draft_n_accepted": 78
  }

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Apr 30, 2026

There also seems to be a slowdown for this PR as well, not sure why the drafts went from ~340 to ~84

  "timings": {
    "cache_n": 0,
    "prompt_n": 100,
    "prompt_ms": 91.193,
    "prompt_per_token_ms": 0.91193,
    "prompt_per_second": 1096.5753950412861,
    "predicted_n": 100,
    "predicted_ms": 874.735,
    "predicted_per_token_ms": 8.74735,
    "predicted_per_second": 114.3203370163535,
    "draft_n": 84,
    "draft_n_accepted": 79
  }

@ggerganov
Copy link
Copy Markdown
Member

@am17an Could you reevaluate the baseline using #22679?

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 4, 2026

@ggerganov after #22679 there is no difference between master and PR (I also cherry-picked your changes) for high acceptance rates (> 85%) on a 5090. For lower acceptance rate I see some difference but it could be just noise

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 4, 2026

on DGX Spark though there is still a difference as this PR doesn't help with unified memory as much

  code_python        pred= 192 draft= 186 acc= 158 rate=0.850 tok/s=27.2
  code_cpp           pred= 192 draft= 138 acc= 120 rate=0.870 tok/s=15.5
  explain_concept    pred= 192 draft= 170 acc= 101 rate=0.594 tok/s=10.7
  summarize          pred=  55 draft=  48 acc=  36 rate=0.750 tok/s=14.5
  qa_factual         pred= 177 draft= 126 acc= 106 rate=0.841 tok/s=14.0
  translation        pred=  22 draft=  13 acc=  13 rate=1.000 tok/s=16.8
  creative_short     pred= 192 draft= 136 acc= 104 rate=0.765 tok/s=13.5
  stepwise_math      pred= 192 draft= 172 acc= 147 rate=0.855 tok/s=23.2
  long_code_review   pred= 192 draft= 160 acc= 111 rate=0.694 tok/s=13.1

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1406,
  "total_draft": 1149,
  "total_draft_accepted": 896,
  "aggregate_accept_rate": 0.7798,
  "wall_s_total": 96.47
}

The python script I'm using to bench is here https://gist.github.com/am17an/228edfb84ed082aa88e3865d6fa27090

@ggerganov
Copy link
Copy Markdown
Member

With #22679 I get these results on DGX Spark:

$ python spec.py 
  code_python        pred= 192 draft= 168 acc= 129 rate=0.768 tok/s=22.8
  code_cpp           pred= 192 draft= 137 acc= 117 rate=0.854 tok/s=21.5
  explain_concept    pred= 192 draft= 167 acc= 103 rate=0.617 tok/s=16.3
  summarize          pred=  55 draft=  47 acc=  37 rate=0.787 tok/s=24.2
  qa_factual         pred= 192 draft= 142 acc= 121 rate=0.852 tok/s=22.3
  translation        pred= 192 draft= 124 acc= 113 rate=0.911 tok/s=20.5
  creative_short     pred= 192 draft= 126 acc=  99 rate=0.786 tok/s=18.0
  stepwise_math      pred= 192 draft= 171 acc= 145 rate=0.848 tok/s=30.3
  long_code_review   pred= 192 draft= 148 acc= 109 rate=0.737 tok/s=18.1

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1591,
  "total_draft": 1230,
  "total_draft_accepted": 973,
  "aggregate_accept_rate": 0.7911,
  "wall_s_total": 80.14
}

The server command:

make -j && ./bin/llama-server -hf unsloth/Qwen3.6-27B-GGUF:Q4_K_M unsloth/Qwen3.5-0.8B-GGUF:Q8_0 

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 4, 2026

I'm using the Q8_0 target model. Also --chat-template-kwargs "{\"preserve_thinking\": true}"

@ggerganov
Copy link
Copy Markdown
Member

@am17an Could you prepare a branch with MTP and using just #22679, without the partial seq_rm changes from this PR? I'd like to see if the fast checkpointing is now good enough to proceed with it and avoid the extra complexity of partial rollback.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 5, 2026

@ggerganov you can try 7f46fe6, I'm seeing a 15-20% slowdown but higher acceptance rates, so investigating if there's a bug in the partial rollback

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 5, 2026

on a 5090,

with no rollback + #22679 (i.e. 7f46fe6)

  code_python        pred= 192 draft= 153 acc= 140 rate=0.915 tok/s=103.2
  code_cpp           pred= 192 draft= 156 acc= 139 rate=0.891 tok/s=94.4
  explain_concept    pred= 192 draft= 168 acc= 135 rate=0.804 tok/s=76.4
  summarize          pred=  54 draft=  48 acc=  37 rate=0.771 tok/s=70.5
  qa_factual         pred= 178 draft= 144 acc= 129 rate=0.896 tok/s=89.6
  translation        pred=  23 draft=  21 acc=  15 rate=0.714 tok/s=61.8
  creative_short     pred= 192 draft= 165 acc= 136 rate=0.824 tok/s=78.2
  stepwise_math      pred= 192 draft= 159 acc= 138 rate=0.868 tok/s=87.7
  high_entropy       pred= 192 draft= 156 acc= 139 rate=0.891 tok/s=85.9
  long_code_review   pred= 192 draft= 168 acc= 135 rate=0.804 tok/s=76.6
  
  Aggregate: {
  "n_requests": 10,
  "total_predicted": 1599,
  "total_draft": 1338,
  "total_draft_accepted": 1143,
  "aggregate_accept_rate": 0.8543,
  "wall_s_total": 21.03
}

MTP PR:

  code_python        pred= 192 draft= 153 acc= 140 rate=0.915 tok/s=129.3
  code_cpp           pred= 192 draft= 174 acc= 133 rate=0.764 tok/s=114.9
  explain_concept    pred= 192 draft= 195 acc= 125 rate=0.641 tok/s=101.2
  summarize          pred=  53 draft=  51 acc=  35 rate=0.686 tok/s=107.2
  qa_factual         pred= 177 draft= 174 acc= 118 rate=0.678 tok/s=106.7
  translation        pred=  22 draft=  24 acc=  13 rate=0.542 tok/s=93.6
  creative_short     pred= 192 draft= 201 acc= 123 rate=0.612 tok/s=98.9
  stepwise_math      pred= 192 draft= 174 acc= 133 rate=0.764 tok/s=114.4
  high_entropy       pred= 192 draft= 177 acc= 132 rate=0.746 tok/s=112.4
  long_code_review   pred= 192 draft= 183 acc= 130 rate=0.710 tok/s=108.3

Aggregate: {
  "n_requests": 10,
  "total_predicted": 1596,
  "total_draft": 1506,
  "total_draft_accepted": 1082,
  "aggregate_accept_rate": 0.7185,
  "wall_s_total": 16.68
}

The difference is non-negligible, IMO worth keeping

@ggerganov
Copy link
Copy Markdown
Member

During drafting, do you stop when the probability of the drafted token drops below p_min (i.e. #22506)? This is important for the checkpoint-based speculative decoding to avoid majority of rollbacks.

I think we have to focus on bringing the checkpoint-based version first because it is broadly compatible across the backends, uses a lot less memory, works in combination with ngram-based speculative decoding and fits in the current llama.cpp interface. Partial seq_rm could be considered after this, though generally I think it is too complex and not really sure it is worth adding it.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 5, 2026

I added the p-min break in c8ee2ef, the acceptance rate goes up but the timing is a slightly worse

  code_python        pred= 192 draft= 142 acc= 139 rate=0.979 tok/s=100.2
  code_cpp           pred= 192 draft= 136 acc= 134 rate=0.985 tok/s=84.8
  explain_concept    pred= 192 draft= 135 acc= 129 rate=0.956 tok/s=66.8
  summarize          pred=  54 draft=  37 acc=  37 rate=1.000 tok/s=72.2
  qa_factual         pred= 178 draft= 126 acc= 122 rate=0.968 tok/s=75.5
  translation        pred=  23 draft=  15 acc=  15 rate=1.000 tok/s=65.7
  creative_short     pred= 192 draft= 126 acc= 126 rate=1.000 tok/s=69.1
  stepwise_math      pred= 192 draft= 143 acc= 137 rate=0.958 tok/s=90.6
  high_entropy       pred= 192 draft= 140 acc= 134 rate=0.957 tok/s=77.9
  long_code_review   pred= 192 draft= 134 acc= 130 rate=0.970 tok/s=73.3

Aggregate: {
  "n_requests": 10,
  "total_predicted": 1599,
  "total_draft": 1134,
  "total_draft_accepted": 1103,
  "aggregate_accept_rate": 0.9727,
  "wall_s_total": 22.62
}

I think we have to focus on bringing the checkpoint-based version first because it is broadly compatible across the backends, uses a lot less memory, works in combination with ngram-based speculative decoding and fits in the current llama.cpp interface. Partial seq_rm could be considered after this, though generally I think it is too complex and not really sure it is worth adding it.

Yes other models like GLM don't have recurrent memory so MTP will be "best" version there, we can focus on getting MTP in without the partial rollback. However to your other points, I think the rollback can only be enabled for MTP + hybrid memory as it requires ~500MB of extra VRAM for --spec-draft-n-max 3 for the rollback but provides a very tangible speed up in toks/second on fast GPUs (i.e. it is on par or better with vLLM vs 20% slower). For backends which support the fused AR GDN it should be easy to add the change, it's just a few lines of code to keep the intermediates - I can take the responsibility for adding across all of them.

For any draft models or ngram, I don't think it makes sense because of the memory use as you said.

Comment thread ggml/include/ggml.h Outdated
Comment on lines +2540 to +2541
struct ggml_tensor * state,
bool keep_intermediates);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be simpler (and backwards-compatible) if we keep the signature as it is and extend the internal kernel logic to determine if keep_intermediates based on the shape of the state tensor? This way, we will keep the graph static and the idea is that even if we process a batch of n_tokens > n_rs_seq, we would still keep the last [n_tokens - n_rs_seq, n_tokens - 1] states.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes should be possible - this is how we do kda detection as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question with this - how do we do build_rs the correct shape. I think the cleanest is having it be 3d instead of 2d with the 3rd dim being n_rs_seq. Or do you have a better idea?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cleanest is having it be 3d instead of 2d with the 3rd dim being n_rs_seq.

Yes, let's try like this.

am17an added 4 commits May 13, 2026 17:50
Currently speculative checkpoint needs to restart from a checkpoint
after some draft tokens are not accepted, this leads to some wastage in
running the target again. This PR adds the ability to rollback upto
`draft_max` by storing the GDN intermediates.
@am17an am17an force-pushed the hybrid-mem-slot-rollback branch from 6ef86c2 to e01a801 Compare May 13, 2026 13:09
Comment thread include/llama.h
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
uint32_t n_ubatch; // physical maximum batch size
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback)
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs server testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants