Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3004,7 +3004,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.use_jinja = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
Expand Down Expand Up @@ -3035,7 +3035,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(
Expand Down Expand Up @@ -3346,6 +3346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN"));
add_opt(common_arg(
{"--eagle3"},
"use EAGLE3 speculative decoding with the draft model",
[](common_params & params) {
params.speculative.eagle3 = true;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,13 @@ struct common_params_speculative {

struct common_params_model mparams_dft;

llama_model * model_tgt = nullptr; // the target model
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts

llama_context_params cparams_dft; // these are the parameters for the draft llama_context

bool eagle3 = false; // use EAGLE3 speculative decoding

int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

Expand Down
208 changes: 185 additions & 23 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct common_speculative_config {
const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {}
};


static bool common_speculative_are_compatible(
const llama_model * model_tgt,
const llama_model * model_dft) {
Expand Down Expand Up @@ -210,7 +211,9 @@ struct common_speculative_state_draft : public common_speculative_state {
~common_speculative_state_draft() override {
llama_perf_context_print(ctx_dft);

llama_free(ctx_dft);
if (ctx_dft) {
llama_free(ctx_dft);
}

common_sampler_free(smpl);

Expand All @@ -228,11 +231,11 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_tokens & result) override {
auto * spec = this;

auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;

auto * mem_dft = llama_get_memory(ctx_dft);

Expand Down Expand Up @@ -438,7 +441,52 @@ struct common_speculative_state_draft : public common_speculative_state {
};

struct common_speculative_state_eagle3 : public common_speculative_state {
common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {}
llama_context * ctx_tgt;

common_sampler * smpl;

llama_batch batch;

struct llama_context * ctx_dft_enc = nullptr;
struct llama_context * ctx_dft_dec = nullptr;

int32_t eagle3_n_past = 0; // number of verified positions in decoder KV cache

common_speculative_state_eagle3(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft_enc,
llama_context * ctx_dft_dec)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft_enc(ctx_dft_enc)
, ctx_dft_dec(ctx_dft_dec)
{
batch = llama_batch_init(llama_n_batch(ctx_dft_dec), 0, 1);

// Initialize sampler for EAGLE3 decoder
common_params_sampling params;
params.no_perf = false;
params.top_k = 10; // set 1 for greedy sampling (argmax) to match vLLM's default behavior but >1 always gets higher acceptance rate for eagle3
params.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
smpl = common_sampler_init(llama_get_model(ctx_dft_dec), params);
}

~common_speculative_state_eagle3() override {
llama_perf_context_print(ctx_dft_dec);

if (ctx_dft_dec) {
llama_free(ctx_dft_dec);
}

if (ctx_dft_enc) {
llama_free(ctx_dft_enc);
}

common_sampler_free(smpl);

llama_batch_free(batch);
}

void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
Expand All @@ -448,12 +496,97 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & draft_tokens) override {
// TODO: implement
GGML_UNUSED(params);
GGML_UNUSED(prompt_tgt);
GGML_UNUSED(id_last);
GGML_UNUSED(draft_tokens);
llama_tokens & result) override {
auto * spec = this;

auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft_enc = spec->ctx_dft_enc;
auto & ctx_dft_dec = spec->ctx_dft_dec;
auto & smpl = spec->smpl;

//result = gen_eagle3_draft(spec, params, prompt_tgt, id_last);
const int n_embd = llama_model_n_embd(llama_get_model(ctx_dft_enc));
const int n = (int)prompt_tgt.size();
const int n_new = n - spec->eagle3_n_past;

GGML_ASSERT(n >= 1 && "prompt_tgt is empty");
GGML_ASSERT(n_new >= 1 && "must have at least 1 new token");

// Clear draft positions from decoder KV cache [n_past, inf)
llama_memory_seq_rm(llama_get_memory(ctx_dft_dec), 0, spec->eagle3_n_past, -1);

// Encoder: features → g_embeddings
const float * features = llama_get_eagle3_target_features(ctx_tgt);
GGML_ASSERT(features && "no target features");

llama_batch enc_batch = {
/*.n_tokens =*/ n_new,
/*.token =*/ nullptr,
/*.embd =*/ const_cast<float*>(features),
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
GGML_ASSERT(llama_encode(ctx_dft_enc, enc_batch) == 0);

const float * g_embd = llama_get_embeddings(ctx_dft_enc);
GGML_ASSERT(g_embd && "encoder output failed");

// Decoder batch: process new tokens with KV cache reuse
llama_set_eagle3_g_embeddings(ctx_dft_dec, g_embd, n_embd, n_new);

common_batch_clear(batch);
for (int i = 0; i < n_new; i++) {
const int pos = spec->eagle3_n_past + i;
const llama_token tok = (pos < n - 1) ? prompt_tgt[pos + 1] : id_last;
common_batch_add(batch, tok, pos, {0}, true);
}

GGML_ASSERT(llama_decode(ctx_dft_dec, batch) == 0);

spec->eagle3_n_past = n; // update verified positions

// Sample draft tokens
result.clear();
common_sampler_reset(smpl);

// Sample and check probability (consistent with standard speculative decoding)
auto sample_and_check = [&](int idx) -> bool {
common_sampler_sample(smpl, ctx_dft_dec, idx);

const auto * cur_p = common_sampler_get_candidates(smpl, true);
const llama_token id = cur_p->data[0].id;

common_sampler_accept(smpl, id, true);
result.push_back(id);

return cur_p->data[0].p >= params.p_min;
};

// First draft token from batch decode
if (!sample_and_check(n_new - 1)) {
return;
}

// Autoregressive: use prenorm as g_embd (-1 = last output)
const float * prenorm = llama_get_embeddings_ith(ctx_dft_dec, -1);

for (int i = 1; i < params.n_max; i++) {
GGML_ASSERT(prenorm && "prenorm failed");
llama_set_eagle3_g_embeddings(ctx_dft_dec, prenorm, n_embd, 1);

common_batch_clear(batch);
common_batch_add(batch, result.back(), n - 1 + i, {0}, true);
GGML_ASSERT(llama_decode(ctx_dft_dec, batch) == 0);

prenorm = llama_get_embeddings_ith(ctx_dft_dec, -1);

if (!sample_and_check(0)) {
break;
}
}
}

void accept(uint16_t n_accepted) override {
Expand Down Expand Up @@ -840,19 +973,43 @@ common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt) {
llama_context * ctx_dft = nullptr;

llama_context * ctx_dft_enc = nullptr;
llama_context * ctx_dft_dec = nullptr;

if (params.model_dft) {
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
if (params.eagle3) {
llama_context_params params_enc = params.cparams_dft;
params_enc.target_model = nullptr;
params_enc.embeddings = true;
ctx_dft_enc = llama_init_from_model(params.model_dft, params_enc);
if (!ctx_dft_enc) {
LOG_ERR("failed to create EAGLE3 encoder context\n");
return nullptr;
}

llama_context_params params_dec = params.cparams_dft;
params_dec.target_model = params.model_tgt;
params_dec.embeddings = true;
ctx_dft_dec = llama_init_from_model(params.model_dft, params_dec);
if (!ctx_dft_dec) {
LOG_ERR("failed to create EAGLE3 decoder context\n");
return nullptr;
}
} else {
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
}
}
}

// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_draft_eagle3 = params.eagle3;

bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
Expand Down Expand Up @@ -893,10 +1050,11 @@ common_speculative * common_speculative_init(
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
} else {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
}
}
}

Expand All @@ -916,7 +1074,11 @@ common_speculative * common_speculative_init(
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft_enc = */ ctx_dft_enc,
/* .ctx_dft_dec = */ ctx_dft_dec
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
Expand Down
Loading
Loading