Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,65 @@ bool llama_memory_recurrent::resize(uint32_t new_mem_size) {
cells.resize(new_mem_size);
size = new_mem_size;

uint32_t used_new = 0;
for (auto & cell : cells) {
cell.tail = -1;

for (auto it = cell.seq_id.begin(); it != cell.seq_id.end();) {
if (*it < 0 || (uint32_t) *it >= size) {
LLAMA_LOG_WARN("%s: dropping seq_id %d after resize %u -> %u\n",
__func__, *it, old_size, new_mem_size);
it = cell.seq_id.erase(it);
} else {
++it;
}
}

if (cell.seq_id.empty()) {
cell.pos = -1;
cell.src = -1;
cell.src0 = -1;
continue;
}

if (cell.src >= (int32_t) size) {
LLAMA_LOG_WARN("%s: clearing out-of-range src %d after resize %u -> %u\n",
__func__, cell.src, old_size, new_mem_size);
cell.src = -1;
}
if (cell.src0 >= (int32_t) size) {
LLAMA_LOG_WARN("%s: clearing out-of-range src0 %d after resize %u -> %u\n",
__func__, cell.src0, old_size, new_mem_size);
cell.src0 = -1;
}

used_new++;
}

for (uint32_t i = 0; i < size; ++i) {
for (llama_seq_id seq_id : cells[i].seq_id) {
cells[seq_id].tail = i;
}
}

used = used_new;
if (size == 0) {
head = 0;
n = 0;
rs_z = -1;
} else {
head = std::min(head, size - 1);
n = std::min(n, size);
if (rs_z >= (int32_t) size) {
rs_z = -1;
}
}

const size_t memory_size_r = size_r_bytes();
const size_t memory_size_s = size_s_bytes();
LLAMA_LOG_INFO("%s: resized %u -> %u cells, R: %7.2f MiB, S: %7.2f MiB\n", __func__,
LLAMA_LOG_INFO("%s: resized %u -> %u cells, used=%u, head=%u, n=%u, rs_z=%d, R: %7.2f MiB, S: %7.2f MiB\n", __func__,
old_size, new_mem_size,
used, head, n, rs_z,
(float)memory_size_r / (1024.0f * 1024.0f),
(float)memory_size_s / (1024.0f * 1024.0f));

Expand Down Expand Up @@ -1346,5 +1401,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
}

int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + mem->head].src0;
const int32_t cell_idx = (int32_t) mem->head + i;
const int32_t fallback = get_rs_z() >= 0 ? get_rs_z() : 0;

if (cell_idx < 0 || (uint32_t) cell_idx >= mem->size) {
LLAMA_LOG_WARN("%s: source cell index out of range: i=%d head=%u size=%u, using zero state %d\n",
__func__, i, mem->head, mem->size, fallback);
return fallback;
}

const int32_t src = mem->cells[cell_idx].src0;
if (src < 0 || (uint32_t) src >= mem->size) {
LLAMA_LOG_WARN("%s: recurrent source row out of range: i=%d cell=%d src=%d size=%u, using zero state %d\n",
__func__, i, cell_idx, src, mem->size, fallback);
return fallback;
}

return src;
}
36 changes: 36 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,35 @@ struct server_context_impl {
prompt_cache->update();
}

void recurrent_shrink_for_prefill(const char * reason) {
if (!recurrent_expanded || !needs_reeval || n_seq_max_full <= n_parallel_user) {
return;
}

for (const server_slot & slot : slots) {
if (slot.is_processing() || slot.has_draft_backup) {
SRV_DBG("not shrinking recurrent state for prefill (%s): slot %d processing=%d has_backup=%d\n",
reason, slot.id, slot.is_processing(), slot.has_draft_backup);
return;
}
}

auto * mem = llama_get_memory(ctx);
for (const server_slot & slot : slots) {
const llama_seq_id seq_backup = slot.id + n_parallel_user;
llama_memory_seq_rm(mem, seq_backup, -1, -1);
}

if (llama_memory_recurrent_shrink(mem, n_parallel_user)) {
recurrent_expanded = false;
SRV_INF("shrunk recurrent state to %d cells for prefill (%s, removed %d backup cells)\n",
n_parallel_user, reason, n_seq_max_full - n_parallel_user);
} else {
SRV_ERR("failed to shrink recurrent state to %d cells for prefill (%s)\n",
n_parallel_user, reason);
}
}

void handle_sleeping_state(bool new_state) {
GGML_ASSERT(sleeping != new_state);
if (new_state) {
Expand Down Expand Up @@ -1289,6 +1318,8 @@ struct server_context_impl {
}

if (ret) {
recurrent_shrink_for_prefill("before prompt cache save/load");

const auto & tokens = ret->prompt.tokens;

update_cache = update_cache && prompt_cache;
Expand Down Expand Up @@ -2791,6 +2822,11 @@ struct server_context_impl {
if (!do_reset) {
// restore the context checkpoint
const size_t checkpoint_size = it->data.size();
SLT_DBG(slot,
"restoring context checkpoint data=%.3f MiB ring=%.3f MiB recurrent_expanded=%d n_parallel_user=%d n_seq_max_full=%d\n",
(float) it->data.size() / 1024 / 1024,
(float) it->ring_data.size() / 1024 / 1024,
recurrent_expanded, n_parallel_user, n_seq_max_full);
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

if (n != checkpoint_size) {
Expand Down