Skip to content

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Oct 28, 2025

ref #16736 (comment)

Simplify the logic during memory module creation. Before, for KV caches, we used to pad the buffer size up to 256 cells since flash attention implementations did not support arbitrary K, V sizes. After the improvements in #16148 and related, we no longer need to explicitly do this padding.

For now, keeping support for llama_kv_cache size padding via the constructor's n_pad argument, although it is not currently used anymore.

Note that we continue to pad n_kv - this is the tensor shape for the K and V tensors for each graph:

uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];
// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
}
return result;
}

We need to do this in order to reuse most of the compute graphs during the text generation phase. Additionally, this helps the performance with some of the backends.

Also, llama_model::create_memory() no longer mutates cparams.

Next, will rebase #16736 on top of this change and finish it.

@ggerganov ggerganov requested a review from CISC as a code owner October 28, 2025 07:32
@JohannesGaessler
Copy link
Collaborator

FYI in the CUDA backend, while non-padded inputs will produce correct results, the performance will be worse. One reason is that the vector and mma kernels don't have support for it (in principle also the wmma kernel but I want to remove that one soon). Another reason is that like with e.g. ALiBi I've added support for a non-padded KV cache only to the template specialization that doesn't have GQA-specific optimizations (to keep compilation times low and to mitigate the few % performance penalty from the OOB checks).

@ggerganov
Copy link
Member Author

What is the optimal padding for CUDA - 128 or 256?

@JohannesGaessler
Copy link
Collaborator

As of right now the required padding is still 256. The vector and tile kernels I've already adapted to be able to only need a padding of 128 (without OOB checks). The mma kernel, wmma kernel, and some utility kernels still need a padding of 256. My plan is to implement support for Volta, AMD WMMA instructions, and AMD MFMA instructions directly in the mma kernel, then I can just remove the wmma kernel without having to make any changes to it. The hardware I need for development are a V100 (which should arrive today), a MI100 (which arrived yesterday), and an RDNA3+ GPU (already in hand). The mma kernel and the utility kernels can then be made to work with a padding of 128 without issue.

Caveat: the MI100 doesn't seem to work with the motherboard that I intended to use with it so I may need to shuffle around my hardware a bit.

@ggerganov
Copy link
Member Author

@JohannesGaessler The updated PR now continues to pad the n_kv to 256. This should basically keep the performance the same as on master. The only change is that the total size of the KV cache is no longer padded.

For example, we can now allocate a KV cache with 1000 cells (although not recommended), while on master this would have been automatically padded to 1024. Note that even with a total size of 1000 cells, most of the compute graphs will continue to have tensor shapes (n_kv) padded to 256 - f.ex. 256, 512, 768.

The main goal here is to simplify the logic around context size allocation per sequence and localize it during the context creation.

const auto n_ctx_train = llama_model_n_ctx_train(model);

if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= slot.n_ctx) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old logic of using the training context as a generation limit seems dubious - the context size of the slot should impose the limit.

After #16736, the slot.n_ctx will be capped to the training context either way, so this change should not make a really big difference either way.

Copy link
Collaborator

@ngxson ngxson Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, do you think this code branch can be removed altogether? The n_ctx cap is already imposed on the slot.n_past >= slot.n_ctx condition above (L2933)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

@github-actions github-actions bot added examples python python script changes server labels Oct 28, 2025
}

// if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note that slot.n_past won't reflect the correct number of tokens in KV cache in case the model uses M-RoPE. We should fix this later but I'm not entirely sure how. I'm thinking of these 2 solutions:

  • Rely on the server_tokens::size()
  • Add an API like llama_memory_seq_is_full which returns true if the memory is full

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll open a PR to fix this.

@ggerganov ggerganov merged commit 85a7d86 into master Oct 28, 2025
71 checks passed
@ggerganov ggerganov deleted the gg/context-no-pad branch October 28, 2025 18:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants