diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6283393ff38..d6cd5f66ace 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -541,10 +541,65 @@ bool llama_memory_recurrent::resize(uint32_t new_mem_size) { cells.resize(new_mem_size); size = new_mem_size; + uint32_t used_new = 0; + for (auto & cell : cells) { + cell.tail = -1; + + for (auto it = cell.seq_id.begin(); it != cell.seq_id.end();) { + if (*it < 0 || (uint32_t) *it >= size) { + LLAMA_LOG_WARN("%s: dropping seq_id %d after resize %u -> %u\n", + __func__, *it, old_size, new_mem_size); + it = cell.seq_id.erase(it); + } else { + ++it; + } + } + + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + cell.src0 = -1; + continue; + } + + if (cell.src >= (int32_t) size) { + LLAMA_LOG_WARN("%s: clearing out-of-range src %d after resize %u -> %u\n", + __func__, cell.src, old_size, new_mem_size); + cell.src = -1; + } + if (cell.src0 >= (int32_t) size) { + LLAMA_LOG_WARN("%s: clearing out-of-range src0 %d after resize %u -> %u\n", + __func__, cell.src0, old_size, new_mem_size); + cell.src0 = -1; + } + + used_new++; + } + + for (uint32_t i = 0; i < size; ++i) { + for (llama_seq_id seq_id : cells[i].seq_id) { + cells[seq_id].tail = i; + } + } + + used = used_new; + if (size == 0) { + head = 0; + n = 0; + rs_z = -1; + } else { + head = std::min(head, size - 1); + n = std::min(n, size); + if (rs_z >= (int32_t) size) { + rs_z = -1; + } + } + const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: resized %u -> %u cells, R: %7.2f MiB, S: %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: resized %u -> %u cells, used=%u, head=%u, n=%u, rs_z=%d, R: %7.2f MiB, S: %7.2f MiB\n", __func__, old_size, new_mem_size, + used, head, n, rs_z, (float)memory_size_r / (1024.0f * 1024.0f), (float)memory_size_s / (1024.0f * 1024.0f)); @@ -1346,5 +1401,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const int32_t cell_idx = (int32_t) mem->head + i; + const int32_t fallback = get_rs_z() >= 0 ? get_rs_z() : 0; + + if (cell_idx < 0 || (uint32_t) cell_idx >= mem->size) { + LLAMA_LOG_WARN("%s: source cell index out of range: i=%d head=%u size=%u, using zero state %d\n", + __func__, i, mem->head, mem->size, fallback); + return fallback; + } + + const int32_t src = mem->cells[cell_idx].src0; + if (src < 0 || (uint32_t) src >= mem->size) { + LLAMA_LOG_WARN("%s: recurrent source row out of range: i=%d cell=%d src=%d size=%u, using zero state %d\n", + __func__, i, cell_idx, src, mem->size, fallback); + return fallback; + } + + return src; } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c1f2116b7d5..b8f9734fed9 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -752,6 +752,35 @@ struct server_context_impl { prompt_cache->update(); } + void recurrent_shrink_for_prefill(const char * reason) { + if (!recurrent_expanded || !needs_reeval || n_seq_max_full <= n_parallel_user) { + return; + } + + for (const server_slot & slot : slots) { + if (slot.is_processing() || slot.has_draft_backup) { + SRV_DBG("not shrinking recurrent state for prefill (%s): slot %d processing=%d has_backup=%d\n", + reason, slot.id, slot.is_processing(), slot.has_draft_backup); + return; + } + } + + auto * mem = llama_get_memory(ctx); + for (const server_slot & slot : slots) { + const llama_seq_id seq_backup = slot.id + n_parallel_user; + llama_memory_seq_rm(mem, seq_backup, -1, -1); + } + + if (llama_memory_recurrent_shrink(mem, n_parallel_user)) { + recurrent_expanded = false; + SRV_INF("shrunk recurrent state to %d cells for prefill (%s, removed %d backup cells)\n", + n_parallel_user, reason, n_seq_max_full - n_parallel_user); + } else { + SRV_ERR("failed to shrink recurrent state to %d cells for prefill (%s)\n", + n_parallel_user, reason); + } + } + void handle_sleeping_state(bool new_state) { GGML_ASSERT(sleeping != new_state); if (new_state) { @@ -1289,6 +1318,8 @@ struct server_context_impl { } if (ret) { + recurrent_shrink_for_prefill("before prompt cache save/load"); + const auto & tokens = ret->prompt.tokens; update_cache = update_cache && prompt_cache; @@ -2791,6 +2822,11 @@ struct server_context_impl { if (!do_reset) { // restore the context checkpoint const size_t checkpoint_size = it->data.size(); + SLT_DBG(slot, + "restoring context checkpoint data=%.3f MiB ring=%.3f MiB recurrent_expanded=%d n_parallel_user=%d n_seq_max_full=%d\n", + (float) it->data.size() / 1024 / 1024, + (float) it->ring_data.size() / 1024 / 1024, + recurrent_expanded, n_parallel_user, n_seq_max_full); const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != checkpoint_size) {