Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 29 additions & 47 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
float * llama_context::get_sampled_logits_ith(int32_t idx) {
output_reorder();

if (sampling.logits == nullptr) {
if (!has_sampled) {
return nullptr;
}

Expand Down Expand Up @@ -873,7 +873,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
size_t llama_context::get_sampled_logits_count(int32_t idx) {
output_reorder();

if (sampling.logits == nullptr) {
if (!has_sampled) {
return model.vocab.n_tokens();
}

Expand Down Expand Up @@ -1746,11 +1746,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
has_embd = true;
}

// Check which sampling modes are needed for the current batch.
// TODO: avoid this branching by working with the worst-case
bool has_sampling = false;
bool cpu_logits = false;

has_sampled = false;
bool cpu_logits = false;
if (batch.logits) {
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
Expand All @@ -1759,7 +1756,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
llama_seq_id seq_id = batch.seq_id[i][j];
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
has_sampling = true;
has_sampled = true;
} else {
cpu_logits = true;
}
Expand All @@ -1778,21 +1775,13 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;

// TODO: avoid this branching by working with the worst-case
if (!has_sampling) {
sampling.logits_size = 0;
sampling.probs_size = 0;
sampling.sampled_size = 0;
sampling.candidates_size = 0;
} else {
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;
sampling.logits_size = n_vocab*n_outputs_max;
sampling.probs_size = n_vocab*n_outputs_max;
sampling.sampled_size = n_outputs_max;
sampling.candidates_size = n_vocab*n_outputs_max;

backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;
}
backend_float_count = sampling.logits_size + sampling.probs_size;
backend_token_count = sampling.sampled_size + sampling.candidates_size;

if (output_ids.empty()) {
// init, never resized afterwards
Expand Down Expand Up @@ -1848,37 +1837,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
embd = has_embd ? (float *) (base + offset) : nullptr;
offset += embd_size * sizeof(float);

sampling.logits = nullptr;
sampling.probs = nullptr;
sampling.sampled = nullptr;
sampling.candidates = nullptr;
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);

if (has_sampling) {
sampling.logits = (float *) (base + offset);
offset += sampling.logits_size * sizeof(float);
sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);

sampling.probs = (float *) (base + offset);
offset += sampling.probs_size * sizeof(float);
sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);

sampling.sampled = (llama_token *) (base + offset);
offset += sampling.sampled_size * sizeof(llama_token);
sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);

sampling.candidates = (llama_token *) (base + offset);
offset += sampling.candidates_size * sizeof(llama_token);
// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.

// The count vectors keep track of the actual number of logits/probs/candidates
// copied from the backend for each output row.
sampling.logits_count.resize(n_outputs_max);
sampling.probs_count.resize(n_outputs_max);
sampling.candidates_count.resize(n_outputs_max);

sampling.logits_count.resize(n_outputs_max);
sampling.probs_count.resize(n_outputs_max);
sampling.candidates_count.resize(n_outputs_max);
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);

std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);

std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
}
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);

// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
Expand Down Expand Up @@ -1908,7 +1890,7 @@ void llama_context::output_reorder() {
}
}

if (sampling.logits && sampling.logits_size > 0) {
if (has_sampled && sampling.logits_size > 0) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ struct llama_context {
};

sampling_info sampling;
bool has_sampled = false;

// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
Expand Down
Loading