Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3447,6 +3447,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-use-checkpoints"}, "[on|off|auto]",
string_format("use checkpoints to rewind token history in recurrent models ('on', 'off', or 'auto', default: %s)",
params.speculative.use_checkpoints ? "on" : "off"),
[](common_params & params, const std::string & value) {
if (is_truthy(value) || is_autoy(value)) {
params.speculative.use_checkpoints = true;
} else if (is_falsey(value)) {
params.speculative.use_checkpoints = false;
} else {
throw std::invalid_argument("invalid value for --spec-use-checkpoints");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ struct common_params_speculative {
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models


std::shared_ptr<common_ngram_mod> ngram_mod;

Expand Down
4 changes: 2 additions & 2 deletions common/ngram-map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}

if (map.idx_last_check > cur_len) {
if (map.idx_last_check > cur_len) {
// Should not happen because of common_ngram_map_begin().
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
}
Expand Down Expand Up @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());

map.last_draft_created = false;
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
Expand Down
257 changes: 257 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,260 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf.c_str());
}
}


// server callbacks
//

common_speculative_callback::~common_speculative_callback() = default;

// server session
//
struct common_speculative_session::impl {
common_speculative_callback & callback;
common_params_speculative params_spec;

llama_context * ctx_tgt = nullptr;

common_speculative * spec = nullptr;

// `i_batch_dft`, idx of draft tokens in the main batch are stored in the caller

llama_tokens draft;

// use of checkpoints in speculative mode
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position (0 or 1)
size_t spec_ckpt_size_part = 0; // size of partial checkpoint

// Speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted

impl(common_speculative_callback & callback,
const common_params_speculative & params,
llama_context * ctx_tgt)
: callback(callback), params_spec(params), ctx_tgt(ctx_tgt) {
spec = common_speculative_init(params_spec, ctx_tgt);
}

void begin(const llama_tokens & prompt_history) {
common_speculative_begin(spec, prompt_history);
}

bool has_batch_dft() {
return !draft.empty();
}

void clear_draft() {
draft.clear();
spec_ckpt_n_denials = 0;
}

llama_tokens compute_draft(
const llama_tokens & cached_text_tokens,
llama_token id_last,
const int n_draft_max) {
if (spec == nullptr) {
// no implementation, nothing to do
clear_draft();
return draft;
}

if (n_draft_max == 0) {
clear_draft();
return draft;
}
if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) {
// We shouldn't get two denials.
LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__,
cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size());
clear_draft();
return draft;
}

if (spec_ckpt_n_denials == 1) {
// there is a previous speculation which wasn't accepted in full length
if (draft.empty()) {
LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__);
clear_draft();
return draft;
}
// we use the shortened draft of previous speculation
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__,
cached_text_tokens.size(), id_last, draft.size());
} else {
// call the speculative implementation to create a draft
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
LOG_DBG("draft: id_last=%d, #draft=%zu\n", id_last, draft.size());
if (draft.empty()) {
clear_draft();
return draft;
}
}

if (draft.size() > (size_t) n_draft_max) {
LOG_WRN("draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
}

bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints;
if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) {
LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n",
__func__,
cached_text_tokens.size(),
draft.size(), spec_ckpt_n_denials,
do_checkpoint ? "yes" : "no", id_last,
cached_text_tokens[cached_text_tokens.size() - 3],
cached_text_tokens[cached_text_tokens.size() - 2],
cached_text_tokens[cached_text_tokens.size() - 1],
draft[0], draft[1], draft[2]);
}

if (params_spec.n_min > (int) draft.size()) {
LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params_spec.n_min);
clear_draft();
return draft;
}

if (do_checkpoint) {
const size_t n = callback.create_checkpoint();
if (n == 0) {
LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size());
clear_draft();
return draft;
}
spec_ckpt_size_part = n;
spec_has_ckpt = true;
}

// add last sampled token to the batch
callback.batch_add_token(id_last, true);

// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
callback.batch_add_token(draft[i], true);
}

return draft;
}

common_speculative_accept_response sample_and_accept() {
const size_t n_draft = draft.size();

// the accepted tokens from the speculation
auto ids = callback.sampler_sample_and_accept_n(draft);

LOG_DBG("%s: n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
if (ids.size() < n_draft + 1) {
// the main model rejected some tokens

// we shorten the draft
draft.resize(ids.size() - 1);
if (spec_has_ckpt) {
// we need to rollback to the state before sampling the draft tokens
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
__func__,
ids.size() - 1, n_draft, n);

// delete Checkpoint
callback.delete_checkpoint();
spec_has_ckpt = false;

spec_ckpt_n_denials++;
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
// we will do the batch again but with the shortened draft
return common_speculative_accept_response(std::move(ids), n_draft, true);
}

LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
draft.clear();

// use the sampled token only
ids.resize(1);
// drafted tokens in prompt have been deleted in restore_checkpoint(...).
return common_speculative_accept_response{std::move(ids), 0, false};
}
}
const size_t draft_size_accepted = draft.size();
LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size());
common_speculative_accept(spec, draft_size_accepted);
draft.clear();

return common_speculative_accept_response{std::move(ids), n_draft, false};
}

void rewind(const llama_pos p0) {
spec_ckpt_n_denials = 0;
if (spec_has_ckpt) {
// Delete Checkpoint
callback.delete_checkpoint();
spec_has_ckpt = false;
} else {
callback.memory_seq_rm(p0, -1);
}
}

void print_stats() const {
if (spec == nullptr) {
return;
}

common_speculative_print_stats(spec);
}

void reset() {
if (spec == nullptr) {
return;
}

clear_draft();

spec_has_ckpt = false;
spec_ckpt_size_part = 0;
}
};

common_speculative_session::common_speculative_session(
common_speculative_callback & callback,
const common_params_speculative & params,
llama_context * ctx_tgt) : p_impl(new impl{callback, params, ctx_tgt}) {
}

common_speculative_session::~common_speculative_session() {
common_speculative_free(p_impl->spec);
delete p_impl;
}

void common_speculative_session::begin(const llama_tokens & prompt_history) {
p_impl->begin(prompt_history);
}

bool common_speculative_session::has_batch_dft() {
return !p_impl->has_batch_dft();
}

llama_tokens common_speculative_session::compute_draft(
const llama_tokens & prompt,
llama_token id_last,
int n_draft_max_slot) {
return p_impl->compute_draft(prompt, id_last, n_draft_max_slot);
}

common_speculative_accept_response common_speculative_session::sample_and_accept() {
return p_impl->sample_and_accept();
}

void common_speculative_session::rewind(const llama_pos p0) {
p_impl->rewind(p0);
}

void common_speculative_session::print_stats() const {
p_impl->print_stats();
}

void common_speculative_session::reset() {
p_impl->reset();
}

Loading