Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions common/ngram-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
entries.resize(size);
scores.resize(size, SCORE_INIT);

reset();
}
Expand All @@ -27,8 +28,12 @@ void common_ngram_mod::add(const entry_t * tokens) {

if (entries[i] == EMPTY) {
used++;
scores[i] = SCORE_INS;
} else if (entries[i] != tokens[n]) {
// a different token hashes to the same bucket
++collisions;
}

// keep existing score if entry already occupied
entries[i] = tokens[n];
}

Expand All @@ -40,7 +45,9 @@ common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {

void common_ngram_mod::reset() {
std::fill(entries.begin(), entries.end(), EMPTY);
std::fill(scores.begin(), scores.end(), SCORE_INIT);
used = 0;
collisions = 0;
}

size_t common_ngram_mod::get_n() const {
Expand All @@ -56,5 +63,85 @@ size_t common_ngram_mod::size() const {
}

size_t common_ngram_mod::size_bytes() const {
return entries.size() * sizeof(entries[0]);
return entries.size() * sizeof(entries[0]) + scores.size() * sizeof(scores[0]);
}

size_t common_ngram_mod::index(const entry_t * tokens) const {
return idx(tokens);
}

void common_ngram_mod::inc_score(const entry_t * tokens) {
const size_t i = idx(tokens);
if (scores[i] < common_ngram_mod::SCORE_MAX) {
++scores[i];
}
}

void common_ngram_mod::dec_score(const entry_t * tokens) {
const size_t i = idx(tokens);
if (scores[i] > common_ngram_mod::SCORE_MIN) {
--scores[i];
}
}

void common_ngram_mod::inc_score_by_index(size_t i) {
if (i < scores.size() && scores[i] < common_ngram_mod::SCORE_MAX) {
++scores[i];
}
}

void common_ngram_mod::dec_score_by_index(size_t i) {
if (i < scores.size() && scores[i] > common_ngram_mod::SCORE_MIN) {
--scores[i];
}
}

void common_ngram_mod::prune_low_score() {
used = 0;
for (size_t i = 0; i < entries.size(); ++i) {
if (entries[i] != EMPTY) {
if (scores[i] < common_ngram_mod::SCORE_THR) {
entries[i] = EMPTY;
scores[i] = 0;
} else {
++used;
}
}
}
}

size_t common_ngram_mod::get_collisions() const {
return collisions;
}

size_t common_ngram_mod::get_below_thr() const {
return count_below_thr;
}

size_t common_ngram_mod::get_at_min() const {
return count_at_min;
}

size_t common_ngram_mod::get_at_max() const {
return count_at_max;
}

size_t common_ngram_mod::get_at_ins() const {
return count_at_ins;
}

void common_ngram_mod::update_score_stats() {
// reset counters
count_below_thr = 0;
count_at_min = 0;
count_at_max = 0;
count_at_ins = 0;

for (size_t i = 0; i < scores.size(); ++i) {
const int8_t s = scores[i];
if (s < SCORE_THR) ++count_below_thr;
if (s == SCORE_MIN) ++count_at_min;
if (s == SCORE_MAX) ++count_at_max;
if (s == SCORE_INS) ++count_at_ins;
}
}
35 changes: 35 additions & 0 deletions common/ngram-mod.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ struct common_ngram_mod {

static constexpr entry_t EMPTY = -1;

static constexpr int8_t SCORE_INIT = 0;
static constexpr int8_t SCORE_MIN = -5;
static constexpr int8_t SCORE_MAX = 20;
static constexpr int8_t SCORE_THR = 0;
static constexpr int8_t SCORE_INS = 3;

common_ngram_mod(uint16_t n, size_t size);

size_t idx(const entry_t * tokens) const;
Expand All @@ -23,9 +29,27 @@ struct common_ngram_mod {

void reset();

// expose the hash index for external bookkeeping
size_t index(const entry_t * tokens) const;

// score handling
void inc_score(const entry_t * tokens);
void dec_score(const entry_t * tokens);
void inc_score_by_index(size_t i);
void dec_score_by_index(size_t i);
void prune_low_score(); // remove entries below SCORE_THR

size_t get_n() const;
size_t get_used() const;

void update_score_stats();

size_t get_collisions() const;
size_t get_below_thr() const;
size_t get_at_min() const;
size_t get_at_max() const;
size_t get_at_ins() const;

size_t size() const;
size_t size_bytes() const;

Expand All @@ -35,4 +59,15 @@ struct common_ngram_mod {
size_t used;

std::vector<entry_t> entries;
// per-entry score, range SCORE_MIN .. SCORE_MAX
std::vector<int8_t> scores;

// stats
// count of hash collisions
size_t collisions = 0;
// counts for score
size_t count_below_thr = 0;
size_t count_at_min = 0;
size_t count_at_max = 0;
size_t count_at_ins = 0;
};
32 changes: 28 additions & 4 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {

// consecutive accept rounds with low acceptance fraction (< 0.5)
int n_low = 0;
// hash indices of ngrams consulted during the most recent draft
std::vector<size_t> used_hashes;

// enable trace logging if LLAMA_TRACE is set
const bool verbose;
Expand Down Expand Up @@ -558,7 +560,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {

constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting (collisions=%zu)\n", __func__, f, f_thold, mod.get_collisions());

mod.reset();
}
Expand All @@ -572,6 +574,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
GGML_UNUSED(params);

n_draft_last = 0;
used_hashes.clear();

const size_t cur_len = prompt_tgt.size();
if (cur_len < mod.get_n()) {
Expand Down Expand Up @@ -607,6 +610,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
break;
}
result[n + i] = token;
// remember which hash entry produced this token
used_hashes.push_back(mod.index(result.data() + i));
}

// only return the m tokens that were drafted
Expand All @@ -627,18 +632,37 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
// compute acceptance fraction if we have a recorded draft length
if (n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)n_draft_last;

// update per-ngram scores based on acceptance outcome
for (size_t i = 0; i < n_draft_last; ++i) {
if (i < static_cast<size_t>(n_accepted)) {
mod.inc_score_by_index(used_hashes[i]);
} else {
mod.dec_score_by_index(used_hashes[i]);
}
}

if (f_acc < 0.5) {
n_low++;
if (n_low >= 3) {
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low);

mod.reset();
LOG_WRN("%s: low acceptance streak (%d) - pruning ngram_mod (collisions=%zu)\n", __func__, n_low, mod.get_collisions());
// Log detailed score metrics before pruning
mod.update_score_stats();
LOG_WRN("%s: before prune scores - below_thr=%zu, at_min=%zu, at_max=%zu, at_ins=%zu\n",
__func__,
mod.get_below_thr(),
mod.get_at_min(),
mod.get_at_max(),
mod.get_at_ins());

mod.prune_low_score();
n_low = 0;
}
} else {
n_low = 0;
}
}
used_hashes.clear();
}
};

Expand Down