diff --git a/common/arg.cpp b/common/arg.cpp index e0f6c606608..cf82d2bc07b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3031,6 +3031,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.timeout_write = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT")); + add_opt(common_arg( + {"--sse-keepalive-interval"}, "N", + string_format("interval in seconds between SSE keepalive comments during streaming (default: %d; 0 = disabled)", params.keepalive_interval), + [](common_params & params, int value) { + params.keepalive_interval = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSE_KEEPALIVE_INTERVAL")); add_opt(common_arg( {"--threads-http"}, "N", string_format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http), diff --git a/common/common.h b/common/common.h index 99898800d1d..11d49c6fbd2 100644 --- a/common/common.h +++ b/common/common.h @@ -590,6 +590,7 @@ struct common_params { bool reuse_port = false; // allow multiple sockets to bind to the same port int32_t timeout_read = 3600; // http read timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t keepalive_interval = 0; // sse keepalive interval in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting bool cache_prompt = true; // whether to enable prompt caching diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bfe3443c1de..5f0fa982b87 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3731,6 +3731,8 @@ std::unique_ptr server_routes::handle_completions_impl( } } } else { + int32_t keepalive_interval = params.keepalive_interval; + // in streaming mode, the first error must be treated as non-stream response // this is to match the OAI API behavior // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 @@ -3764,7 +3766,7 @@ std::unique_ptr server_routes::handle_completions_impl( } res->status = 200; res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &req](std::string & output) -> bool { + res->next = [res_this = res.get(), res_type, &req, keepalive_interval](std::string & output) -> bool { static auto format_error = [](task_response_type res_type, const json & res_json) { if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { return format_anthropic_sse({ @@ -3809,13 +3811,20 @@ std::unique_ptr server_routes::handle_completions_impl( } // receive subsequent results - auto result = rd.next(req.should_stop); + auto result = rd.next(req.should_stop, keepalive_interval); if (result == nullptr) { SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); GGML_ASSERT(req.should_stop()); return false; // should_stop condition met } + // send keepalive as a comment (line starting with a COLON character) + // see 9.2.6 of https://html.spec.whatwg.org/multipage/server-sent-events.html + if (dynamic_cast(result.get()) != nullptr) { + output = ": keepalive\n\n"; + return true; + } + // send the results if (result->is_error()) { json res_json = result->to_json(); diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 588e1a82b18..3b88122c07d 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -375,7 +375,8 @@ bool server_response_reader::has_next() const { // return nullptr if should_stop() is true before receiving a result // note: if one error is received, it will stop further processing and return error result -server_task_result_ptr server_response_reader::next(const std::function & should_stop) { +server_task_result_ptr server_response_reader::next(const std::function & should_stop, int keepalive_interval_seconds) { + int64_t time_last_keepalive_msg_ms = ggml_time_ms(); while (true) { server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); if (result == nullptr) { @@ -387,6 +388,11 @@ server_task_result_ptr server_response_reader::next(const std::function } return nullptr; } + // check if keepalive message needs to be sent + if (keepalive_interval_seconds > 0 && ggml_time_ms() > time_last_keepalive_msg_ms + keepalive_interval_seconds * 1000 ) { + time_last_keepalive_msg_ms = ggml_time_ms(); + return std::make_unique(); + } } else { if (result->is_error()) { stop(); // cancel remaining tasks diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 8ce32c69fb0..250d132c418 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -193,7 +193,7 @@ struct server_response_reader { // return nullptr if should_stop() is true before receiving a result // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function & should_stop); + server_task_result_ptr next(const std::function & should_stop, int keepalive_interval_seconds = 0); struct batch_response { bool is_terminated = false; // if true, indicates that processing was stopped before all results were received diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index ff80be6ccba..6ee04362b22 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1978,6 +1978,14 @@ json server_task_result_apply_lora::to_json() { return json {{ "success", true }}; } +// +// server_task_result_keepalive +// + +json server_task_result_keepalive::to_json() { + return nullptr; +} + // // server_prompt_cache // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index d47dc690cff..0d794725cbd 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -566,6 +566,11 @@ struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; +struct server_task_result_keepalive : server_task_result { + virtual json to_json() override; +}; + + struct server_prompt_data { std::vector main; std::vector drft;