Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions tests/test-save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "llama-cpp.h"

#include <clocale>
#include <random>
#include <vector>

struct llama_batch_ptr {
Expand All @@ -23,16 +24,15 @@ struct llama_batch_ptr {
const llama_batch & get() const { return batch; }
};

static std::string generate_tokens(llama_context * ctx, llama_sampler * smpl, int & n_past, int32_t n_predict, llama_seq_id seq_id) {
std::string result;
static llama_tokens generate_tokens(llama_context * ctx, llama_sampler * smpl, int & n_past, int32_t n_predict, llama_seq_id seq_id) {
llama_tokens result;
llama_batch_ptr batch(1, 0, 1);

for (int i = 0; i < n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = common_token_to_piece(ctx, next_token);
auto next_token = llama_sampler_sample(smpl, ctx, -1);

LOG("%s", next_token_str.c_str());
result += next_token_str;
LOG("%d ", next_token);
result.push_back(next_token);

common_batch_clear(batch.get());
common_batch_add(batch.get(), next_token, n_past, {seq_id}, true);
Expand All @@ -48,28 +48,24 @@ static std::string generate_tokens(llama_context * ctx, llama_sampler * smpl, in
}

// Test 1: baseline
// - tokenize the prompt
// - decode all but the last token
// - save state to disk
// - decode the last token
// - generate n_predict tokens
static std::string test_baseline(struct llama_model * model, const struct common_params & params) {
static llama_tokens test_baseline(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens) {
auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))};

auto sparams = llama_sampler_chain_default_params();
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));

auto tokens = common_tokenize(ctx.get(), params.prompt, true);

auto n_past = 0;
if (!common_prompt_batch_decode(ctx.get(), tokens, (int)tokens.size(), n_past, params.n_batch, params.out_file, true)) {
LOG_ERR("%s: failed to decode prompt\n", __func__);
return {};
}

LOG("\n=== Test 1: baseline ===\n");
LOG("%s", params.prompt.c_str());

auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 0);
if (result.empty()) {
Expand All @@ -87,20 +83,17 @@ static std::string test_baseline(struct llama_model * model, const struct common
// - load state from file
// - replay the last prompt token
// - generate n_predict tokens and compare against expected result
static bool test_state_load(struct llama_model * model, const struct common_params & params, const std::string & expected_result) {
static bool test_state_load(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens, const llama_tokens & expected_result) {
auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))};

auto sparams = llama_sampler_chain_default_params();
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));

auto tokens = common_tokenize(ctx.get(), params.prompt, true);

LOG("\n=== Test 2: state load ===\n");
LOG("%s", params.prompt.c_str());

// Load state from file
std::vector<llama_token> unused_sts(tokens.size());
llama_tokens unused_sts(tokens.size());
size_t n_token_count_out = 0;

if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
Expand Down Expand Up @@ -139,7 +132,7 @@ static bool test_state_load(struct llama_model * model, const struct common_para
// - replay the last prompt token
// - migrate KV cache from seq 0 to seq 1 via the CPU path
// - generate n_predict tokens on seq 1 and compare against expected result
static bool test_seq_cp_host(struct llama_model * model, const struct common_params & params, const std::string & expected_result) {
static bool test_seq_cp_host(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens, const llama_tokens & expected_result) {
auto params_ctx = common_context_params_to_llama(params);
params_ctx.n_seq_max = 2;
auto ctx = llama_context_ptr{llama_init_from_model(model, params_ctx)};
Expand All @@ -148,13 +141,10 @@ static bool test_seq_cp_host(struct llama_model * model, const struct common_par
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));

auto tokens = common_tokenize(ctx.get(), params.prompt, true);

LOG("\n=== Test 3: seq copy (host) ===\n");
LOG("%s", params.prompt.c_str());

// Load state from file
std::vector<llama_token> unused_sts(tokens.size());
llama_tokens unused_sts(tokens.size());
size_t n_token_count_out = 0;

if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
Expand Down Expand Up @@ -214,7 +204,7 @@ static bool test_seq_cp_host(struct llama_model * model, const struct common_par
// - replay the last prompt token
// - migrate KV cache from seq 0 to seq 1 via the on-device path
// - generate n_predict tokens on seq 1 and compare against expected result
static bool test_seq_cp_device(struct llama_model * model, const struct common_params & params, const std::string & expected_result) {
static bool test_seq_cp_device(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens, const llama_tokens & expected_result) {
auto params_ctx = common_context_params_to_llama(params);
params_ctx.n_seq_max = 2;
auto ctx = llama_context_ptr{llama_init_from_model(model, params_ctx)};
Expand All @@ -223,13 +213,10 @@ static bool test_seq_cp_device(struct llama_model * model, const struct common_p
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));

auto tokens = common_tokenize(ctx.get(), params.prompt, true);

LOG("\n=== Test 4: seq copy (device) ===\n");
LOG("%s", params.prompt.c_str());

// Load state from file
std::vector<llama_token> unused_sts(tokens.size());
llama_tokens unused_sts(tokens.size());
size_t n_token_count_out = 0;

if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
Expand Down Expand Up @@ -287,7 +274,8 @@ int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");

common_params params;
params.prompt = "The quick brown fox";
params.prompt = "";
params.n_batch = 100;
params.out_file = "dump_state.bin";
params.sampling.seed = 1234;

Expand Down Expand Up @@ -318,24 +306,49 @@ int main(int argc, char ** argv) {

GGML_ASSERT(llama_init->context() == nullptr);

// Tokenize prompt or generate random tokens
llama_tokens tokens;
if (params.prompt.empty()) {
const int n_prompt = params.n_batch;

// this path is useful for model files that do not have a tokenizer
LOG_INF("%s: no prompt provided, generating %d (n_batch) random tokens\n", __func__, n_prompt);

const auto * vocab = llama_model_get_vocab(model);
const auto n_vocab = llama_vocab_n_tokens(vocab);

std::mt19937 rng(params.sampling.seed);
std::uniform_int_distribution<llama_token> dist(0, n_vocab - 1);
for (int i = 0; i < n_prompt; i++) {
tokens.push_back(dist(rng));
}
} else {
LOG_INF("%s: tokenizing prompt '%s'\n", __func__, params.prompt.c_str());

auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))};
tokens = common_tokenize(ctx.get(), params.prompt, true);
}

LOG_INF("%s: the input prompt is %d tokens\n", __func__, (int)tokens.size());

// Test 1: baseline (saves state to disk)
auto result_baseline = test_baseline(model, params);
auto result_baseline = test_baseline(model, params, tokens);
if (result_baseline.empty()) {
return 1;
}

// Test 2: state load
if (!test_state_load(model, params, result_baseline)) {
if (!test_state_load(model, params, tokens, result_baseline)) {
return 1;
}

// Test 3: seq copy (host)
if (!test_seq_cp_host(model, params, result_baseline)) {
if (!test_seq_cp_host(model, params, tokens, result_baseline)) {
return 1;
}

// Test 4: seq copy (device)
if (!test_seq_cp_device(model, params, result_baseline)) {
if (!test_seq_cp_device(model, params, tokens, result_baseline)) {
return 1;
}

Expand Down