llama: allow partial seq_rm for GDN models for speculative decoding#22400
llama: allow partial seq_rm for GDN models for speculative decoding#22400am17an wants to merge 4 commits into
Conversation
c5a52bc to
9332573
Compare
|
Okay, I'm trying to understand the machinery here, correct me if I'm wrong (I don't have any PoC code yet): It seems the main gains from this approach stem from the fact that instead of using full checkpointing, we actually keep a "sliding window attention" buffer for recurrent states of size N, so we can do However, the approach utilizes major changes to the kernels for fused GDN, which are needed so that in case you have X prefill tokens but <N generation tokens yet, you can still roll back N tokens (so you need to save the recurrent state and convolution state for the last N tokens of prefill). But if I understand correctly how speculative decoding works, this is completely not needed for this exact use case since you never roll back beyond the prefill barrier. So, instead of having this major rewrite, what could work is something like this:
|
|
Not exactly clear on what you mean, but rolling back the rejected tokens and skipping the accepted tokens for the next draft is an optimization. However, I think a lot of benefit can be attributed to no device-to-host copies in this PR (i.e. the slots are all kept on device), so perhaps something simpler would be keeping a device buffer for the spec checkpoint first. |
|
@am17an what I mean is that speculative decoding works, as far as I understand, like this: -> we prefill X tokens So if I get the process correctly, we will never move back beyond X, which is what this whole machinery of modifying the GDN kernels is intended to do. I'll try to put up some PoC later on. |
|
No it is not like that. We first generate N drafts auto-regressively via the draft model, then the target model verifies this via a batch (prefill), say K are accepted. So the rollback required is N-K |
Yeah, but the rollback of N-K is within the auto-regressively generated tokens (which can normally just save their states one-token-at-a-time to the bigger cache). That's what mean. The GDN kernel changes would only be needed if you actually wanted to rollback into the original prefill content. |
These aren't auto-regressively generated by the target model - who's states we actually want. These are auto-regressively generated by the draft model, the target models verifies via prefill (i.e. all drafts together in a batch). However the fused GDN we use currently is token by token, we just need to expose where to write the state at each token, which is what this PR does |
|
Okay, now I get it :) |
|
Working on a Metal port of the keep_intermediates path locally. Have a draft kernel that mirrors the CPU/CUDA logic — token loop writes per-token snapshots into the dst state region, with pipeline lookup gated by op_params[0]. test-backend-ops passes 6/6 keep_intermediates cases bit-exact vs the CPU reference (head_size 16/32/64/128, n_seq_tokens 2/4, n_seqs 1/2, kda on/off, v_repeat 1/2). Happy to push to your branch, open a PR against it, or wait until merge and submit a follow-up — whichever fits your workflow. |
86d9978 to
4e95702
Compare
| // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) | ||
| LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); | ||
|
|
||
| LLAMA_API bool llama_model_supports_recurrent_partial_rollback(const struct llama_model * model); |
There was a problem hiding this comment.
Instead of adding the concept of "rollback" I think you can add an integer parameter:
uint32_t n_rs_seq; // number of recurrent states per sequence
Yeah, it would be nice to explore this. I prefer the current approach on master as it does not introduce memory overhead for extra recurrent states so we can have large drafts (e.h. 64 tokens long) for prompt-based speculative decoding without adding a lot of memory. |
|
@ggerganov yes will explore keeping the recurrent checkpoint on device. In terms of memory, we still use less than a normal attention model. My initial idea was to support MTP with this since I think the highest quality draft is the MTP layer, which will require max 5 draft tokens. In that case I think it will matter if we can do a partial rollback for performance. |
|
Here is target-side deferred commit (SGLang Implementation), maybe this can give you some more insights :) |
|
@ruixiang63 this is what this PR also does |
|
@ggerganov on the same prompt using the same parameters I get this (pasted below). I also tried |
|
There also seems to be a slowdown for this PR as well, not sure why the drafts went from ~340 to ~84 |
|
@ggerganov after #22679 there is no difference between master and PR (I also cherry-picked your changes) for high acceptance rates (> 85%) on a 5090. For lower acceptance rate I see some difference but it could be just noise |
|
on DGX Spark though there is still a difference as this PR doesn't help with unified memory as much The python script I'm using to bench is here https://gist.github.com/am17an/228edfb84ed082aa88e3865d6fa27090 |
|
With #22679 I get these results on DGX Spark: The server command: make -j && ./bin/llama-server -hf unsloth/Qwen3.6-27B-GGUF:Q4_K_M unsloth/Qwen3.5-0.8B-GGUF:Q8_0 |
|
I'm using the Q8_0 target model. Also |
|
@ggerganov you can try 7f46fe6, I'm seeing a 15-20% slowdown but higher acceptance rates, so investigating if there's a bug in the partial rollback |
|
on a 5090, with no rollback + #22679 (i.e. 7f46fe6) MTP PR: The difference is non-negligible, IMO worth keeping |
|
During drafting, do you stop when the probability of the drafted token drops below I think we have to focus on bringing the checkpoint-based version first because it is broadly compatible across the backends, uses a lot less memory, works in combination with ngram-based speculative decoding and fits in the current |
|
I added the p-min break in c8ee2ef, the acceptance rate goes up but the timing is a slightly worse
Yes other models like GLM don't have recurrent memory so MTP will be "best" version there, we can focus on getting MTP in without the partial rollback. However to your other points, I think the rollback can only be enabled for MTP + hybrid memory as it requires ~500MB of extra VRAM for For any draft models or ngram, I don't think it makes sense because of the memory use as you said. |
| struct ggml_tensor * state, | ||
| bool keep_intermediates); |
There was a problem hiding this comment.
Wouldn't it be simpler (and backwards-compatible) if we keep the signature as it is and extend the internal kernel logic to determine if keep_intermediates based on the shape of the state tensor? This way, we will keep the graph static and the idea is that even if we process a batch of n_tokens > n_rs_seq, we would still keep the last [n_tokens - n_rs_seq, n_tokens - 1] states.
There was a problem hiding this comment.
Yes should be possible - this is how we do kda detection as well.
There was a problem hiding this comment.
One question with this - how do we do build_rs the correct shape. I think the cleanest is having it be 3d instead of 2d with the 3rd dim being n_rs_seq. Or do you have a better idea?
There was a problem hiding this comment.
I think the cleanest is having it be 3d instead of 2d with the 3rd dim being n_rs_seq.
Yes, let's try like this.
Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates.
6ef86c2 to
e01a801
Compare
| uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode | ||
| uint32_t n_ubatch; // physical maximum batch size | ||
| uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) | ||
| uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) |
There was a problem hiding this comment.
| uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) | |
| uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] |
Overview
Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto
draft_maxby storing the GDN intermediates. This is the similar to how vLLM does it.This is breaking change in the GDN API, so it doesn't work for anything except for CPU and CUDA as of now. Also there seems to be some issue with
-fitAdditional information
I'm not sure how to exactly measure spec decoding comprehensively (probably different types of prompts), but for a simple prompt,
Write a complete Python implementation of a doubly linked list with insert, delete, search, and reverse methods. Include type hints and docstrings.On a 5090:
llama-server -m /opt/models/Qwen3.5-27B-Q4_K_M.gguf -md /opt/models/Qwen3.5-0.8B-Q8_0.gguf --draft-max 16 -c 8192 -ngl 99Master:
PR:
So about 1.8x better
Requirements