Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 114 additions & 47 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ struct server_slot {

common_speculative * spec = nullptr;

// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
std::unique_ptr<const server_task> task;
std::unique_ptr<const server_task> task_prev; // used for debugging

Expand Down Expand Up @@ -153,7 +155,7 @@ struct server_slot {

common_sampler_ptr smpl;

llama_token sampled; // in speculative mode, this is the last accepted token
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;

// stats
Expand Down Expand Up @@ -201,12 +203,46 @@ struct server_slot {
alora_invocation_start = -1;
}

// remove cached prompt + tokens
void clear(bool allow_processing) {
if (!allow_processing) {
GGML_ASSERT(!is_processing());
}

SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());

llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
prompt.tokens.clear();
}

void init_sampler() const {
const int64_t t_start = ggml_time_us();

common_sampler_reset(smpl.get());

int n_text = 0;

for (int i = 0; i < (int) prompt.tokens.size(); i++) {
const llama_token id = prompt.tokens[i];

if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(smpl.get(), id, false);
n_text++;
}
}

SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n",
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
}

// TODO: move to server_task
bool need_embd() const {
GGML_ASSERT(task);

return server_task_type_need_embd(task->type);
}

// TODO: move to server_task
bool need_logits() const {
GGML_ASSERT(task);

Expand Down Expand Up @@ -258,10 +294,13 @@ struct server_slot {
SLT_WRN(*this, "%s", "slot is not processing\n");
return;
}

generated_token_probs.push_back(token);
}

int get_n_draft_max() const {
GGML_ASSERT(task);

if (!can_speculate()) {
return 0;
}
Expand All @@ -287,12 +326,14 @@ struct server_slot {
}

// note: a slot can also be either a parent or a child
// TODO: move to server_task
bool is_parent() const {
return is_processing() && task->n_children > 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is_processing() check also seemed redundant so removed it.

return task->n_children > 0;
}

// TODO: move to server_task
bool is_child() const {
return is_processing() && task->id_parent >= 0;
return task->id_parent >= 0;
}

void release() {
Expand All @@ -301,10 +342,16 @@ struct server_slot {

SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);

t_last_used = ggml_time_us();
t_last_used = ggml_time_us();
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;

state = SLOT_STATE_IDLE;

// do not keep context of the child slots - the parent's context is enough
if (is_child()) {
clear(false);
}

task_prev = std::move(task);
task.reset();

Expand Down Expand Up @@ -425,14 +472,22 @@ struct server_slot {
}

void copy_state_to(server_slot & other) const {
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);

llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1);

other.n_decoded = n_decoded;
other.n_remaining = n_remaining;
other.i_batch = i_batch;

other.t_start_process_prompt = t_start_process_prompt;
other.t_prompt_processing = t_prompt_processing;
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
other.n_prompt_tokens_processed = n_prompt_tokens_processed;

other.prompt = prompt.clone();
other.init_sampler();
}
};

Expand Down Expand Up @@ -745,6 +800,7 @@ struct server_context_impl {
}

slots.clear();

for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;

Expand Down Expand Up @@ -993,7 +1049,7 @@ struct server_context_impl {
ret->prompt_save(*prompt_cache);

if (!ret->prompt_load(*prompt_cache, task.tokens)) {
clear_slot(*ret);
ret->clear(false);
}

prompt_cache->update();
Expand All @@ -1005,17 +1061,6 @@ struct server_context_impl {
return ret;
}

void clear_slot(server_slot & slot, bool allow_processing = false) const {
if (!allow_processing) {
GGML_ASSERT(!slot.is_processing());
}

SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());

llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
}

// return true if at least one slot has been cleared
// TODO: improve logic
// - smarter decision which slot to clear (LRU or longest prompt?)
Expand All @@ -1036,7 +1081,7 @@ struct server_context_impl {
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());

clear_slot(slot);
slot.clear(false);

res = true;

Expand Down Expand Up @@ -1182,7 +1227,7 @@ struct server_context_impl {
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;

SLT_INF(slot, "%s", "processing task\n");
SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());

