-
Notifications
You must be signed in to change notification settings - Fork 13.7k
kv cache slot search improvements #3493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1316,8 +1316,8 @@ static bool llama_kv_cache_find_slot( | |
|
|
||
| while (true) { | ||
| if (cache.head + n_tokens > n_ctx) { | ||
| n_tested += cache.size - cache.head; | ||
| cache.head = 0; | ||
| n_tested += n_ctx - cache.head; | ||
| continue; | ||
| } | ||
|
|
||
|
|
@@ -1368,13 +1368,18 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, | |
| cache.cells[i].pos = -1; | ||
| cache.cells[i].seq_id.clear(); | ||
| } | ||
|
|
||
| // Searching for a free slot can start here since we know it will be empty. | ||
| cache.head = uint32_t(c0); | ||
| } | ||
|
|
||
| static void llama_kv_cache_seq_rm( | ||
| struct llama_kv_cache & cache, | ||
| llama_seq_id seq_id, | ||
| llama_pos p0, | ||
| llama_pos p1) { | ||
| uint32_t new_head = cache.size; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for these few functions follows basically the same pattern. We'll set This could maybe be improved by only changing These changes (even in the simple form) will slow down the cache manipulation functions a little. I think that at least is worth it because searching for a slot is probably the most common case. |
||
|
|
||
| if (p0 < 0) p0 = 0; | ||
| if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||
|
|
||
|
|
@@ -1383,9 +1388,13 @@ static void llama_kv_cache_seq_rm( | |
| cache.cells[i].seq_id.erase(seq_id); | ||
| if (cache.cells[i].seq_id.empty()) { | ||
| cache.cells[i].pos = -1; | ||
| if (new_head == cache.size) new_head = i; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // If we freed up a slot, set head to it so searching can start there. | ||
| if (new_head != cache.size) cache.head = new_head; | ||
| } | ||
|
|
||
| static void llama_kv_cache_seq_cp( | ||
|
|
@@ -1397,6 +1406,8 @@ static void llama_kv_cache_seq_cp( | |
| if (p0 < 0) p0 = 0; | ||
| if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||
|
|
||
| cache.head = 0; | ||
|
|
||
| for (uint32_t i = 0; i < cache.size; ++i) { | ||
| if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||
| cache.cells[i].seq_id.insert(seq_id_dst); | ||
|
|
@@ -1405,12 +1416,18 @@ static void llama_kv_cache_seq_cp( | |
| } | ||
|
|
||
| static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { | ||
| uint32_t new_head = cache.size; | ||
|
|
||
| for (uint32_t i = 0; i < cache.size; ++i) { | ||
| if (!cache.cells[i].has_seq_id(seq_id)) { | ||
| cache.cells[i].pos = -1; | ||
| cache.cells[i].seq_id.clear(); | ||
| if (new_head == cache.size) new_head = i; | ||
| } | ||
| } | ||
|
|
||
| // If we freed up a slot, set head to it so searching can start there. | ||
| if (new_head != cache.size) cache.head = new_head; | ||
| } | ||
|
|
||
| static void llama_kv_cache_seq_shift( | ||
|
|
@@ -1419,6 +1436,8 @@ static void llama_kv_cache_seq_shift( | |
| llama_pos p0, | ||
| llama_pos p1, | ||
| llama_pos delta) { | ||
| uint32_t new_head = cache.size; | ||
|
|
||
| if (p0 < 0) p0 = 0; | ||
| if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||
|
|
||
|
|
@@ -1428,12 +1447,17 @@ static void llama_kv_cache_seq_shift( | |
| if (cache.cells[i].pos < 0) { | ||
| cache.cells[i].pos = -1; | ||
| cache.cells[i].seq_id.clear(); | ||
| if (new_head == cache.size) new_head = i; | ||
| } else { | ||
| cache.has_shift = true; | ||
| cache.cells[i].delta = delta; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // If we freed up a slot, set head to it so searching can start there. | ||
| // Otherwise we just start the next search from the beginning. | ||
| cache.head = new_head != cache.size ? new_head : 0; | ||
| } | ||
|
|
||
| // | ||
|
|
@@ -4454,10 +4478,6 @@ static int llama_decode_internal( | |
| batch.seq_id = seq_id.data(); | ||
| } | ||
|
|
||
| // we always start to search for a free slot from the start of the cache | ||
| // TODO: better strategies can be implemented | ||
| kv_self.head = 0; | ||
|
|
||
| if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||
| return 1; | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be wrong, but I really don't understand the logic of setting
n_testedto the full cache size in this case. That means we just give up searching in that case, but it's possible there's an empty slot ofn_tokensat a point before the currentcache.head, right?At the least
n_tested += n_ctx - cache.headis kind of redundant sincecache.headjust got set to 0. So this is liken_tested += cache.head - 0I also really wanted to write something like
at the end of the function but it seems like
cache.headis used to communicate the start of the tokens to evaluate for the actual evaluation (and it also doescache.head += n_tokensafter).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, that was a bug! The fix is OK, but we should either use
cache.sizeorn_ctxat all places in the function. No need to mix it.