Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
Expand Down
11 changes: 6 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,12 @@ struct common_params {
std::string cls_sep = "\t"; // separator of classification sequences

// server params
int32_t port = 8080; // server listens on this network port
int32_t timeout_read = 600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t port = 8080; // server listens on this network port
int32_t timeout_read = 600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
Expand Down
23 changes: 23 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,29 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);

#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1

typedef uint32_t llama_state_seq_flags;

LLAMA_API size_t llama_state_seq_get_size_ext(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);

LLAMA_API size_t llama_state_seq_get_data_ext(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id,
llama_state_seq_flags flags);

LLAMA_API size_t llama_state_seq_set_data_ext(
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id,
llama_state_seq_flags flags);

//
// Decoding
//
Expand Down
44 changes: 28 additions & 16 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
}
}

size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_io_write_dummy io;
try {
return state_seq_write_data(io, seq_id);
return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}

size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
llama_io_write_buffer io(dst, size);
try {
return state_seq_write_data(io, seq_id);
return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}

size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
llama_io_read_buffer io(src, size);
try {
return state_seq_read_data(io, seq_id);
return state_seq_read_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
Expand Down Expand Up @@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
{
const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file);
const size_t nread = state_seq_read_data(io, seq_id);
const size_t nread = state_seq_read_data(io, seq_id, 0);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
Expand All @@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file

// save the context state using stream saving
llama_io_write_file io(&file);
state_seq_write_data(io, seq_id);
state_seq_write_data(io, seq_id, 0);

const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
Expand Down Expand Up @@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
return io.n_bytes();
}

size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);

if (memory) {
memory->state_write(io, seq_id);
memory->state_write(io, seq_id, flags);
}

return io.n_bytes();
}

size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);

if (memory) {
memory->state_read(io, seq_id);
memory->state_read(io, seq_id, flags);
}

return io.n_bytes();
Expand Down Expand Up @@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
}

size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
return ctx->state_seq_get_size(seq_id);
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
}

size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
}

size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
}

size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
return ctx->state_seq_get_size(seq_id, flags);
}

size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();

return ctx->state_seq_get_data(seq_id, dst, size);
return ctx->state_seq_get_data(seq_id, dst, size, flags);
}

size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();

return ctx->state_seq_set_data(seq_id, src, size);
return ctx->state_seq_set_data(seq_id, src, size, flags);
}

size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
Expand Down
10 changes: 5 additions & 5 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ struct llama_context {
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);

size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);

bool state_load_file(
const char * filepath,
Expand Down Expand Up @@ -212,8 +212,8 @@ struct llama_context {
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);

size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);

//
// members
Expand Down
18 changes: 12 additions & 6 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}

void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
kv_base->state_write(io, seq_id);
kv_swa ->state_write(io, seq_id);
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}

kv_swa->state_write(io, seq_id, flags);
}

void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
kv_base->state_read(io, seq_id);
kv_swa ->state_read(io, seq_id);
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}

kv_swa->state_read(io, seq_id, flags);
}

llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

//
// llama_kv_cache_unified_iswa specific API
Expand Down
8 changes: 6 additions & 2 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
return false;
}

void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

io.write(&n_stream, sizeof(n_stream));

for (uint32_t s = 0; s < n_stream; ++s) {
Expand Down Expand Up @@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
}
}

void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));

uint32_t n_stream_cur;
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class llama_kv_cache_unified : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

//
// llama_kv_cache_unified specific API
Expand Down
8 changes: 6 additions & 2 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}

void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
}

void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
}
Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory-hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class llama_memory_hybrid : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

//
// llama_memory_hybrid specific API
Expand Down
8 changes: 6 additions & 2 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
return size_s_bytes;
}

void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;

Expand Down Expand Up @@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
state_write_data(io, cell_ranges);
}

void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));

Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class llama_memory_recurrent : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ struct llama_memory_i {
// state write/read
//

virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};

using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
Expand Down
Loading
Loading