Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2551,6 +2551,10 @@ struct server_context_impl {
int get_slot_n_ctx() {
return slots.back().n_ctx;
}

server_response_reader get_response_reader() {
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
}
};

//
Expand Down Expand Up @@ -2580,8 +2584,8 @@ llama_context * server_context::get_llama_context() const {
return impl->ctx;
}

std::pair<server_queue &, server_response &> server_context::get_queues() {
return { impl->queue_tasks, impl->queue_results };
server_response_reader server_context::get_response_reader() {
return impl->get_response_reader();
}


Expand All @@ -2590,7 +2594,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
struct server_res_generator : server_http_res {
server_response_reader rd;
server_res_generator(server_context_impl & ctx_server)
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
void ok(const json & response_data) {
status = 200;
data = safe_json_to_str(response_data);
Expand Down Expand Up @@ -2623,9 +2627,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try {
std::vector<server_task> tasks;

// tracking generation state and partial tool calls
std::vector<task_result_state> states;

const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
Expand All @@ -2641,7 +2642,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
Expand All @@ -2660,7 +2660,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);

if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
Expand All @@ -2669,15 +2668,13 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.id,
ctx_server.queue_tasks.get_new_id(),
idx++);
states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child));
}
}

tasks.push_back(std::move(task));
}

rd.set_states(std::move(states));

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

states looks unused now - shouldn't it be removed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I forgot - it's removed in e25bf4b

rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
Expand Down Expand Up @@ -3407,7 +3404,7 @@ void server_routes::init_routes() {

// create and queue the task
json responses = json::array();
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
server_response_reader rd = ctx_server.get_response_reader();
{
std::vector<server_task> tasks;
tasks.reserve(documents.size());
Expand Down Expand Up @@ -3667,7 +3664,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons

// create and queue the task
json responses = json::array();
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
server_response_reader rd = ctx_server.get_response_reader();
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
Expand Down
5 changes: 2 additions & 3 deletions tools/server/server-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ struct server_context {
// get the underlaying llama_context
llama_context * get_llama_context() const;

// get the underlaying queue_tasks and queue_results
// used by CLI application
std::pair<server_queue &, server_response &> get_queues();
// get a new response reader, used by CLI application
server_response_reader get_response_reader();
};


Expand Down
13 changes: 11 additions & 2 deletions tools/server/server-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,21 @@ void server_response::terminate() {
// server_response_reader
//

void server_response_reader::set_states(std::vector<task_result_state> && states) {
this->states = std::move(states);
void server_response_reader::post_task(server_task && task) {
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
id_tasks.insert(task.id);
states.push_back(task.create_state());
queue_results.add_waiting_task_id(task.id);
queue_tasks.post(std::move(task));
}

void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
id_tasks = server_task::get_list_id(tasks);
states.reserve(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
states.push_back(tasks[i].create_state());
}
queue_results.add_waiting_tasks(tasks);
queue_tasks.post(std::move(tasks));
}
Expand Down
6 changes: 3 additions & 3 deletions tools/server/server-queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ struct server_response_reader {
std::vector<task_result_state> states;

// should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}

void set_states(std::vector<task_result_state> && states);
void post_task(server_task && tasks);
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;

Expand Down
44 changes: 25 additions & 19 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ struct task_params {
json to_json(bool only_metrics = false) const;
};

// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;

task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}

// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};

struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
Expand Down Expand Up @@ -146,6 +165,12 @@ struct server_task {
copy.tokens = tokens.clone();
return copy;
}

// the task will be moved into queue, then onto slots
// however, the state must be kept by caller (e.g., HTTP thread)
task_result_state create_state() const {
return task_result_state(params.oaicompat_chat_syntax);
}
};

struct result_timings {
Expand Down Expand Up @@ -177,25 +202,6 @@ struct result_prompt_progress {
json to_json() const;
};

// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;

task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}

// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};

struct server_task_result {
int id = -1;
int id_slot = -1;
Expand Down
Loading