Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1863,3 +1863,66 @@ float lr_opt::get_lr(float epoch) const {
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}

bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s: failed to replay last token\n", __func__);
return false;
}
return true;
}

bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & tokens,
int & n_past,
int n_batch,
const std::filesystem::path & state_path,
bool save_state,
bool is_last_batch) {
const int n_eval = tokens.size();
if (n_eval == 0) {
return true;
}

if (save_state && is_last_batch && n_eval > 1) {
const int n_tokens_before_last = n_eval - 1;

GGML_ASSERT(n_eval <= n_batch);

// Decode all but the last token so we can save the memory state before decoding the last token.
// This is done so we can restore the session state later and replay the last token.
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_tokens_before_last;

llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last);
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last);

llama_token last_token = tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
int32_t pos = n_past;
batch.pos = &pos;

if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}

return true;
}
19 changes: 19 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ggml-opt.h"
#include "llama-cpp.h"

#include <filesystem>
#include <set>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -779,6 +780,24 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);

// decodes a single batch of tokens for a prompt and manages session tokens
//
// Note: We save state before the last token so that we can replay it to ensure
// compatibility with all memory types. Recurrent/hybrid models cannot remove
// tokens from memory, so this approach works across all model architectures.
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & embd,
int & n_past,
int n_batch,
const std::filesystem::path & state_path,
bool save_state,
bool is_last_batch = true);

// replays the last token after loading state to regenerate logits
// used after loading session state to ensure the sampling context has valid logits
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);

//
// Token utils
//
Expand Down
94 changes: 36 additions & 58 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
#include "common.h"
#include "llama.h"

#include <filesystem>
#include <vector>
#include <cstdio>


int main(int argc, char ** argv) {
common_params params;

params.prompt = "The quick brown fox";
params.sampling.seed = 1234;

std::filesystem::path state_file = "dump_state.bin";

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
Expand Down Expand Up @@ -53,35 +57,16 @@ int main(int argc, char ** argv) {
// tokenize prompt
auto tokens = common_tokenize(ctx, params.prompt, true);

// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token

// evaluate prompt
llama_decode(ctx, batch);
n_past += batch.n_tokens;

// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());

FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
fclose(fp_write);

fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
const bool save_state = true;
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
return 1;
}

// save state (last tokens)
const auto n_past_saved = n_past;

// first run
printf("\nfirst run: %s", params.prompt.c_str());

llama_batch batch = llama_batch_init(1, 0, 1);

for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = common_token_to_piece(ctx, next_token);
Expand Down Expand Up @@ -111,27 +96,23 @@ int main(int argc, char ** argv) {

printf("\nsecond run: %s", params.prompt.c_str());

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
return 1;
}
// load state from file
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
size_t n_token_count_out = 0;

fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}

fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);

// restore state (last tokens)
n_past = n_past_saved;
n_past = n_token_count_out;
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
return 1;
}
++n_past;

// second run
for (auto i = 0; i < params.n_predict; i++) {
Expand Down Expand Up @@ -160,7 +141,9 @@ int main(int argc, char ** argv) {
}

// make new context
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
auto params_ctx3 = common_context_params_to_llama(params);
params_ctx3.n_seq_max = 2;
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);

llama_sampler * smpl3 = llama_sampler_chain_init(sparams);

Expand All @@ -169,26 +152,21 @@ int main(int argc, char ** argv) {
printf("\nsingle seq run: %s", params.prompt.c_str());

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
n_token_count_out = 0;

if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
return 1;
}

fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}

fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);

// restore state (last tokens)
n_past = n_past_saved;
n_past = n_token_count_out;
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
return 1;
}
++n_past;

// save seq 0 and load into seq 1
{
Expand Down
122 changes: 0 additions & 122 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2491,64 +2491,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
// TODO: add more model-specific info which should prevent loading the session file if not identical
}

// write output ids
{
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);

const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;

std::vector<int32_t> w_output_pos;

w_output_pos.resize(n_outputs);

// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
}
}

io.write(&n_outputs, sizeof(n_outputs));

if (n_outputs) {
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
}
}

// [TAG_CONTEXT_STATE_LOGITS]
// write logits
{
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);

const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());

io.write(&logits_size, sizeof(logits_size));

if (logits_size) {
io.write(logits.data, logits_size * sizeof(float));
}
}

// write embeddings
{
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);

const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);

io.write(&embd_size, sizeof(embd_size));

if (embd_size) {
io.write(embd.data, embd_size * sizeof(float));
}
}

// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004

if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
Expand All @@ -2574,70 +2516,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
// TODO: add more info which needs to be identical but which is not verified otherwise
}

// read output ids
{
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);

auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));

if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}

std::vector<int32_t> output_pos;

if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));

for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
if ((uint32_t) id >= n_batch()) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
}
this->output_ids[id] = i;
}

this->n_outputs = n_outputs;
}
}

// read logits
{
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);

uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));

if (this->logits.size < logits_size) {
throw std::runtime_error("logits buffer too small");
}

if (logits_size) {
io.read_to(this->logits.data, logits_size * sizeof(float));
}
}

// read embeddings
{
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);

uint64_t embd_size;
io.read_to(&embd_size, sizeof(embd_size));

if (this->embd.size < embd_size) {
throw std::runtime_error("embeddings buffer too small");
}

if (embd_size) {
io.read_to(this->embd.data, embd_size * sizeof(float));
}
}

// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004

if (memory) {
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);

Expand Down
Loading