From e20085bac967011a3215146b570053a8e39a6e78 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Tue, 19 May 2026 21:16:27 +0530 Subject: [PATCH 1/3] Move to backend sampling for MTP draft path Run top_k(10) on the draft backend. D2H transfers happen only for the top 10 logits Make backend sampling more robust and fallback to CPU on failure cases, such as with "-sm tensor" or when a backend doesn't support TOP_K. --- common/speculative.cpp | 29 +++++++++++++++++++++++++++++ src/llama-context.cpp | 23 ++++++++++++++++++++++- src/llama-sampler.cpp | 11 +++++++++-- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 4d1b61a13ad..bfbb49bf565 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -414,6 +414,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::vector smpls; + // backend sampler chain per seq, attached to ctx_dft + std::vector backend_chains; + int32_t n_embd = 0; // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. @@ -469,6 +472,20 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } + // offload draft sampling to the backend + backend_chains.assign(n_seq, nullptr); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); + + if (!llama_set_sampler(ctx_dft, seq_id, chain)) { + LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); + llama_sampler_free(chain); + chain = nullptr; + } + backend_chains[seq_id] = chain; + } + llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true); @@ -484,6 +501,18 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { } ~common_speculative_impl_draft_mtp() override { + auto * ctx_dft = this->params.ctx_dft; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) { + if (backend_chains[seq_id] == nullptr) { + continue; + } + if (ctx_dft) { + llama_set_sampler(ctx_dft, seq_id, nullptr); + } + llama_sampler_free(backend_chains[seq_id]); + } + backend_chains.clear(); + if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3cc8ffa6668..8d41b78ebbe 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1137,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -1146,7 +1159,15 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { if (sampler && can_offload) { auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); - sampler->iface->backend_init(sampler, buft); + if (!sampler->iface->backend_init(sampler, buft)) { + LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, backend_init failed\n", + __func__, llama_sampler_name(sampler), seq_id); + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } sampling.samplers[seq_id] = sampler; diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 9bbc5dbde24..b592f3ce360 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -698,8 +698,6 @@ static bool llama_sampler_chain_backend_init( GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); - chain->is_init = true; - bool res = true; for (auto & smpl : chain->samplers) { @@ -721,6 +719,15 @@ static bool llama_sampler_chain_backend_init( res = res && res_cur; } + if (res) { + chain->is_init = true; + } else { + // partial failure - force CPU path for all samplers + for (auto & smpl : chain->samplers) { + smpl.is_backend = false; + } + } + return res; } From e961ca57fd8b7d57e817e89c34112191a74e6069 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Wed, 20 May 2026 12:47:28 +0530 Subject: [PATCH 2/3] Allow sampler chains to be partially offloaded to backend --- src/llama-context.cpp | 10 +--------- src/llama-sampler.cpp | 11 ++--------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8d41b78ebbe..ad36c06667d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1159,15 +1159,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { if (sampler && can_offload) { auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); - if (!sampler->iface->backend_init(sampler, buft)) { - LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, backend_init failed\n", - __func__, llama_sampler_name(sampler), seq_id); - if (sampling.samplers.count(seq_id) > 0) { - sched_need_reserve = true; - } - sampling.samplers.erase(seq_id); - return false; - } + sampler->iface->backend_init(sampler, buft); sampling.samplers[seq_id] = sampler; diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index b592f3ce360..9bbc5dbde24 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -698,6 +698,8 @@ static bool llama_sampler_chain_backend_init( GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); + chain->is_init = true; + bool res = true; for (auto & smpl : chain->samplers) { @@ -719,15 +721,6 @@ static bool llama_sampler_chain_backend_init( res = res && res_cur; } - if (res) { - chain->is_init = true; - } else { - // partial failure - force CPU path for all samplers - for (auto & smpl : chain->samplers) { - smpl.is_backend = false; - } - } - return res; } From b0061f7f9e7f62d045e9c94d0d2370e03d38ad0a Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Wed, 20 May 2026 15:27:47 +0530 Subject: [PATCH 3/3] Add --spec-draft-backend-sampling argument. Enabled by default. --- common/arg.cpp | 9 +++++++++ common/common.h | 2 ++ common/speculative.cpp | 20 +++++++++++--------- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 87462f49e76..24d9734b934 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3591,6 +3591,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.p_min = std::stof(value); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN")); + add_opt(common_arg( + {"--spec-draft-backend-sampling"}, + {"--no-spec-draft-backend-sampling"}, + string_format("offload draft sampling to the backend (default: %s)", + params.speculative.draft.backend_sampling ? "enabled" : "disabled"), + [](common_params & params, bool value) { + params.speculative.draft.backend_sampling = value; + } + ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_BACKEND_SAMPLING")); add_opt(common_arg( {"--spec-draft-device", "-devd", "--device-draft"}, "", "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" diff --git a/common/common.h b/common/common.h index 53c689bc11d..dec90456afa 100644 --- a/common/common.h +++ b/common/common.h @@ -305,6 +305,8 @@ struct common_params_speculative_draft { float p_split = 0.1f; // speculative decoding split probability float p_min = 0.0f; // minimum speculative decoding probability (greedy) + bool backend_sampling = true; // offload draft sampling to the backend (default: on) + common_params_model mparams; llama_context * ctx_tgt = nullptr; diff --git a/common/speculative.cpp b/common/speculative.cpp index bfbb49bf565..fcb24e2d6db 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -448,7 +448,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); - LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd); + LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__, this->params.n_gpu_layers, ggml_type_name(this->params.cache_type_k), @@ -474,16 +474,18 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { // offload draft sampling to the backend backend_chains.assign(n_seq, nullptr); - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { - llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); - llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); + if (this->params.backend_sampling) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); - if (!llama_set_sampler(ctx_dft, seq_id, chain)) { - LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); - llama_sampler_free(chain); - chain = nullptr; + if (!llama_set_sampler(ctx_dft, seq_id, chain)) { + LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); + llama_sampler_free(chain); + chain = nullptr; + } + backend_chains[seq_id] = chain; } - backend_chains[seq_id] = chain; } llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);