Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,8 +1316,8 @@ static bool llama_kv_cache_find_slot(

while (true) {
if (cache.head + n_tokens > n_ctx) {
n_tested += cache.size - cache.head;
Copy link
Contributor Author

@KerfuffleV2 KerfuffleV2 Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be wrong, but I really don't understand the logic of setting n_tested to the full cache size in this case. That means we just give up searching in that case, but it's possible there's an empty slot of n_tokens at a point before the current cache.head, right?

At the least n_tested += n_ctx - cache.head is kind of redundant since cache.head just got set to 0. So this is like n_tested += cache.head - 0


I also really wanted to write something like

// Check if there's an empty slot past what we allocated. If so, we can
// set head to it and immediately find an empty slot next time. Otherwise
// just reset head to 0.
if (cache.head + n_tokens < n_ctx && cache.cells[cache.head + n_tokens].pos < 0 && cache.cells[cache.head + n_tokens].seq_id.empty()) {
    cache.head += n_tokens;
} else {
    cache.head = 0;
}

at the end of the function but it seems like cache.head is used to communicate the start of the tokens to evaluate for the actual evaluation (and it also does cache.head += n_tokens after).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, that was a bug! The fix is OK, but we should either use cache.size or n_ctx at all places in the function. No need to mix it.

cache.head = 0;
n_tested += n_ctx - cache.head;
continue;
}

Expand Down Expand Up @@ -1368,13 +1368,18 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}

// Searching for a free slot can start here since we know it will be empty.
cache.head = uint32_t(c0);
}

static void llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
uint32_t new_head = cache.size;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for these few functions follows basically the same pattern. We'll set new_head to the first freed slot if there is one. new_head = cache.size is a safe but invalid value we know will fit in uint32_t.

This could maybe be improved by only changing cache.head if it 1) doesn't already point at a free slot, and 2) points at an index greater than the one that got freed up. The idea would be to try to maximize using slots near the beginning of the cache. I'm not sure doing this is really worth the complexity though.

These changes (even in the simple form) will slow down the cache manipulation functions a little. I think that at least is worth it because searching for a slot is probably the most common case.


if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

Expand All @@ -1383,9 +1388,13 @@ static void llama_kv_cache_seq_rm(
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}

static void llama_kv_cache_seq_cp(
Expand All @@ -1397,6 +1406,8 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

cache.head = 0;

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
Expand All @@ -1405,12 +1416,18 @@ static void llama_kv_cache_seq_cp(
}

static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;

for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}

static void llama_kv_cache_seq_shift(
Expand All @@ -1419,6 +1436,8 @@ static void llama_kv_cache_seq_shift(
llama_pos p0,
llama_pos p1,
llama_pos delta) {
uint32_t new_head = cache.size;

if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

Expand All @@ -1428,12 +1447,17 @@ static void llama_kv_cache_seq_shift(
if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
cache.has_shift = true;
cache.cells[i].delta = delta;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
cache.head = new_head != cache.size ? new_head : 0;
}

//
Expand Down Expand Up @@ -4454,10 +4478,6 @@ static int llama_decode_internal(
batch.seq_id = seq_id.data();
}

// we always start to search for a free slot from the start of the cache
// TODO: better strategies can be implemented
kv_self.head = 0;

if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
Expand Down