return true;
}
Expand Down Expand Up @@ -1819,7 +1864,7 @@ struct server_context_impl {
// Erase token cache
const size_t n_erased = slot->prompt.tokens.size();

clear_slot(*slot);
slot->clear(false);

auto res = std::make_unique<server_task_result_slot_erase>();
res->id = task.id;
Expand Down Expand Up @@ -2053,8 +2098,29 @@ struct server_context_impl {
continue;
}

// check if this is a child slot
if (slot.state == SLOT_STATE_WAIT_OTHER) {
SLT_DBG(slot, "%s", "waiting for parent slot to complete\n");
continue;
}

// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
// wait for all children to be launched
if (slot.is_parent()) {
int n_launched = 0;
for (auto & other : slots) {
if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) {
++n_launched;
}
}

if (n_launched < slot.task->n_children) {
SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched);
continue;
}
}

const auto & input_tokens = slot.task->tokens;

// TODO: maybe move branch to outside of this loop in the future
Expand Down Expand Up @@ -2355,7 +2421,7 @@ struct server_context_impl {
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);

clear_slot(slot, /*allow_processing=*/true);
slot.clear(true);

// there is no common part left
slot.n_prompt_tokens_cache = 0;
Expand Down Expand Up @@ -2455,16 +2521,6 @@ struct server_context_impl {

GGML_ASSERT(batch.n_tokens > 0);

common_sampler_reset(slot.smpl.get());

// Process all prompt tokens through sampler system
for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl.get(), id, false);
}
}

// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;

Expand All @@ -2473,6 +2529,8 @@ struct server_context_impl {

SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);

slot.init_sampler();

const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

Expand Down Expand Up @@ -2519,11 +2577,6 @@ struct server_context_impl {
}
}

if (batch.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}

SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);

if (slot_batched) {
Expand All @@ -2540,6 +2593,10 @@ struct server_context_impl {
llama_set_embeddings(ctx, slot_batched->need_embd());
}

if (batch.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
}

int32_t i_next = 0;

// process the created batch of tokens
Expand Down Expand Up @@ -2591,7 +2648,7 @@ struct server_context_impl {

// note: it's complicated to keep track of how much of the current batch has been
// processed before the error occurred, so we simply clear the entire context
clear_slot(slot);
slot.clear(false);
}
}

Expand All @@ -2615,27 +2672,34 @@ struct server_context_impl {
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx);

// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
// may need to copy state to other slots
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
std::vector<server_slot *> child_slots;
SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);

std::vector<server_slot *> children;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
child_slots.push_back(&other);
children.push_back(&other);
}
}

// we can only proceed if all child slots are having the correct tasks
if (child_slots.size() == slot.task->n_children) {
if (slot.task->n_children == (int) children.size()) {
// copy state to the child slots
for (auto & child : child_slots) {
SLT_INF(slot, "copying state to child %d\n", child->id);
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);

GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);

slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
}
}
}

for (auto & slot : slots) {
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
Expand Down Expand Up @@ -2720,7 +2784,7 @@ struct server_context_impl {
continue;
}

size_t n_draft = slot.drafted.size();
const size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
Expand Down Expand Up @@ -2923,9 +2987,11 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = meta->model_name;

// prepare child tasks
if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
for (size_t j = 0; j < task.n_children; j++) {

for (int j = 0; j < task.n_children; j++) {
server_task child = task.create_child(task.id, rd.get_new_id());

// use different sampling seed for each child
Expand All @@ -2938,7 +3004,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
}
}

tasks.push_back(std::move(task));
// note: the parent task always launches first
tasks.insert(tasks.begin(), std::move(task));
}

rd.post_tasks(std::move(tasks));
Expand Down
6 changes: 4 additions & 2 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ struct server_task {
int id_slot = -1;

// used by parallel sampling (multiple completions from same prompt)
size_t n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;
int n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;

// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
Expand Down Expand Up @@ -173,11 +173,13 @@ struct server_task {

server_task create_child(int id_parent, int id_child) const {
server_task copy;

copy.id = id_child;
copy.id_parent = id_parent;
copy.params = params;
copy.type = type;
copy.tokens = tokens.clone();

return copy;
}

Expand Down
Loading