diff --git a/common/arg.cpp b/common/arg.cpp index 87462f49e76..24d9734b934 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3591,6 +3591,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.p_min = std::stof(value); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN")); + add_opt(common_arg( + {"--spec-draft-backend-sampling"}, + {"--no-spec-draft-backend-sampling"}, + string_format("offload draft sampling to the backend (default: %s)", + params.speculative.draft.backend_sampling ? "enabled" : "disabled"), + [](common_params & params, bool value) { + params.speculative.draft.backend_sampling = value; + } + ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_BACKEND_SAMPLING")); add_opt(common_arg( {"--spec-draft-device", "-devd", "--device-draft"}, "", "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" diff --git a/common/common.h b/common/common.h index 53c689bc11d..dec90456afa 100644 --- a/common/common.h +++ b/common/common.h @@ -305,6 +305,8 @@ struct common_params_speculative_draft { float p_split = 0.1f; // speculative decoding split probability float p_min = 0.0f; // minimum speculative decoding probability (greedy) + bool backend_sampling = true; // offload draft sampling to the backend (default: on) + common_params_model mparams; llama_context * ctx_tgt = nullptr; diff --git a/common/speculative.cpp b/common/speculative.cpp index 4d1b61a13ad..fcb24e2d6db 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -414,6 +414,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::vector smpls; + // backend sampler chain per seq, attached to ctx_dft + std::vector backend_chains; + int32_t n_embd = 0; // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. @@ -445,7 +448,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); - LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd); + LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__, this->params.n_gpu_layers, ggml_type_name(this->params.cache_type_k), @@ -469,6 +472,22 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } + // offload draft sampling to the backend + backend_chains.assign(n_seq, nullptr); + if (this->params.backend_sampling) { + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); + + if (!llama_set_sampler(ctx_dft, seq_id, chain)) { + LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); + llama_sampler_free(chain); + chain = nullptr; + } + backend_chains[seq_id] = chain; + } + } + llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true); @@ -484,6 +503,18 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { } ~common_speculative_impl_draft_mtp() override { + auto * ctx_dft = this->params.ctx_dft; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) { + if (backend_chains[seq_id] == nullptr) { + continue; + } + if (ctx_dft) { + llama_set_sampler(ctx_dft, seq_id, nullptr); + } + llama_sampler_free(backend_chains[seq_id]); + } + backend_chains.clear(); + if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3cc8ffa6668..ad36c06667d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1137,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init &&