Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,7 @@ struct server_context_impl {
return true;
}

void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress, bool is_begin = false) {
auto res = std::make_unique<server_task_result_cmpl_partial>();

res->id = slot.task->id;
Expand All @@ -1746,6 +1746,9 @@ struct server_context_impl {
res->progress.cache = slot.n_prompt_tokens_cache;
res->progress.processed = slot.prompt.tokens.size();
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
}
if (is_begin) {
res->is_begin = true;
} else {
res->content = tkn.text_to_send;
res->tokens = { tkn.tok };
Expand Down Expand Up @@ -2828,10 +2831,15 @@ struct server_context_impl {

slot.prompt.tokens.keep_first(n_past);

// send initial 0% progress update if needed
// this is to signal the client that the request has started processing
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
if (slot.task->params.stream) {
if (slot.task->params.return_progress) {
// send initial 0% progress update if needed
send_partial_response(slot, {}, true);
} else {
// otherwise, for streaming without progress, signal HTTP to send the headers (i.e. 200 status)
send_partial_response(slot, {}, false, true);
}
}
}

Expand Down Expand Up @@ -3745,7 +3753,9 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// next responses are streamed
// to be sent immediately
json first_result_json = first_result->to_json();
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
if (first_result_json == nullptr) {
res->data = ""; // simply send HTTP headers and status code
} else if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
res->data = format_anthropic_sse(first_result_json);
} else if (res_type == TASK_RESPONSE_TYPE_OAI_RESP) {
res->data = format_oai_resp_sse(first_result_json);
Expand Down
3 changes: 3 additions & 0 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,9 @@ void server_task_result_cmpl_partial::update(task_result_state & state) {

json server_task_result_cmpl_partial::to_json() {
GGML_ASSERT(is_updated && "update() must be called before to_json()");
if (is_begin) {
return nullptr; // simply signal to HTTP handler to send the headers and status code
}
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
Expand Down
4 changes: 3 additions & 1 deletion tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ enum stop_type {
};

struct task_params {
bool stream = true;
bool stream = false;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
Expand Down Expand Up @@ -418,6 +418,8 @@ struct server_task_result_cmpl_partial : server_task_result {

bool post_sampling_probs;
bool is_progress = false;
bool is_begin = false; // whether to send 200 status to HTTP client (begin of SSE stream)
// ref: https://github.com/ggml-org/llama.cpp/pull/23884
completion_token_output prob_output;
result_timings timings;
result_prompt_progress progress;
Expand Down
Loading