Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,13 +1932,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
{"--ctx-checkpoints"}, "N",
string_format("max number of context checkpoints to create per slot (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
params.n_ctx_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
Expand Down
5 changes: 4 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);

// Returns true if the model is hybrid (like Jamba, Granite, etc.)
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);

// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);

Expand Down Expand Up @@ -791,7 +794,7 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);

#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
#define LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a bit better name than this. The old name does not work, but the proposed new name is confusing.

The purpose of this flag is to indicate that we want save only the "small" caches such as SWA, "recr", etc. But I can't think of a good name to call it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. I can't think of anything better at the moment


typedef uint32_t llama_state_seq_flags;

Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ bool llama_kv_cache_iswa::get_can_shift() const {
}

void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}

kv_swa->state_write(io, seq_id, flags);
}

void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}

Expand Down
16 changes: 8 additions & 8 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
}

void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
mem_attn->state_write(io, seq_id, flags);
}
mem_recr->state_write(io, seq_id, flags);
}

void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
mem_attn->state_read(io, seq_id, flags);
}
mem_recr->state_read(io, seq_id, flags);
}

llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
Expand Down
9 changes: 8 additions & 1 deletion src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
}

bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;

if (p0 < 0) {
Expand All @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
// invalidate tails which will be cleared
Expand All @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
} else {
// seq_id is negative, then the range should include everything or nothing
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
return false;
}
}
Expand Down Expand Up @@ -689,6 +692,8 @@ size_t llama_memory_recurrent::size_s_bytes() const {
}

void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
// the LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY flag is acknowledged but does not change
// behavior here, as there is no notion of a partial state for a recurrent context
GGML_UNUSED(flags);

std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
Expand Down Expand Up @@ -729,6 +734,8 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
}

void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
// the LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY flag is acknowledged but does not change
// behavior here, as there is no notion of a partial state for a recurrent context
GGML_UNUSED(flags);

uint32_t cell_count;
Expand Down
4 changes: 4 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19945,6 +19945,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
return llm_arch_is_recurrent(model->arch);
}

bool llama_model_is_hybrid(const llama_model * model) {
return llm_arch_is_hybrid(model->arch);
}

bool llama_model_is_diffusion(const llama_model * model) {
return llm_arch_is_diffusion(model->arch);
}
Expand Down
93 changes: 45 additions & 48 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ struct completion_token_output {
}
};

struct swa_checkpoint {
struct ctx_checkpoint {
llama_pos pos_min;
llama_pos pos_max;

Expand Down Expand Up @@ -1460,7 +1460,7 @@ struct server_slot {

std::vector<completion_token_output> generated_token_probs;

std::vector<swa_checkpoint> swa_checkpoints;
std::vector<ctx_checkpoint> ctx_checkpoints;

bool has_next_token = true;
bool has_new_line = false;
Expand Down Expand Up @@ -3555,38 +3555,38 @@ struct server_context {
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);

// search for a SWA checkpoint
// search for a context checkpoint
const auto it = std::find_if(
slot.swa_checkpoints.rbegin(),
slot.swa_checkpoints.rend(),
slot.ctx_checkpoints.rbegin(),
slot.ctx_checkpoints.rend(),
[&](const auto & cur) {
return cur.pos_min <= pos_min_thold;
}
);

bool do_reset = it == slot.swa_checkpoints.rend();
bool do_reset = it == slot.ctx_checkpoints.rend();
//printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");

if (!do_reset) {
// restore the checkpoint
const size_t swa_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
// restore the context checkpoint
const size_t ctx_checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);

if (n != swa_size) {
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
if (n != ctx_checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, it->pos_max);

SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
}
}

if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");

slot.n_past = 0;
slot.swa_checkpoints.clear();
slot.ctx_checkpoints.clear();
}
}
}
Expand All @@ -3595,21 +3595,20 @@ struct server_context {
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);

// erase any checkpoints with pos_min > pos_min_thold
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
const auto & cur = slot.swa_checkpoints[i];
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
const auto & cur = slot.ctx_checkpoints[i];
if (cur.pos_min > pos_min_thold) {
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);

SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
}
}
}
}

if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);

SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
}

slot.n_prompt_tokens_cache = slot.n_past;
Expand All @@ -3623,17 +3622,17 @@ struct server_context {
}
}

// keep only the common part
// truncate any tokens that are beyond n_past for this slot
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);

// there is no common part left
slot.n_past = 0;
slot.n_prompt_tokens_cache = 0;
}

SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
SLT_INF(slot, "n_past = %d\n", slot.n_past);

// remove the non-common part from the cache
slot.cache_tokens.keep_first(slot.n_past);
Expand Down Expand Up @@ -3854,37 +3853,35 @@ struct server_context {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;

// make a checkpoint with the SWA memory
// checkpoints are needed only if we are not using "--swa-full"
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
{
const auto & cur = slot.swa_checkpoints.back();

SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}

slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
// make a checkpoint of the parts of memory that cannot be rolled back.
// checkpoints are needed only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit torn on this logic for determining when to do checkpoints. It should be centred around the memory module or the context, rather than the model.

Just making a note for the future - no need to change anything in this PR.

if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
if (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.back();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}

const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);

auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(swa_size),
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});

llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);

float size_total = 0.0f;
for (const auto & checkpoint : slot.swa_checkpoints) {
size_total += (float) checkpoint.data.size() / 1024 / 1024;
}
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);

SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
Expand Down
Loading