Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,26 @@ struct server_slot {
alora_invocation_start = -1;
}

void init_sampler() const {
const int64_t t_start = ggml_time_us();

common_sampler_reset(smpl.get());

int n_text = 0;

for (int i = 0; i < (int) prompt.tokens.size(); i++) {
const llama_token id = prompt.tokens[i];

if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(smpl.get(), id, false);
n_text++;
}
}

SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n",
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
}

bool need_embd() const {
GGML_ASSERT(task);

Expand Down Expand Up @@ -288,11 +308,11 @@ struct server_slot {

// note: a slot can also be either a parent or a child
bool is_parent() const {
return is_processing() && task->n_children > 0;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is_processing() check also seemed redundant so removed it.

return task->n_children > 0;
}

bool is_child() const {
return is_processing() && task->id_parent >= 0;
return task->id_parent >= 0;
}

void release() {
Expand Down Expand Up @@ -425,14 +445,22 @@ struct server_slot {
}

void copy_state_to(server_slot & other) const {
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);

llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);

other.n_decoded = n_decoded;
other.n_remaining = n_remaining;
other.i_batch = i_batch;

other.t_start_process_prompt = t_start_process_prompt;
other.t_prompt_processing = t_prompt_processing;
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
other.n_prompt_tokens_processed = n_prompt_tokens_processed;

other.prompt = prompt.clone();
other.init_sampler();
}
};

Expand Down Expand Up @@ -1182,7 +1210,7 @@ struct server_context_impl {
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;

SLT_INF(slot, "%s", "processing task\n");
SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child());

return true;
}
Expand Down Expand Up @@ -2053,6 +2081,12 @@ struct server_context_impl {
continue;
}

// check if this is a child slot
if (slot.state == SLOT_STATE_WAIT_OTHER) {
SLT_DBG(slot, "%s", "waiting for parent slot to complete\n");
continue;
}

// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
const auto & input_tokens = slot.task->tokens;
Expand Down Expand Up @@ -2455,16 +2489,6 @@ struct server_context_impl {

GGML_ASSERT(batch.n_tokens > 0);

common_sampler_reset(slot.smpl.get());

// Process all prompt tokens through sampler system
for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl.get(), id, false);
}
}

// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;

Expand All @@ -2473,6 +2497,8 @@ struct server_context_impl {

SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);

slot.init_sampler();

const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

Expand Down Expand Up @@ -2519,11 +2545,6 @@ struct server_context_impl {
}
}

if (batch.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}

SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);

if (slot_batched) {
Expand All @@ -2540,6 +2561,10 @@ struct server_context_impl {
llama_set_embeddings(ctx, slot_batched->need_embd());
}

if (batch.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
}

int32_t i_next = 0;

// process the created batch of tokens
Expand Down Expand Up @@ -2615,27 +2640,34 @@ struct server_context_impl {
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx);

// handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too
for (auto & slot : slots) {
// may need to copy state to other slots
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
std::vector<server_slot *> child_slots;
SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children);

std::vector<server_slot *> children;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
child_slots.push_back(&other);
children.push_back(&other);
}
}

// we can only proceed if all child slots are having the correct tasks
if (child_slots.size() == slot.task->n_children) {
if (slot.task->n_children == (int) children.size()) {
// copy state to the child slots
for (auto & child : child_slots) {
SLT_INF(slot, "copying state to child %d\n", child->id);
for (auto & child : children) {
SLT_INF(slot, " - copying state to child %d\n", child->id);

GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER);

slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
}
}
}

for (auto & slot : slots) {
// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
Expand Down Expand Up @@ -2720,7 +2752,7 @@ struct server_context_impl {
continue;
}

size_t n_draft = slot.drafted.size();
const size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
Expand Down Expand Up @@ -2925,7 +2957,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(

if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
for (size_t j = 0; j < task.n_children; j++) {
for (int j = 0; j < task.n_children; j++) {
server_task child = task.create_child(task.id, rd.get_new_id());

// use different sampling seed for each child
Expand Down
4 changes: 2 additions & 2 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ struct server_task {
int id_slot = -1;

// used by parallel sampling (multiple completions from same prompt)
size_t n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;
int n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;

// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
Expand Down
Loading