Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,10 +957,14 @@ bool llama_kv_cache::get_has_shift() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;

// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);

for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];

result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
}

return result;
Expand Down Expand Up @@ -2010,8 +2014,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}

uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}
2 changes: 0 additions & 2 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ struct llama_context;

class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);

struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
Expand Down
23 changes: 4 additions & 19 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
}
};

llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
llama_memory_i * res;

switch (arch) {
Expand Down Expand Up @@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
};
}

const auto padding = llama_kv_cache::get_padding(cparams);

cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);

res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
Expand All @@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);

uint32_t n_ctx_per_stream = cparams.n_ctx;

if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);

cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);

cparams.n_ctx = n_ctx_per_stream;
}

LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);

llama_memory_i::layer_reuse_cb reuse = nullptr;

if (arch == LLM_ARCH_GEMMA3N) {
Expand All @@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding,
1,
nullptr,
reuse);
} else {
Expand All @@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max,
padding,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
Expand Down
3 changes: 1 addition & 2 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,8 @@ struct llama_model {

ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;

// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;

// TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const;
Expand Down
27 changes: 3 additions & 24 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2866,10 +2866,12 @@ struct server_context {

// if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note that slot.n_past won't reflect the correct number of tokens in KV cache in case the model uses M-RoPE. We should fix this later but I'm not entirely sure how. I'm thinking of these 2 solutions:

  • Rely on the server_tokens::size()
  • Add an API like llama_memory_seq_is_full which returns true if the memory is full

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll open a PR to fix this.

slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;

SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx);
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
}

// check the limits
Expand Down Expand Up @@ -2929,36 +2931,13 @@ struct server_context {
}
}

// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_past >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;

SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
}

if (llama_vocab_is_eog(vocab, result.tok)) {
slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false;

SLT_DBG(slot, "%s", "stopped by EOS\n");
}

const auto n_ctx_train = llama_model_n_ctx_train(model);

if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction

SLT_WRN(slot,
"n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
slot.task->params.n_predict, n_ctx_train);
}

SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());

return slot.has_next_token; // continue
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_ctx_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_ctx_shift_enabled():

@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
(-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
Expand Down
Loading