From 9d9be8bae4e9911f336723908139ae5b2e17f61e Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 13 Jan 2026 15:10:09 +0100 Subject: [PATCH] sampling : remove sampling branching in output_reserve This commit updates output_reserve in llama-context.cpp to always allocate sampling buffers regardless of whether sampling is needed for the current batch. The motivation for this is to avoid reallocations and branching based on the sampling requirements of the batch. --- src/llama-context.cpp | 76 +++++++++++++++++-------------------------- src/llama-context.h | 1 + 2 files changed, 30 insertions(+), 47 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f220010a1b..d1ae58d244 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -818,7 +818,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!has_sampled) { return nullptr; } @@ -873,7 +873,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!has_sampled) { return model.vocab.n_tokens(); } @@ -1746,11 +1746,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba has_embd = true; } - // Check which sampling modes are needed for the current batch. - // TODO: avoid this branching by working with the worst-case - bool has_sampling = false; - bool cpu_logits = false; - + has_sampled = false; + bool cpu_logits = false; if (batch.logits) { for (int32_t i = 0; i < batch.n_tokens; i++) { if (!batch.logits[i]) { @@ -1759,7 +1756,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { llama_seq_id seq_id = batch.seq_id[i][j]; if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { - has_sampling = true; + has_sampled = true; } else { cpu_logits = true; } @@ -1778,21 +1775,13 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; embd_size = has_embd ? n_embd_out*n_outputs_max : 0; - // TODO: avoid this branching by working with the worst-case - if (!has_sampling) { - sampling.logits_size = 0; - sampling.probs_size = 0; - sampling.sampled_size = 0; - sampling.candidates_size = 0; - } else { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; + sampling.logits_size = n_vocab*n_outputs_max; + sampling.probs_size = n_vocab*n_outputs_max; + sampling.sampled_size = n_outputs_max; + sampling.candidates_size = n_vocab*n_outputs_max; - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; - } + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; if (output_ids.empty()) { // init, never resized afterwards @@ -1848,37 +1837,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba embd = has_embd ? (float *) (base + offset) : nullptr; offset += embd_size * sizeof(float); - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); - if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + // The count vectors keep track of the actual number of logits/probs/candidates + // copied from the backend for each output row. - // The count vectors keep track of the actual number of logits/probs/candidates - // copied from the backend for each output row. + sampling.logits_count.resize(n_outputs_max); + sampling.probs_count.resize(n_outputs_max); + sampling.candidates_count.resize(n_outputs_max); - sampling.logits_count.resize(n_outputs_max); - sampling.probs_count.resize(n_outputs_max); - sampling.candidates_count.resize(n_outputs_max); + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); - std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); - std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); - } + std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1908,7 +1890,7 @@ void llama_context::output_reorder() { } } - if (sampling.logits && sampling.logits_size > 0) { + if (has_sampled && sampling.logits_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); } diff --git a/src/llama-context.h b/src/llama-context.h index b29edf4db2..a581cea895 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -293,6 +293,7 @@ struct llama_context { }; sampling_info sampling; + bool has_sampled = false; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE