Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 27 additions & 43 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,23 @@ static void copy_tensor_async_candidates(
}
}

static bool needs_cpu_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
Comment thread
danbev marked this conversation as resolved.
Outdated
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (!ubatch.output[i]) {
continue;
}

// Check if the output token has at least one sequence without a backend sampler.
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
llama_seq_id seq_id = ubatch.seq_id[i][j];
if (samplers.find(seq_id) == samplers.end()) {
return true;
}
}
}
return false; // all sequences use backend sampling
}

int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT

Expand Down Expand Up @@ -1661,10 +1678,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

// extract logits
// For multi-sequence batches that mix backend samplers and CPU sampler
// this is currently inefficient as we copy all logits even for the
// backend sampled tokens.
if (logits && t_logits && n_outputs > 0) {
if (logits && t_logits && n_outputs > 0 && needs_cpu_logits(ubatch, sampling.samplers)) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
Expand Down Expand Up @@ -1734,11 +1748,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}

// This flag indicates whether a backend sampler has actually sampled a specific
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();

if (has_samplers && has_sampled) {
// Copy backend sampling output if this ubatch produced any sampling tensors.
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
const auto stride = n_vocab;

Expand Down Expand Up @@ -1814,6 +1825,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
//

uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
GGML_UNUSED(batch);
Comment thread
danbev marked this conversation as resolved.
Outdated

const auto & hparams = model.hparams;
const auto & vocab = model.vocab;

Expand All @@ -1832,45 +1845,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
has_embd = true;
}

// Check which sampling modes are needed for the current batch.
// TODO: avoid this branching by working with the worst-case
bool has_sampling = false;
bool cpu_logits = false;

if (batch.logits) {
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
llama_seq_id seq_id = batch.seq_id[i][j];
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
has_sampling = true;
} else {
cpu_logits = true;
}
}
}
} else {
// When batch.logits is nullptr (when loading state with a dummy batch),
// allocate CPU logits.
cpu_logits = true;
}

size_t backend_float_count = 0;
size_t backend_token_count = 0;

// Allocate CPU logits buffer only if needed by sequences in this batch
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;

// TODO: avoid this branching by working with the worst-case
if (!has_sampling) {
sampling.logits_size = 0;
sampling.probs_size = 0;
sampling.sampled_size = 0;
sampling.candidates_size = 0;
} else {
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
Expand Down Expand Up @@ -1928,7 +1912,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
size_t offset = 0;
uint8_t * base = (uint8_t *) output_base;

logits = (has_logits && cpu_logits) ? output_base : nullptr;
logits = has_logits ? output_base : nullptr;
offset += logits_size * sizeof(float);

embd = has_embd ? (float *) (base + offset) : nullptr;
Expand Down
Loading