Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
}

float * llama_context::get_logits() {
output_reorder();

return logits;
}

float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;

output_reorder();

try {
if (logits == nullptr) {
throw std::runtime_error("no logits");
Expand Down Expand Up @@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
}

float * llama_context::get_embeddings() {
output_reorder();

return embd;
}

float * llama_context::get_embeddings_ith(int32_t i) {
int64_t j = -1;

output_reorder();

try {
if (embd == nullptr) {
throw std::runtime_error("no embeddings");
Expand Down Expand Up @@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {

// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
output_swaps.clear();

bool did_optimize = false;

Expand Down Expand Up @@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;

GGML_ASSERT((size_t) n_outputs == out_ids.size());

// TODO: is there something more efficient which also minimizes swaps?
Expand All @@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
continue;
}
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}

// remember the swaps and apply them lazily upon logits/embeddings access
output_swaps.push_back({ i, j_min });
}

std::fill(output_ids.begin(), output_ids.end(), -1);
Expand Down Expand Up @@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
return n_outputs_max;
}

void llama_context::output_reorder() {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;

for (uint32_t s = 0; s < output_swaps.size(); ++s) {
const uint32_t i0 = output_swaps[s].i0;
const uint32_t i1 = output_swaps[s].i1;

if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
}
}

if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
}
}
}

output_swaps.clear();
}

//
// graph
//
Expand Down
9 changes: 9 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ struct llama_context {
// Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs);

void output_reorder();

//
// graph
//
Expand Down Expand Up @@ -250,6 +252,13 @@ struct llama_context {

std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers

struct swap_info {
uint32_t i0;
uint32_t i1;
};

std::vector<swap_info> output_swaps;

ggml_backend_sched_ptr sched;

ggml_backend_t backend_cpu = nullptr;
Expand Down
Loading