Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.timeout_write = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT"));
add_opt(common_arg(
{"--sse-keepalive-interval"}, "N",
string_format("interval in seconds between SSE keepalive comments during streaming (default: %d; 0 = disabled)", params.keepalive_interval),
[](common_params & params, int value) {
params.keepalive_interval = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSE_KEEPALIVE_INTERVAL"));
add_opt(common_arg(
{"--threads-http"}, "N",
string_format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http),
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ struct common_params {
bool reuse_port = false; // allow multiple sockets to bind to the same port
int32_t timeout_read = 3600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t keepalive_interval = 0; // sse keepalive interval in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
bool cache_prompt = true; // whether to enable prompt caching
Expand Down
13 changes: 11 additions & 2 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3731,6 +3731,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
}
}
} else {
int32_t keepalive_interval = params.keepalive_interval;

// in streaming mode, the first error must be treated as non-stream response
// this is to match the OAI API behavior
// ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
Expand Down Expand Up @@ -3764,7 +3766,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), res_type, &req](std::string & output) -> bool {
res->next = [res_this = res.get(), res_type, &req, keepalive_interval](std::string & output) -> bool {
static auto format_error = [](task_response_type res_type, const json & res_json) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
return format_anthropic_sse({
Expand Down Expand Up @@ -3809,13 +3811,20 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
}

// receive subsequent results
auto result = rd.next(req.should_stop);
auto result = rd.next(req.should_stop, keepalive_interval);
if (result == nullptr) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
GGML_ASSERT(req.should_stop());
return false; // should_stop condition met
}

// send keepalive as a comment (line starting with a COLON character)
// see 9.2.6 of https://html.spec.whatwg.org/multipage/server-sent-events.html
if (dynamic_cast<server_task_result_keepalive*>(result.get()) != nullptr) {
output = ": keepalive\n\n";
return true;
}

// send the results
if (result->is_error()) {
json res_json = result->to_json();
Expand Down
8 changes: 7 additions & 1 deletion tools/server/server-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ bool server_response_reader::has_next() const {

// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop, int keepalive_interval_seconds) {
int64_t time_last_keepalive_msg_ms = ggml_time_ms();
while (true) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
if (result == nullptr) {
Expand All @@ -387,6 +388,11 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
}
return nullptr;
}
// check if keepalive message needs to be sent
if (keepalive_interval_seconds > 0 && ggml_time_ms() > time_last_keepalive_msg_ms + keepalive_interval_seconds * 1000 ) {
time_last_keepalive_msg_ms = ggml_time_ms();
return std::make_unique<server_task_result_keepalive>();
}
} else {
if (result->is_error()) {
stop(); // cancel remaining tasks
Expand Down
2 changes: 1 addition & 1 deletion tools/server/server-queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ struct server_response_reader {

// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop);
server_task_result_ptr next(const std::function<bool()> & should_stop, int keepalive_interval_seconds = 0);

struct batch_response {
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
Expand Down
8 changes: 8 additions & 0 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1978,6 +1978,14 @@ json server_task_result_apply_lora::to_json() {
return json {{ "success", true }};
}

//
// server_task_result_keepalive
//

json server_task_result_keepalive::to_json() {
return nullptr;
}

//
// server_prompt_cache
//
Expand Down
5 changes: 5 additions & 0 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,11 @@ struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};

struct server_task_result_keepalive : server_task_result {
virtual json to_json() override;
};


struct server_prompt_data {
std::vector<uint8_t> main;
std::vector<uint8_t> drft;
Expand Down