Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3447,6 +3447,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ckpt-num-tries"}, "N",
string_format("number of tries for speculative decoding with recurrent memory (default: %d)", params.speculative.ckpt_num_tries),
[](common_params & params, int value) {
if (value < 0 || value > 10) {
throw std::invalid_argument("number of tries must be between 0 and 10 inclusive");
}
params.speculative.ckpt_num_tries = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ struct common_params_speculative {
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
uint16_t ckpt_num_tries = 0; // number of tries in case of recurrent memory

std::shared_ptr<common_ngram_mod> ngram_mod;

Expand Down
4 changes: 2 additions & 2 deletions common/ngram-map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}

if (map.idx_last_check > cur_len) {
if (map.idx_last_check > cur_len) {
// Should not happen because of common_ngram_map_begin().
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
}
Expand Down Expand Up @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());

map.last_draft_created = false;
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
Expand Down
Binary file modified tools/server/public/index.html.gz
Binary file not shown.
129 changes: 123 additions & 6 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ struct server_slot {
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;

// use of checkpoints in speculative mode
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position
int spec_ckpt_n_accepted = 0; // number of accepted tokens at current position
size_t spec_ckpt_size_part = 0; // size of partial checkpoint


// stats
size_t n_sent_text = 0; // number of sent text character

Expand Down Expand Up @@ -184,6 +191,11 @@ struct server_slot {
n_draft_total = 0;
n_draft_accepted = 0;

spec_ckpt_n_denials = 0;
spec_ckpt_n_accepted = 0;
spec_has_ckpt = false;
spec_ckpt_size_part = 0;

task_prev = std::move(task);
task.reset();

Expand Down Expand Up @@ -742,7 +754,7 @@ struct server_context_impl {

const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
SRV_WRN("%s", "speculative decoding not supported by this context without checkpoints\n");
}

// initialize slots
Expand All @@ -757,7 +769,7 @@ struct server_context_impl {
slot.prompt.tokens.has_mtmd = mctx != nullptr;

// try speculative decoding
if (can_spec) {
if (can_spec || params_base.speculative.ckpt_num_tries > 0) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
Expand Down Expand Up @@ -2041,8 +2053,9 @@ struct server_context_impl {
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
const int n_draft_max = (slot.spec_ckpt_n_accepted > 0) ? slot.spec_ckpt_n_accepted : slot.get_n_draft_max();
if (n_draft_max > 0 && (params_base.speculative.ckpt_num_tries == 0
|| slot.spec_ckpt_n_denials < params_base.speculative.ckpt_num_tries)) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
Expand All @@ -2059,8 +2072,52 @@ struct server_context_impl {
draft.resize(n_draft_max);
}

const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
bool do_checkpoint = !draft.empty() && params_base.speculative.ckpt_num_tries > 0
&& slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
if (do_checkpoint && cached_text_tokens.size() > 5) {
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
draft.size(), slot.spec_ckpt_n_denials,
slot.prompt.checkpoints.size(),
do_checkpoint ? "yes" : "no", pos_min, pos_max,
cached_text_tokens[cached_text_tokens.size() - 3],
cached_text_tokens[cached_text_tokens.size() - 2],
cached_text_tokens[cached_text_tokens.size() - 1]);
}

if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();

SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}

const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, 0);

auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});

const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.spec_ckpt_size_part = n;
slot.spec_has_ckpt = true;
}

// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
SLT_DBG(slot, "before common_batch_add: sampled=%d, pos_next=%d, tokens.size=%zu, tokens.last=%d\n",
slot.sampled, slot.prompt.tokens.pos_next(), slot.prompt.tokens.size(), slot.prompt.tokens[slot.prompt.tokens.size() -1]);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);

Expand All @@ -2070,6 +2127,15 @@ struct server_context_impl {
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();

if (slot.spec_has_ckpt) {
slot.spec_ckpt_n_accepted = 0;
slot.spec_ckpt_n_denials = 0;

// Delete Checkpoint
slot.prompt.checkpoints.pop_back();
slot.spec_has_ckpt = false;
}
} else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
Expand All @@ -2086,6 +2152,9 @@ struct server_context_impl {
// no speculative decoding
slot.i_batch = batch.n_tokens;

slot.spec_ckpt_n_denials = 0;
slot.spec_ckpt_n_accepted = 0;

common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);

slot.prompt.tokens.push_back(slot.sampled);
Expand Down Expand Up @@ -2538,6 +2607,7 @@ struct server_context_impl {

// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);

// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
Expand Down Expand Up @@ -2797,12 +2867,49 @@ struct server_context_impl {

const int64_t t_current = ggml_time_us();

slot.n_decoded += ids.size();

if (slot.spec_has_ckpt && ids.size() < n_draft + 1) {
// the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens
auto & ckpt = slot.prompt.checkpoints.back();
SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n",
ids.size() - 1, n_draft,
ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx,
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != slot.spec_ckpt_size_part) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), slot.spec_ckpt_size_part, n);
}
SRV_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
ids.size() -1 , n_draft, n);

// rollback to the state before sampling the draft tokens
SLT_INF(slot, "partial acceptance: n_tokens=%d, n_draft=%zu, pos_max=%d\n",
slot.prompt.n_tokens(), n_draft, ckpt.pos_max);

slot.prompt.tokens.keep_first(ckpt.pos_max + 1);

// Delete Checkpoint
slot.prompt.checkpoints.pop_back();
slot.spec_has_ckpt = false;

// Inform the speculative implementation of the number of valid tokens.
// common_speculative_accept(slot.spec, ids.size() - 1);

slot.spec_ckpt_n_denials++;
slot.spec_ckpt_n_accepted = (slot.spec_ckpt_n_denials < params_base.speculative.ckpt_num_tries) ? (int) (ids.size() - 1) : 0;

common_batch_clear(batch);

continue;
}

slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;

// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
slot.spec_ckpt_n_accepted = 0;

// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
Expand All @@ -2814,7 +2921,17 @@ struct server_context_impl {
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token

llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
slot.spec_ckpt_n_denials = 0;
if (slot.spec_has_ckpt) {
// Delete Checkpoint
if (slot.prompt.checkpoints.empty()) {
GGML_ABORT("missing checkpoint to delete");
}
slot.prompt.checkpoints.pop_back();
slot.spec_has_ckpt = false;
} else {
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
}

for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
Expand Down
Loading