Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2055,10 +2055,12 @@ void common_prompt_checkpoint::clear() {
void common_prompt_checkpoint::update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max) {
llama_pos pos_max,
llama_pos pos_end) {
this->n_tokens = n_tokens;
this->pos_min = pos_min;
this->pos_max = pos_max;
this->pos_end = pos_end;
}

void common_prompt_checkpoint::update_tgt(
Expand Down
4 changes: 3 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ struct common_prompt_checkpoint {

llama_pos pos_min;
llama_pos pos_max;
llama_pos pos_end;

std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
Expand All @@ -1066,7 +1067,8 @@ struct common_prompt_checkpoint {
void update_pos(
int64_t n_tokens,
llama_pos pos_min,
llama_pos pos_max);
llama_pos pos_max,
llama_pos pos_end);

void update_tgt(
llama_context * ctx,
Expand Down
17 changes: 14 additions & 3 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,7 @@ struct server_context_impl {

auto & cur = slot.prompt.checkpoints.emplace_back();

cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max, slot.prompt.tokens.pos_next());

cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
Expand Down Expand Up @@ -2442,7 +2442,8 @@ struct server_context_impl {
slot.spec_ckpt.update_pos(
slot.prompt.n_tokens(),
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id),
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id),
slot.prompt.tokens.pos_next());

if (use_ckpt_dft) {
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
Expand Down Expand Up @@ -2714,6 +2715,10 @@ struct server_context_impl {
}

llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
const llama_pos prompt_end = slot.task->tokens.pos_next();
const bool is_recurrent_or_hybrid =
llama_model_is_recurrent(model_tgt) ||
llama_model_is_hybrid(model_tgt);

// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, pos_next - n_swa - 1);
Expand Down Expand Up @@ -2774,9 +2779,15 @@ struct server_context_impl {
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
[&, func_name = __func__](const auto & cur) {
if (cur.pos_end > prompt_end) {
return false;
}
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
if (is_recurrent_or_hybrid) {
return cur.pos_max < pos_next || cur.pos_min == 0;
}
return cur.pos_min < pos_min_thold || cur.pos_min == 0;
}
);
Expand Down Expand Up @@ -2806,7 +2817,7 @@ struct server_context_impl {
// erase any checkpoints with pos_max > pos_next
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_max > pos_next) {
if (cur.pos_end > prompt_end || cur.pos_max > pos_next) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
Expand Down