Move to backend sampling for MTP draft path#23287
Conversation
|
I discussed with @ggerganov about this earlier, and he said the backend sampling has issues we need to resolve first. Pasting his reply In theory yes, but it needs quite some work to get usable:
|
In this PR, I decided to fallback to CPU sampling in case of tensor parallel. Enabling backend sampling for "tensor parallel" is actually a quite involved change, and I was planning to work on it as a follow-up.
ok, I was not aware of it. This PR just makes use of
Could you explain what is currently missing, and does #19833 address this? |
So this is actually needed if we want to run the backend sampling for the target model. But here the idea is to run it on the draft. So for the time being, we can ignore this point.
Argmax is OK I think. But with the merge of #23269 we would want to use a top-k sampler in order to be able to utilize the |
argmax is simpler to do than top-k, so yes perf implications are expected: all of these should be significantly faster on CUDA than on CPU based on my previous involvement in backend sampling |
61f9de9 to
4c330e3
Compare
|
I have rebased the PR on top of master. Now it moves top_k(10) to backend. With RTX 5090, I'm seeing an improvement of ~8%. Master: PR: Master: PR: @ggerganov Can you see how it performs on Metal? |
|
There isn't a noticeable impact in the Metal perf using this change. I think the unoptimized top-k balances with the reduced data transfer and CPU sampling. When it gets optimized, we should likely also see some improvement. |
Ok, let's consider adding an argmax-based sampler as originally proposed, after we merge this version that uses top-k since it should be very simple change and does not involve changes to the |
|
This does not require |
No, it uses backend sampling by default for MTP draft. It doesn't modify the sampling for the target model. |
|
Does it make sense to add a toggle for this? |
|
Vulkan performance on RTX 5090: Average gain of ~4%. Master: PR: |
I don't know how stable backend sampling is across all backends. So, it might be a good idea for debugging. @ggerganov ? |
|
Ok, let's add the argument - I am also not sure how stable the implementation is. |
|
Added the argument. Kept it ON by default. |
Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K.
9fd101a to
b0061f7
Compare
|
Have any changes related to "--fit" calculations been made? Before this PR, Qwen-35B-A3B loaded fine with "-fitt 1536" and had about 90%~ of it's context full without crashing. Now it crashes with the first prompt: Full logsIt loads with "-fitt 2048". Even with less VRAM I got a 10%~ improvement in TG, wow! |
|
The extra VRAM usage might be fixed with #23433 - need to check |
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
* origin/master: (138 commits) fix(flash-attn): replace f32 with kv_type and q_type (ggml-org#23372) tests : move save-load-state from examples to tests (ggml-org#23336) server: expose prompt token counts in /slots endpoint (ggml-org#23454) metal : optimize concat kernel and fix set kernel threads (ggml-org#23411) server : free draft/MTP resources on sleep to fix VRAM leak (ggml-org#23461) server: re-inject subcommand when router spawns children under unified binary (ggml-org#23442) app : add batched-bench, fit-params, quantize & perplexity (ggml-org#23459) mtp: use inp_out_ids for skipping logit computation (ggml-org#23433) vocab : add Carbon-3B (HybridDNATokenizer) support (ggml-org#23410) doc: fix spec mtp typo (ggml-org#23435) ui: Improve Git Hooks for UI development (ggml-org#23403) ggml : Check the right iface method before using the fallback 2d get (ggml-org#23306) llama-graph: fix null-buffer crash in llm_graph_input_attn_kv_iswa for SWA-only models (ggml-org#23131) hexagon: ssm-conv fix for large prompts (ggml-org#23307) app : show version (ggml-org#23426) mtmd, model : merge HunyuanOCR into HunyuanVL and fix OCR vision precision (ggml-org#23329) ui: Add max image size option (ggml-org#22849) Move to backend sampling for MTP draft path (ggml-org#23287) opencl: refactor backend initilization (ggml-org#23318) common/speculative : fix nullptr crash in get_devices_str (ggml-org#23386) ...
Notable upstream changes: - MTP cleanup: rename state→impl, accept(is_other), p_min re-enabled, top_k=10, backend sampling (ggml-org#23287, ggml-org#23269) - fit_params accounts for mmproj memory via mtmd_get_memory_usage (ggml-org#21489) - Free draft/MTP resources on sleep (ggml-org#23461) - MTP inp_out_ids optimization (ggml-org#23433) - PDL for Hopper+ (ggml-org#22522) - SWA-only model null-buffer fix (ggml-org#23131) - Perplexity integer overflow fix (ggml-org#23496) Fork conflict resolutions: - speculative.cpp: updated fork classes (suffix, copyspec, recycle, dflash) to 3-arg accept() signature; renamed state→impl references - server-context.cpp: integrated upstream mmproj memory measurement for non-swap path; kept fork's pre-doubling auto-fit for mmproj-gpu-swap path (now uses mtmd_get_memory_usage instead of file-size heuristic); added upstream's mtmd_helper_log_set to mmproj init Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
* upstream/HEAD: (38 commits) vocab : add Carbon-3B (HybridDNATokenizer) support (ggml-org#23410) doc: fix spec mtp typo (ggml-org#23435) ui: Improve Git Hooks for UI development (ggml-org#23403) ggml : Check the right iface method before using the fallback 2d get (ggml-org#23306) llama-graph: fix null-buffer crash in llm_graph_input_attn_kv_iswa for SWA-only models (ggml-org#23131) hexagon: ssm-conv fix for large prompts (ggml-org#23307) app : show version (ggml-org#23426) mtmd, model : merge HunyuanOCR into HunyuanVL and fix OCR vision precision (ggml-org#23329) ui: Add max image size option (ggml-org#22849) Move to backend sampling for MTP draft path (ggml-org#23287) opencl: refactor backend initilization (ggml-org#23318) common/speculative : fix nullptr crash in get_devices_str (ggml-org#23386) mtmd : DeepSeek-OCR image processing fixes, img_tool::resize padding refactor (ggml-org#23345) vulkan: optimize operations in the IM2COL shader (ggml-org#22685) feat: Add WAV MIME type variants and improve audio format detection (ggml-org#23396) hexagon: HMX quantized matmul rework (ggml-org#23368) Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) (ggml-org#22522) app : introduce the llama unified executable (ggml-org#23296) refactor: Move text attachments up before the message content in chat completions payload (ggml-org#23406) mtmd: fit_params now take into account mmproj (ggml-org#21489) ...
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
ggml-org#23287 (commit 0bbdec4) flipped the default to true, offloading draft top-k(10) to the backend with D2H of only the top 10 logits. The optimization regressed V-J accept rate on Qwen3.5-35B-A3B-MTP-IQ4_XS from 79% (anchor 71a0c46) to 0.20% (tip 67edf5e) — root-caused by worker mtp-accept-rate-bisect-2026-05-24 to a host-side argmax that no longer matches the in-graph argmax over the same logits row. This default-flip is a STOP-GAP — restores pre-0bbdec4aa semantics (host pulls full vocab via D2H + samples locally). Once the actual backend top-k bug is fixed, the optimization can be safely re-enabled. Users who want it now can pass --spec-draft-backend-sampling on the CLI. Brief: kernel-work/orchestrator-inbox/completed/processed/orchestrator-brief-mtp-accept-rate-bisect-2026-05-24.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
ggml-org#23287 (commit 0bbdec4) flipped the default to true, offloading draft top-k(10) to the backend with D2H of only the top 10 logits. The optimization regressed V-J accept rate on Qwen3.5-35B-A3B-MTP-IQ4_XS from 79% (anchor 71a0c46) to 0.20% (tip 67edf5e) — root-caused by worker mtp-accept-rate-bisect-2026-05-24 to a host-side argmax that no longer matches the in-graph argmax over the same logits row. This default-flip is a STOP-GAP — restores pre-0bbdec4aa semantics (host pulls full vocab via D2H + samples locally). Once the actual backend top-k bug is fixed, the optimization can be safely re-enabled. Users who want it now can pass --spec-draft-backend-sampling on the CLI. Brief: kernel-work/orchestrator-inbox/completed/processed/orchestrator-brief-mtp-accept-rate-bisect-2026-05-24.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
ggml-org#23287 (commit 0bbdec4) flipped the default to true, offloading draft top-k(10) to the backend with D2H of only the top 10 logits. The optimization regressed V-J accept rate on Qwen3.5-35B-A3B-MTP-IQ4_XS from 79% (anchor 71a0c46) to 0.20% (tip 67edf5e) — root-caused by worker mtp-accept-rate-bisect-2026-05-24 to a host-side argmax that no longer matches the in-graph argmax over the same logits row. This default-flip is a STOP-GAP — restores pre-0bbdec4aa semantics (host pulls full vocab via D2H + samples locally). Once the actual backend top-k bug is fixed, the optimization can be safely re-enabled. Users who want it now can pass --spec-draft-backend-sampling on the CLI. Brief: kernel-work/orchestrator-inbox/completed/processed/orchestrator-brief-mtp-accept-rate-bisect-2026-05-24.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
PR ggml-org#23287 enabled backend draft sampling by default for the MTP path, attaching a per-seq_id sampler chain (top_k=10) to the draft context. This adds compute-buffer footprint that scales with n_seq, so configs that fit comfortably in VRAM at --parallel N>1 on b9246 now OOM during the first decode on b9410+ (see ggml-org#23903 for the bisect, b9246 fit two slots in 15.6 GB, b9426 needs essentially the full 16 GB for one slot under the same model and flags). Default the new behavior off so the regression does not fire on configs that worked before. Users wanting backend sampling can opt back in with --spec-draft-backend-sampling (already wired by PR ggml-org#23287). The help text auto-reflects the new default via string_format("default: %s", ... ? "enabled" : "disabled").
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
* Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. * Allow sampler chains to be partially offloaded to backend * Add --spec-draft-backend-sampling argument. Enabled by default.
The observation was that MTP draft is quite small, and for this reason, draft sampling can dominate the draft execution time. This PR tries to optimize the MTP draft sampling.
Replace D2H logit copies and CPU-side sort with argmax on the backend.
Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support ARG_MAX.
Introduce a new sampler in API that is greedy, but doesn't expose logits:
llama_sampler_init_greedy_token_only.Performance on 2x RTX 5090 with Qwen3.6-35B-A3B-UD-Q4_K_M.gguf and
--spec-draft-n-max 3improves by ~7%.Command:
./llama-server -m Qwen3.6-35B-A3B-UD-Q4_K_M.gguf --spec-type draft-mtp --spec-draft-n-max 3Master:
PR
Requirements