Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 32 additions & 56 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,7 @@ llama_context::llama_context(

// graph outputs buffer
{
// resized during inference when a batch uses more outputs
// Create a dummy batch for initialization.
llama_batch dummy_batch = {};
dummy_batch.n_tokens = 0;
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
throw std::runtime_error("failed to reserve initial output buffer");
}

Expand Down Expand Up @@ -1225,7 +1221,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
n_queued_tokens += n_tokens;

// reserve output buffer
if (output_reserve(n_tokens, batch_inp) < n_tokens) {
if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
};
Expand Down Expand Up @@ -1456,6 +1452,23 @@ static void copy_tensor_async_candidates(
}
}

static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (!ubatch.output[i]) {
continue;
}

// Check if the output token has at least one sequence without a backend sampler.
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
llama_seq_id seq_id = ubatch.seq_id[i][j];
if (samplers.find(seq_id) == samplers.end()) {
return true;
}
}
}
return false; // all sequences use backend sampling
}

int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT

Expand Down Expand Up @@ -1588,7 +1601,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

// reserve output buffer
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2;
};
Expand Down Expand Up @@ -1661,10 +1674,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

// extract logits
// For multi-sequence batches that mix backend samplers and CPU sampler
// this is currently inefficient as we copy all logits even for the
// backend sampled tokens.
if (logits && t_logits && n_outputs > 0) {
if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
Expand Down Expand Up @@ -1734,11 +1744,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}

// This flag indicates whether a backend sampler has actually sampled a specific
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();

if (has_samplers && has_sampled) {
// Copy backend sampling output if this ubatch produced any sampling tensors.
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
const auto stride = n_vocab;

Expand Down Expand Up @@ -1813,7 +1820,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
// output
//

uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
uint32_t llama_context::output_reserve(int32_t n_outputs) {

const auto & hparams = model.hparams;
const auto & vocab = model.vocab;

Expand All @@ -1832,45 +1840,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
has_embd = true;
}

// Check which sampling modes are needed for the current batch.
// TODO: avoid this branching by working with the worst-case
bool has_sampling = false;
bool cpu_logits = false;

if (batch.logits) {
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
llama_seq_id seq_id = batch.seq_id[i][j];
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
has_sampling = true;
} else {
cpu_logits = true;
}
}
}
} else {
// When batch.logits is nullptr (when loading state with a dummy batch),
// allocate CPU logits.
cpu_logits = true;
}

size_t backend_float_count = 0;
size_t backend_token_count = 0;

// Allocate CPU logits buffer only if needed by sequences in this batch
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;

// TODO: avoid this branching by working with the worst-case
if (!has_sampling) {
sampling.logits_size = 0;
sampling.probs_size = 0;
sampling.sampled_size = 0;
sampling.candidates_size = 0;
} else {
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
Expand Down Expand Up @@ -1928,7 +1907,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
size_t offset = 0;
uint8_t * base = (uint8_t *) output_base;

logits = (has_logits && cpu_logits) ? output_base : nullptr;
logits = has_logits ? output_base : nullptr;
offset += logits_size * sizeof(float);

embd = has_embd ? (float *) (base + offset) : nullptr;
Expand Down Expand Up @@ -2620,10 +2599,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));

// Create a dummy batch for state loading.
llama_batch dummy_batch = {};
dummy_batch.n_tokens = 0;
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}

Expand Down Expand Up @@ -2868,7 +2844,7 @@ void llama_context::opt_epoch_iter(
}

// reserve output buffer
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};
Expand Down
2 changes: 1 addition & 1 deletion src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ struct llama_context {

// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
uint32_t output_reserve(int32_t n_outputs);

void output_reorder();

Expand Down
Loading